Use a decorator for the safe write in State

This is a little cleaner to my eyes
This commit is contained in:
Marc Di Luzio 2024-08-16 22:53:10 +01:00
parent 37e1e7a7ae
commit ef4dd5c571

View file

@ -7,7 +7,7 @@ from typing import Protocol
import matchy.files.ops as ops
import copy
import logging
from contextlib import contextmanager
from functools import wraps
logger = logging.getLogger("state")
logger.setLevel(logging.INFO)
@ -182,6 +182,24 @@ class State():
dict = self._dict
_SCHEMA.validate(dict)
@staticmethod
def safe_write(func):
"""
Wraps any function running it first on some temporary state
Validates the resulting state and only then attempts to save it out
before storing the dict back in the State
"""
@wraps(func)
def inner(self: State, *args, **kwargs):
tmp = State(self._dict, self._file)
func(tmp, *args, **kwargs)
tmp.validate()
if tmp._file:
tmp._save_to_file()
self._dict = tmp._dict
return inner
def get_history_timestamps(self, users: list[Member]) -> list[datetime]:
"""Grab all timestamps in the history"""
others = [m.id for m in users]
@ -202,24 +220,24 @@ class State():
def get_user_matches(self, id: int) -> list[int]:
return self._users.get(str(id), {}).get(_Key.MATCHES, {})
@safe_write
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
"""Log the groups"""
ts = datetime_to_ts(ts or datetime.now())
with self._safe_wrap_write() as safe_state:
for group in groups:
# Update the matchee data with the matches
for m in group:
matchee = safe_state._users.setdefault(str(m.id), {})
matchee = self._users.setdefault(str(m.id), {})
matchee_matches = matchee.setdefault(_Key.MATCHES, {})
for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts
@safe_write
def set_user_scope(self, id: str, scope: str, value: bool = True):
"""Add an auth scope to a user"""
with self._safe_wrap_write() as safe_state:
# Dive in
user = safe_state._users.setdefault(str(id), {})
user = self._users.setdefault(str(id), {})
scopes = user.setdefault(_Key.SCOPES, [])
# Set the value
@ -255,10 +273,10 @@ class State():
self._set_user_channel_prop(
id, channel_id, _Key.REACTIVATE, datetime_to_ts(until))
@safe_write
def reactivate_users(self, channel_id: str):
"""Reactivate any users who've passed their reactivation time on this channel"""
with self._safe_wrap_write() as safe_state:
for user in safe_state._users.values():
for user in self._users.values():
channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {})
if channel and not channel[_Key.ACTIVE]:
@ -295,10 +313,10 @@ class State():
for task in tasks:
yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN])
@safe_write
def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool:
"""Set up a match task on a channel"""
with self._safe_wrap_write() as safe_state:
channel = safe_state._tasks.setdefault(str(channel_id), {})
channel = self._tasks.setdefault(str(channel_id), {})
matches = channel.setdefault(_Key.MATCH_TASKS, [])
found = False
@ -340,34 +358,17 @@ class State():
def _tasks(self) -> dict[str]:
return self._dict[_Key.TASKS]
@safe_write
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
"""Set a user channel property helper"""
with self._safe_wrap_write() as safe_state:
# Dive in
user = safe_state._users.setdefault(str(id), {})
user = self._users.setdefault(str(id), {})
channels = user.setdefault(_Key.CHANNELS, {})
channel = channels.setdefault(str(channel_id), {})
# Set the value
channel[key] = value
# TODO: Make this a decorator?
@contextmanager
def _safe_wrap_write(self):
"""Safely run any function wrapped in a validate"""
# Wrap in a temporary state to validate first to prevent corruption
tmp_state = State(self._dict)
try:
yield tmp_state
finally:
# Validate and then overwrite our dict with the new one
tmp_state.validate()
self._dict = tmp_state._dict
# Write this change out if we have a file
if self._file:
self._save_to_file()
def _save_to_file(self):
"""Saves the state out to the chosen file"""
ops.save(self._file, self.dict_internal_copy)