diff --git a/matchy/files/state.py b/matchy/files/state.py index 3e36c1f..3b2201b 100644 --- a/matchy/files/state.py +++ b/matchy/files/state.py @@ -7,7 +7,7 @@ from typing import Protocol import matchy.files.ops as ops import copy import logging -from contextlib import contextmanager +from functools import wraps logger = logging.getLogger("state") logger.setLevel(logging.INFO) @@ -182,6 +182,24 @@ class State(): dict = self._dict _SCHEMA.validate(dict) + @staticmethod + def safe_write(func): + """ + Wraps any function running it first on some temporary state + Validates the resulting state and only then attempts to save it out + before storing the dict back in the State + """ + @wraps(func) + def inner(self: State, *args, **kwargs): + tmp = State(self._dict, self._file) + func(tmp, *args, **kwargs) + tmp.validate() + if tmp._file: + tmp._save_to_file() + self._dict = tmp._dict + + return inner + def get_history_timestamps(self, users: list[Member]) -> list[datetime]: """Grab all timestamps in the history""" others = [m.id for m in users] @@ -202,31 +220,31 @@ class State(): def get_user_matches(self, id: int) -> list[int]: return self._users.get(str(id), {}).get(_Key.MATCHES, {}) + @safe_write def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None: """Log the groups""" ts = datetime_to_ts(ts or datetime.now()) - with self._safe_wrap_write() as safe_state: - for group in groups: - # Update the matchee data with the matches - for m in group: - matchee = safe_state._users.setdefault(str(m.id), {}) - matchee_matches = matchee.setdefault(_Key.MATCHES, {}) + for group in groups: + # Update the matchee data with the matches + for m in group: + matchee = self._users.setdefault(str(m.id), {}) + matchee_matches = matchee.setdefault(_Key.MATCHES, {}) - for o in (o for o in group if o.id != m.id): - matchee_matches[str(o.id)] = ts + for o in (o for o in group if o.id != m.id): + matchee_matches[str(o.id)] = ts + @safe_write def set_user_scope(self, id: str, scope: str, value: bool = True): """Add an auth scope to a user""" - with self._safe_wrap_write() as safe_state: - # Dive in - user = safe_state._users.setdefault(str(id), {}) - scopes = user.setdefault(_Key.SCOPES, []) + # Dive in + user = self._users.setdefault(str(id), {}) + scopes = user.setdefault(_Key.SCOPES, []) - # Set the value - if value and scope not in scopes: - scopes.append(scope) - elif not value and scope in scopes: - scopes.remove(scope) + # Set the value + if value and scope not in scopes: + scopes.append(scope) + elif not value and scope in scopes: + scopes.remove(scope) def get_user_has_scope(self, id: str, scope: str) -> bool: """ @@ -255,17 +273,17 @@ class State(): self._set_user_channel_prop( id, channel_id, _Key.REACTIVATE, datetime_to_ts(until)) + @safe_write def reactivate_users(self, channel_id: str): """Reactivate any users who've passed their reactivation time on this channel""" - with self._safe_wrap_write() as safe_state: - for user in safe_state._users.values(): - channels = user.get(_Key.CHANNELS, {}) - channel = channels.get(str(channel_id), {}) - if channel and not channel[_Key.ACTIVE]: - reactivate = channel.get(_Key.REACTIVATE, None) - # Check if we've gone past the reactivation time and re-activate - if reactivate and datetime.now() > ts_to_datetime(reactivate): - channel[_Key.ACTIVE] = True + for user in self._users.values(): + channels = user.get(_Key.CHANNELS, {}) + channel = channels.get(str(channel_id), {}) + if channel and not channel[_Key.ACTIVE]: + reactivate = channel.get(_Key.REACTIVATE, None) + # Check if we've gone past the reactivation time and re-activate + if reactivate and datetime.now() > ts_to_datetime(reactivate): + channel[_Key.ACTIVE] = True def get_active_match_tasks(self, time: datetime | None = None) -> Generator[str, int]: """ @@ -295,37 +313,37 @@ class State(): for task in tasks: yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN]) + @safe_write def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool: """Set up a match task on a channel""" - with self._safe_wrap_write() as safe_state: - channel = safe_state._tasks.setdefault(str(channel_id), {}) - matches = channel.setdefault(_Key.MATCH_TASKS, []) + channel = self._tasks.setdefault(str(channel_id), {}) + matches = channel.setdefault(_Key.MATCH_TASKS, []) - found = False - for match in matches: - # Specifically check for the combination of weekday and hour - if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour: - found = True - if set: - match[_Key.MEMBERS_MIN] = members_min - else: - matches.remove(match) - - # Return true as we've successfully changed the data in place - return True - - # If we didn't find it, add it to the schedule - if not found and set: - matches.append({ - _Key.MEMBERS_MIN: members_min, - _Key.WEEKDAY: weekday, - _Key.HOUR: hour, - }) + found = False + for match in matches: + # Specifically check for the combination of weekday and hour + if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour: + found = True + if set: + match[_Key.MEMBERS_MIN] = members_min + else: + matches.remove(match) + # Return true as we've successfully changed the data in place return True - # We did not manage to remove the schedule (or add it? though that should be impossible) - return False + # If we didn't find it, add it to the schedule + if not found and set: + matches.append({ + _Key.MEMBERS_MIN: members_min, + _Key.WEEKDAY: weekday, + _Key.HOUR: hour, + }) + + return True + + # We did not manage to remove the schedule (or add it? though that should be impossible) + return False @property def dict_internal_copy(self) -> dict: @@ -340,33 +358,16 @@ class State(): def _tasks(self) -> dict[str]: return self._dict[_Key.TASKS] + @safe_write def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value): """Set a user channel property helper""" - with self._safe_wrap_write() as safe_state: - # Dive in - user = safe_state._users.setdefault(str(id), {}) - channels = user.setdefault(_Key.CHANNELS, {}) - channel = channels.setdefault(str(channel_id), {}) + # Dive in + user = self._users.setdefault(str(id), {}) + channels = user.setdefault(_Key.CHANNELS, {}) + channel = channels.setdefault(str(channel_id), {}) - # Set the value - channel[key] = value - - # TODO: Make this a decorator? - @contextmanager - def _safe_wrap_write(self): - """Safely run any function wrapped in a validate""" - # Wrap in a temporary state to validate first to prevent corruption - tmp_state = State(self._dict) - try: - yield tmp_state - finally: - # Validate and then overwrite our dict with the new one - tmp_state.validate() - self._dict = tmp_state._dict - - # Write this change out if we have a file - if self._file: - self._save_to_file() + # Set the value + channel[key] = value def _save_to_file(self): """Saves the state out to the chosen file"""