Use a decorator for the safe write in State

This is a little cleaner to my eyes
This commit is contained in:
Marc Di Luzio 2024-08-16 22:53:10 +01:00
parent 37e1e7a7ae
commit ef4dd5c571

View file

@ -7,7 +7,7 @@ from typing import Protocol
import matchy.files.ops as ops import matchy.files.ops as ops
import copy import copy
import logging import logging
from contextlib import contextmanager from functools import wraps
logger = logging.getLogger("state") logger = logging.getLogger("state")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -182,6 +182,24 @@ class State():
dict = self._dict dict = self._dict
_SCHEMA.validate(dict) _SCHEMA.validate(dict)
@staticmethod
def safe_write(func):
"""
Wraps any function running it first on some temporary state
Validates the resulting state and only then attempts to save it out
before storing the dict back in the State
"""
@wraps(func)
def inner(self: State, *args, **kwargs):
tmp = State(self._dict, self._file)
func(tmp, *args, **kwargs)
tmp.validate()
if tmp._file:
tmp._save_to_file()
self._dict = tmp._dict
return inner
def get_history_timestamps(self, users: list[Member]) -> list[datetime]: def get_history_timestamps(self, users: list[Member]) -> list[datetime]:
"""Grab all timestamps in the history""" """Grab all timestamps in the history"""
others = [m.id for m in users] others = [m.id for m in users]
@ -202,24 +220,24 @@ class State():
def get_user_matches(self, id: int) -> list[int]: def get_user_matches(self, id: int) -> list[int]:
return self._users.get(str(id), {}).get(_Key.MATCHES, {}) return self._users.get(str(id), {}).get(_Key.MATCHES, {})
@safe_write
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None: def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
"""Log the groups""" """Log the groups"""
ts = datetime_to_ts(ts or datetime.now()) ts = datetime_to_ts(ts or datetime.now())
with self._safe_wrap_write() as safe_state:
for group in groups: for group in groups:
# Update the matchee data with the matches # Update the matchee data with the matches
for m in group: for m in group:
matchee = safe_state._users.setdefault(str(m.id), {}) matchee = self._users.setdefault(str(m.id), {})
matchee_matches = matchee.setdefault(_Key.MATCHES, {}) matchee_matches = matchee.setdefault(_Key.MATCHES, {})
for o in (o for o in group if o.id != m.id): for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts matchee_matches[str(o.id)] = ts
@safe_write
def set_user_scope(self, id: str, scope: str, value: bool = True): def set_user_scope(self, id: str, scope: str, value: bool = True):
"""Add an auth scope to a user""" """Add an auth scope to a user"""
with self._safe_wrap_write() as safe_state:
# Dive in # Dive in
user = safe_state._users.setdefault(str(id), {}) user = self._users.setdefault(str(id), {})
scopes = user.setdefault(_Key.SCOPES, []) scopes = user.setdefault(_Key.SCOPES, [])
# Set the value # Set the value
@ -255,10 +273,10 @@ class State():
self._set_user_channel_prop( self._set_user_channel_prop(
id, channel_id, _Key.REACTIVATE, datetime_to_ts(until)) id, channel_id, _Key.REACTIVATE, datetime_to_ts(until))
@safe_write
def reactivate_users(self, channel_id: str): def reactivate_users(self, channel_id: str):
"""Reactivate any users who've passed their reactivation time on this channel""" """Reactivate any users who've passed their reactivation time on this channel"""
with self._safe_wrap_write() as safe_state: for user in self._users.values():
for user in safe_state._users.values():
channels = user.get(_Key.CHANNELS, {}) channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {}) channel = channels.get(str(channel_id), {})
if channel and not channel[_Key.ACTIVE]: if channel and not channel[_Key.ACTIVE]:
@ -295,10 +313,10 @@ class State():
for task in tasks: for task in tasks:
yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN]) yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN])
@safe_write
def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool: def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool:
"""Set up a match task on a channel""" """Set up a match task on a channel"""
with self._safe_wrap_write() as safe_state: channel = self._tasks.setdefault(str(channel_id), {})
channel = safe_state._tasks.setdefault(str(channel_id), {})
matches = channel.setdefault(_Key.MATCH_TASKS, []) matches = channel.setdefault(_Key.MATCH_TASKS, [])
found = False found = False
@ -340,34 +358,17 @@ class State():
def _tasks(self) -> dict[str]: def _tasks(self) -> dict[str]:
return self._dict[_Key.TASKS] return self._dict[_Key.TASKS]
@safe_write
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value): def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
"""Set a user channel property helper""" """Set a user channel property helper"""
with self._safe_wrap_write() as safe_state:
# Dive in # Dive in
user = safe_state._users.setdefault(str(id), {}) user = self._users.setdefault(str(id), {})
channels = user.setdefault(_Key.CHANNELS, {}) channels = user.setdefault(_Key.CHANNELS, {})
channel = channels.setdefault(str(channel_id), {}) channel = channels.setdefault(str(channel_id), {})
# Set the value # Set the value
channel[key] = value channel[key] = value
# TODO: Make this a decorator?
@contextmanager
def _safe_wrap_write(self):
"""Safely run any function wrapped in a validate"""
# Wrap in a temporary state to validate first to prevent corruption
tmp_state = State(self._dict)
try:
yield tmp_state
finally:
# Validate and then overwrite our dict with the new one
tmp_state.validate()
self._dict = tmp_state._dict
# Write this change out if we have a file
if self._file:
self._save_to_file()
def _save_to_file(self): def _save_to_file(self):
"""Saves the state out to the chosen file""" """Saves the state out to the chosen file"""
ops.save(self._file, self.dict_internal_copy) ops.save(self._file, self.dict_internal_copy)