More protection - State does it's own saving

This commit is contained in:
Marc Di Luzio 2024-08-13 23:43:15 +01:00
parent cbea7abca2
commit 57f65b265c
5 changed files with 27 additions and 29 deletions

View file

@ -9,7 +9,7 @@ from datetime import datetime, timedelta, time
import cogs.match_button as match_button
import matching
from state import State, save_to_file, AuthScope
from state import State, AuthScope
import util
logger = logging.getLogger("cog")
@ -38,7 +38,6 @@ class MatchyCog(commands.Cog):
self.state.set_user_active_in_channel(
interaction.user.id, interaction.channel.id)
save_to_file(self.state)
await interaction.response.send_message(
f"Roger roger {interaction.user.mention}!\n"
+ f"Added you to {interaction.channel.mention}!",
@ -52,7 +51,6 @@ class MatchyCog(commands.Cog):
self.state.set_user_active_in_channel(
interaction.user.id, interaction.channel.id, False)
save_to_file(self.state)
await interaction.response.send_message(
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
@ -68,7 +66,6 @@ class MatchyCog(commands.Cog):
until = datetime.now() + timedelta(days=days)
self.state.set_user_paused_in_channel(
interaction.user.id, interaction.channel.id, until)
save_to_file(self.state)
await interaction.response.send_message(
f"Sure thing {interaction.user.mention}!\n"
+ f"Paused you until {util.format_day(until)}!",
@ -127,7 +124,6 @@ class MatchyCog(commands.Cog):
# Add the scheduled task and save
success = self.state.set_channel_match_task(
channel_id, members_min, weekday, hour, not cancel)
save_to_file(self.state)
# Let the user know what happened
if not cancel:

View file

@ -3,7 +3,7 @@ Owner bot cog
"""
import logging
from discord.ext import commands
from state import State, AuthScope, save_to_file
from state import State, AuthScope
logger = logging.getLogger("owner")
logger.setLevel(logging.INFO)
@ -49,7 +49,6 @@ class OwnerCog(commands.Cog):
"""
if user.isdigit():
self._state.set_user_scope(str(user), AuthScope.MATCHER)
save_to_file(self._state)
logger.info("Granting user %s matcher scope", user)
await ctx.reply("Done!", ephemeral=True)
else:

View file

@ -3,7 +3,7 @@ import logging
import discord
from datetime import datetime, timedelta
from typing import Protocol, runtime_checkable
from state import State, save_to_file, ts_to_datetime
from state import State, ts_to_datetime
import util
import config
@ -224,7 +224,6 @@ async def match_groups_in_channel(state: State, channel: discord.channel, min: i
# Save the groups to the history
state.log_groups(groups)
save_to_file(state)
logger.info("Done! Matched into %s groups.", len(groups))

View file

@ -174,10 +174,11 @@ def datetime_to_ts(ts: datetime) -> str:
class State():
def __init__(self, data: dict):
def __init__(self, data: dict, file: str | None = None):
"""Initialise and validate the state"""
self.validate(data)
self._dict = copy.deepcopy(data)
self._file = file
def validate(self, dict: dict = None):
"""Initialise and validate a state dict"""
@ -208,7 +209,7 @@ class State():
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
"""Log the groups"""
ts = datetime_to_ts(ts or datetime.now())
with self._safe_wrap() as safe_state:
with self._safe_wrap_write() as safe_state:
for group in groups:
# Update the matchee data with the matches
for m in group:
@ -220,7 +221,7 @@ class State():
def set_user_scope(self, id: str, scope: str, value: bool = True):
"""Add an auth scope to a user"""
with self._safe_wrap() as safe_state:
with self._safe_wrap_write() as safe_state:
# Dive in
user = safe_state._users.setdefault(str(id), {})
scopes = user.setdefault(_Key.SCOPES, [])
@ -260,7 +261,7 @@ class State():
def reactivate_users(self, channel_id: str):
"""Reactivate any users who've passed their reactivation time on this channel"""
with self._safe_wrap() as safe_state:
with self._safe_wrap_write() as safe_state:
for user in safe_state._users.values():
channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {})
@ -300,7 +301,7 @@ class State():
def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool:
"""Set up a match task on a channel"""
with self._safe_wrap() as safe_state:
with self._safe_wrap_write() as safe_state:
channel = safe_state._tasks.setdefault(str(channel_id), {})
matches = channel.setdefault(_Key.MATCH_TASKS, [])
@ -345,7 +346,7 @@ class State():
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
"""Set a user channel property helper"""
with self._safe_wrap() as safe_state:
with self._safe_wrap_write() as safe_state:
# Dive in
user = safe_state._users.setdefault(str(id), {})
channels = user.setdefault(_Key.CHANNELS, {})
@ -355,7 +356,7 @@ class State():
channel[key] = value
@contextmanager
def _safe_wrap(self):
def _safe_wrap_write(self):
"""Safely run any function wrapped in a validate"""
# Wrap in a temporary state to validate first to prevent corruption
tmp_state = State(self._dict)
@ -366,6 +367,14 @@ class State():
tmp_state.validate()
self._dict = tmp_state._dict
# Write this change out if we have a file
if self._file:
self._save_to_file()
def _save_to_file(self):
"""Saves the state out to the chosen file"""
files.save(self._file, self.dict_internal_copy)
def _migrate(dict: dict):
"""Migrate a dict through versions"""
@ -378,7 +387,7 @@ def _migrate(dict: dict):
def load_from_file(file: str = _STATE_FILE) -> State:
"""
Load the state from a file
Load the state from a files
Apply any required migrations
"""
loaded = _EMPTY_DICT
@ -388,14 +397,9 @@ def load_from_file(file: str = _STATE_FILE) -> State:
loaded = files.load(file)
_migrate(loaded)
st = State(loaded)
st = State(loaded, file)
# Save out the migrated (or new) file
files.save(file, st._dict)
return st
def save_to_file(state: State, file: str = _STATE_FILE):
"""Saves the state out to a file"""
files.save(file, state.dict_internal_copy)

View file

@ -18,10 +18,10 @@ def test_simple_load_reload():
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
st._save_to_file()
st = state.load_from_file(path)
state.save_to_file(st, path)
st._save_to_file()
st = state.load_from_file(path)
@ -30,13 +30,13 @@ def test_authscope():
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
st._save_to_file()
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
st = state.load_from_file(path)
st.set_user_scope(1, state.AuthScope.MATCHER)
state.save_to_file(st, path)
st._save_to_file()
st = state.load_from_file(path)
assert st.get_user_has_scope(1, state.AuthScope.MATCHER)
@ -50,13 +50,13 @@ def test_channeljoin():
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
st._save_to_file()
assert not st.get_user_active_in_channel(1, "2")
st = state.load_from_file(path)
st.set_user_active_in_channel(1, "2", True)
state.save_to_file(st, path)
st._save_to_file()
st = state.load_from_file(path)
assert st.get_user_active_in_channel(1, "2")