Use a decorator for the safe write in State

This is a little cleaner to my eyes
This commit is contained in:
Marc Di Luzio 2024-08-16 22:53:10 +01:00
parent 37e1e7a7ae
commit ef4dd5c571

View file

@ -7,7 +7,7 @@ from typing import Protocol
import matchy.files.ops as ops import matchy.files.ops as ops
import copy import copy
import logging import logging
from contextlib import contextmanager from functools import wraps
logger = logging.getLogger("state") logger = logging.getLogger("state")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -182,6 +182,24 @@ class State():
dict = self._dict dict = self._dict
_SCHEMA.validate(dict) _SCHEMA.validate(dict)
@staticmethod
def safe_write(func):
"""
Wraps any function running it first on some temporary state
Validates the resulting state and only then attempts to save it out
before storing the dict back in the State
"""
@wraps(func)
def inner(self: State, *args, **kwargs):
tmp = State(self._dict, self._file)
func(tmp, *args, **kwargs)
tmp.validate()
if tmp._file:
tmp._save_to_file()
self._dict = tmp._dict
return inner
def get_history_timestamps(self, users: list[Member]) -> list[datetime]: def get_history_timestamps(self, users: list[Member]) -> list[datetime]:
"""Grab all timestamps in the history""" """Grab all timestamps in the history"""
others = [m.id for m in users] others = [m.id for m in users]
@ -202,31 +220,31 @@ class State():
def get_user_matches(self, id: int) -> list[int]: def get_user_matches(self, id: int) -> list[int]:
return self._users.get(str(id), {}).get(_Key.MATCHES, {}) return self._users.get(str(id), {}).get(_Key.MATCHES, {})
@safe_write
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None: def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
"""Log the groups""" """Log the groups"""
ts = datetime_to_ts(ts or datetime.now()) ts = datetime_to_ts(ts or datetime.now())
with self._safe_wrap_write() as safe_state: for group in groups:
for group in groups: # Update the matchee data with the matches
# Update the matchee data with the matches for m in group:
for m in group: matchee = self._users.setdefault(str(m.id), {})
matchee = safe_state._users.setdefault(str(m.id), {}) matchee_matches = matchee.setdefault(_Key.MATCHES, {})
matchee_matches = matchee.setdefault(_Key.MATCHES, {})
for o in (o for o in group if o.id != m.id): for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts matchee_matches[str(o.id)] = ts
@safe_write
def set_user_scope(self, id: str, scope: str, value: bool = True): def set_user_scope(self, id: str, scope: str, value: bool = True):
"""Add an auth scope to a user""" """Add an auth scope to a user"""
with self._safe_wrap_write() as safe_state: # Dive in
# Dive in user = self._users.setdefault(str(id), {})
user = safe_state._users.setdefault(str(id), {}) scopes = user.setdefault(_Key.SCOPES, [])
scopes = user.setdefault(_Key.SCOPES, [])
# Set the value # Set the value
if value and scope not in scopes: if value and scope not in scopes:
scopes.append(scope) scopes.append(scope)
elif not value and scope in scopes: elif not value and scope in scopes:
scopes.remove(scope) scopes.remove(scope)
def get_user_has_scope(self, id: str, scope: str) -> bool: def get_user_has_scope(self, id: str, scope: str) -> bool:
""" """
@ -255,17 +273,17 @@ class State():
self._set_user_channel_prop( self._set_user_channel_prop(
id, channel_id, _Key.REACTIVATE, datetime_to_ts(until)) id, channel_id, _Key.REACTIVATE, datetime_to_ts(until))
@safe_write
def reactivate_users(self, channel_id: str): def reactivate_users(self, channel_id: str):
"""Reactivate any users who've passed their reactivation time on this channel""" """Reactivate any users who've passed their reactivation time on this channel"""
with self._safe_wrap_write() as safe_state: for user in self._users.values():
for user in safe_state._users.values(): channels = user.get(_Key.CHANNELS, {})
channels = user.get(_Key.CHANNELS, {}) channel = channels.get(str(channel_id), {})
channel = channels.get(str(channel_id), {}) if channel and not channel[_Key.ACTIVE]:
if channel and not channel[_Key.ACTIVE]: reactivate = channel.get(_Key.REACTIVATE, None)
reactivate = channel.get(_Key.REACTIVATE, None) # Check if we've gone past the reactivation time and re-activate
# Check if we've gone past the reactivation time and re-activate if reactivate and datetime.now() > ts_to_datetime(reactivate):
if reactivate and datetime.now() > ts_to_datetime(reactivate): channel[_Key.ACTIVE] = True
channel[_Key.ACTIVE] = True
def get_active_match_tasks(self, time: datetime | None = None) -> Generator[str, int]: def get_active_match_tasks(self, time: datetime | None = None) -> Generator[str, int]:
""" """
@ -295,37 +313,37 @@ class State():
for task in tasks: for task in tasks:
yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN]) yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN])
@safe_write
def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool: def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool:
"""Set up a match task on a channel""" """Set up a match task on a channel"""
with self._safe_wrap_write() as safe_state: channel = self._tasks.setdefault(str(channel_id), {})
channel = safe_state._tasks.setdefault(str(channel_id), {}) matches = channel.setdefault(_Key.MATCH_TASKS, [])
matches = channel.setdefault(_Key.MATCH_TASKS, [])
found = False found = False
for match in matches: for match in matches:
# Specifically check for the combination of weekday and hour # Specifically check for the combination of weekday and hour
if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour: if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour:
found = True found = True
if set: if set:
match[_Key.MEMBERS_MIN] = members_min match[_Key.MEMBERS_MIN] = members_min
else: else:
matches.remove(match) matches.remove(match)
# Return true as we've successfully changed the data in place
return True
# If we didn't find it, add it to the schedule
if not found and set:
matches.append({
_Key.MEMBERS_MIN: members_min,
_Key.WEEKDAY: weekday,
_Key.HOUR: hour,
})
# Return true as we've successfully changed the data in place
return True return True
# We did not manage to remove the schedule (or add it? though that should be impossible) # If we didn't find it, add it to the schedule
return False if not found and set:
matches.append({
_Key.MEMBERS_MIN: members_min,
_Key.WEEKDAY: weekday,
_Key.HOUR: hour,
})
return True
# We did not manage to remove the schedule (or add it? though that should be impossible)
return False
@property @property
def dict_internal_copy(self) -> dict: def dict_internal_copy(self) -> dict:
@ -340,33 +358,16 @@ class State():
def _tasks(self) -> dict[str]: def _tasks(self) -> dict[str]:
return self._dict[_Key.TASKS] return self._dict[_Key.TASKS]
@safe_write
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value): def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
"""Set a user channel property helper""" """Set a user channel property helper"""
with self._safe_wrap_write() as safe_state: # Dive in
# Dive in user = self._users.setdefault(str(id), {})
user = safe_state._users.setdefault(str(id), {}) channels = user.setdefault(_Key.CHANNELS, {})
channels = user.setdefault(_Key.CHANNELS, {}) channel = channels.setdefault(str(channel_id), {})
channel = channels.setdefault(str(channel_id), {})
# Set the value # Set the value
channel[key] = value channel[key] = value
# TODO: Make this a decorator?
@contextmanager
def _safe_wrap_write(self):
"""Safely run any function wrapped in a validate"""
# Wrap in a temporary state to validate first to prevent corruption
tmp_state = State(self._dict)
try:
yield tmp_state
finally:
# Validate and then overwrite our dict with the new one
tmp_state.validate()
self._dict = tmp_state._dict
# Write this change out if we have a file
if self._file:
self._save_to_file()
def _save_to_file(self): def _save_to_file(self):
"""Saves the state out to the chosen file""" """Saves the state out to the chosen file"""