From 57f65b265cb1683d6d7ace9c44e4e466128e3022 Mon Sep 17 00:00:00 2001 From: Marc Di Luzio Date: Tue, 13 Aug 2024 23:43:15 +0100 Subject: [PATCH] More protection - State does it's own saving --- py/cogs/matchy_cog.py | 6 +----- py/cogs/owner_cog.py | 3 +-- py/matching.py | 3 +-- py/state.py | 32 ++++++++++++++++++-------------- py/state_test.py | 12 ++++++------ 5 files changed, 27 insertions(+), 29 deletions(-) diff --git a/py/cogs/matchy_cog.py b/py/cogs/matchy_cog.py index b7dd0ee..95cc1c7 100644 --- a/py/cogs/matchy_cog.py +++ b/py/cogs/matchy_cog.py @@ -9,7 +9,7 @@ from datetime import datetime, timedelta, time import cogs.match_button as match_button import matching -from state import State, save_to_file, AuthScope +from state import State, AuthScope import util logger = logging.getLogger("cog") @@ -38,7 +38,6 @@ class MatchyCog(commands.Cog): self.state.set_user_active_in_channel( interaction.user.id, interaction.channel.id) - save_to_file(self.state) await interaction.response.send_message( f"Roger roger {interaction.user.mention}!\n" + f"Added you to {interaction.channel.mention}!", @@ -52,7 +51,6 @@ class MatchyCog(commands.Cog): self.state.set_user_active_in_channel( interaction.user.id, interaction.channel.id, False) - save_to_file(self.state) await interaction.response.send_message( f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True) @@ -68,7 +66,6 @@ class MatchyCog(commands.Cog): until = datetime.now() + timedelta(days=days) self.state.set_user_paused_in_channel( interaction.user.id, interaction.channel.id, until) - save_to_file(self.state) await interaction.response.send_message( f"Sure thing {interaction.user.mention}!\n" + f"Paused you until {util.format_day(until)}!", @@ -127,7 +124,6 @@ class MatchyCog(commands.Cog): # Add the scheduled task and save success = self.state.set_channel_match_task( channel_id, members_min, weekday, hour, not cancel) - save_to_file(self.state) # Let the user know what happened if not cancel: diff --git a/py/cogs/owner_cog.py b/py/cogs/owner_cog.py index 2ea2008..8590d7e 100644 --- a/py/cogs/owner_cog.py +++ b/py/cogs/owner_cog.py @@ -3,7 +3,7 @@ Owner bot cog """ import logging from discord.ext import commands -from state import State, AuthScope, save_to_file +from state import State, AuthScope logger = logging.getLogger("owner") logger.setLevel(logging.INFO) @@ -49,7 +49,6 @@ class OwnerCog(commands.Cog): """ if user.isdigit(): self._state.set_user_scope(str(user), AuthScope.MATCHER) - save_to_file(self._state) logger.info("Granting user %s matcher scope", user) await ctx.reply("Done!", ephemeral=True) else: diff --git a/py/matching.py b/py/matching.py index a0ea046..1f3b485 100644 --- a/py/matching.py +++ b/py/matching.py @@ -3,7 +3,7 @@ import logging import discord from datetime import datetime, timedelta from typing import Protocol, runtime_checkable -from state import State, save_to_file, ts_to_datetime +from state import State, ts_to_datetime import util import config @@ -224,7 +224,6 @@ async def match_groups_in_channel(state: State, channel: discord.channel, min: i # Save the groups to the history state.log_groups(groups) - save_to_file(state) logger.info("Done! Matched into %s groups.", len(groups)) diff --git a/py/state.py b/py/state.py index ba57513..7a161f8 100644 --- a/py/state.py +++ b/py/state.py @@ -174,10 +174,11 @@ def datetime_to_ts(ts: datetime) -> str: class State(): - def __init__(self, data: dict): + def __init__(self, data: dict, file: str | None = None): """Initialise and validate the state""" self.validate(data) self._dict = copy.deepcopy(data) + self._file = file def validate(self, dict: dict = None): """Initialise and validate a state dict""" @@ -208,7 +209,7 @@ class State(): def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None: """Log the groups""" ts = datetime_to_ts(ts or datetime.now()) - with self._safe_wrap() as safe_state: + with self._safe_wrap_write() as safe_state: for group in groups: # Update the matchee data with the matches for m in group: @@ -220,7 +221,7 @@ class State(): def set_user_scope(self, id: str, scope: str, value: bool = True): """Add an auth scope to a user""" - with self._safe_wrap() as safe_state: + with self._safe_wrap_write() as safe_state: # Dive in user = safe_state._users.setdefault(str(id), {}) scopes = user.setdefault(_Key.SCOPES, []) @@ -260,7 +261,7 @@ class State(): def reactivate_users(self, channel_id: str): """Reactivate any users who've passed their reactivation time on this channel""" - with self._safe_wrap() as safe_state: + with self._safe_wrap_write() as safe_state: for user in safe_state._users.values(): channels = user.get(_Key.CHANNELS, {}) channel = channels.get(str(channel_id), {}) @@ -300,7 +301,7 @@ class State(): def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool: """Set up a match task on a channel""" - with self._safe_wrap() as safe_state: + with self._safe_wrap_write() as safe_state: channel = safe_state._tasks.setdefault(str(channel_id), {}) matches = channel.setdefault(_Key.MATCH_TASKS, []) @@ -345,7 +346,7 @@ class State(): def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value): """Set a user channel property helper""" - with self._safe_wrap() as safe_state: + with self._safe_wrap_write() as safe_state: # Dive in user = safe_state._users.setdefault(str(id), {}) channels = user.setdefault(_Key.CHANNELS, {}) @@ -355,7 +356,7 @@ class State(): channel[key] = value @contextmanager - def _safe_wrap(self): + def _safe_wrap_write(self): """Safely run any function wrapped in a validate""" # Wrap in a temporary state to validate first to prevent corruption tmp_state = State(self._dict) @@ -366,6 +367,14 @@ class State(): tmp_state.validate() self._dict = tmp_state._dict + # Write this change out if we have a file + if self._file: + self._save_to_file() + + def _save_to_file(self): + """Saves the state out to the chosen file""" + files.save(self._file, self.dict_internal_copy) + def _migrate(dict: dict): """Migrate a dict through versions""" @@ -378,7 +387,7 @@ def _migrate(dict: dict): def load_from_file(file: str = _STATE_FILE) -> State: """ - Load the state from a file + Load the state from a files Apply any required migrations """ loaded = _EMPTY_DICT @@ -388,14 +397,9 @@ def load_from_file(file: str = _STATE_FILE) -> State: loaded = files.load(file) _migrate(loaded) - st = State(loaded) + st = State(loaded, file) # Save out the migrated (or new) file files.save(file, st._dict) return st - - -def save_to_file(state: State, file: str = _STATE_FILE): - """Saves the state out to a file""" - files.save(file, state.dict_internal_copy) diff --git a/py/state_test.py b/py/state_test.py index b79a426..51a7f4b 100644 --- a/py/state_test.py +++ b/py/state_test.py @@ -18,10 +18,10 @@ def test_simple_load_reload(): with tempfile.TemporaryDirectory() as tmp: path = os.path.join(tmp, 'tmp.json') st = state.load_from_file(path) - state.save_to_file(st, path) + st._save_to_file() st = state.load_from_file(path) - state.save_to_file(st, path) + st._save_to_file() st = state.load_from_file(path) @@ -30,13 +30,13 @@ def test_authscope(): with tempfile.TemporaryDirectory() as tmp: path = os.path.join(tmp, 'tmp.json') st = state.load_from_file(path) - state.save_to_file(st, path) + st._save_to_file() assert not st.get_user_has_scope(1, state.AuthScope.MATCHER) st = state.load_from_file(path) st.set_user_scope(1, state.AuthScope.MATCHER) - state.save_to_file(st, path) + st._save_to_file() st = state.load_from_file(path) assert st.get_user_has_scope(1, state.AuthScope.MATCHER) @@ -50,13 +50,13 @@ def test_channeljoin(): with tempfile.TemporaryDirectory() as tmp: path = os.path.join(tmp, 'tmp.json') st = state.load_from_file(path) - state.save_to_file(st, path) + st._save_to_file() assert not st.get_user_active_in_channel(1, "2") st = state.load_from_file(path) st.set_user_active_in_channel(1, "2", True) - state.save_to_file(st, path) + st._save_to_file() st = state.load_from_file(path) assert st.get_user_active_in_channel(1, "2")