Use a decorator for the safe write in State

This is a little cleaner to my eyes
This commit is contained in:
Marc Di Luzio 2024-08-16 22:53:10 +01:00
parent 37e1e7a7ae
commit ef4dd5c571

View file

@ -7,7 +7,7 @@ from typing import Protocol
import matchy.files.ops as ops
import copy
import logging
from contextlib import contextmanager
from functools import wraps
logger = logging.getLogger("state")
logger.setLevel(logging.INFO)
@ -182,6 +182,24 @@ class State():
dict = self._dict
_SCHEMA.validate(dict)
@staticmethod
def safe_write(func):
"""
Wraps any function running it first on some temporary state
Validates the resulting state and only then attempts to save it out
before storing the dict back in the State
"""
@wraps(func)
def inner(self: State, *args, **kwargs):
tmp = State(self._dict, self._file)
func(tmp, *args, **kwargs)
tmp.validate()
if tmp._file:
tmp._save_to_file()
self._dict = tmp._dict
return inner
def get_history_timestamps(self, users: list[Member]) -> list[datetime]:
"""Grab all timestamps in the history"""
others = [m.id for m in users]
@ -202,31 +220,31 @@ class State():
def get_user_matches(self, id: int) -> list[int]:
return self._users.get(str(id), {}).get(_Key.MATCHES, {})
@safe_write
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
"""Log the groups"""
ts = datetime_to_ts(ts or datetime.now())
with self._safe_wrap_write() as safe_state:
for group in groups:
# Update the matchee data with the matches
for m in group:
matchee = safe_state._users.setdefault(str(m.id), {})
matchee_matches = matchee.setdefault(_Key.MATCHES, {})
for group in groups:
# Update the matchee data with the matches
for m in group:
matchee = self._users.setdefault(str(m.id), {})
matchee_matches = matchee.setdefault(_Key.MATCHES, {})
for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts
for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts
@safe_write
def set_user_scope(self, id: str, scope: str, value: bool = True):
"""Add an auth scope to a user"""
with self._safe_wrap_write() as safe_state:
# Dive in
user = safe_state._users.setdefault(str(id), {})
scopes = user.setdefault(_Key.SCOPES, [])
# Dive in
user = self._users.setdefault(str(id), {})
scopes = user.setdefault(_Key.SCOPES, [])
# Set the value
if value and scope not in scopes:
scopes.append(scope)
elif not value and scope in scopes:
scopes.remove(scope)
# Set the value
if value and scope not in scopes:
scopes.append(scope)
elif not value and scope in scopes:
scopes.remove(scope)
def get_user_has_scope(self, id: str, scope: str) -> bool:
"""
@ -255,17 +273,17 @@ class State():
self._set_user_channel_prop(
id, channel_id, _Key.REACTIVATE, datetime_to_ts(until))
@safe_write
def reactivate_users(self, channel_id: str):
"""Reactivate any users who've passed their reactivation time on this channel"""
with self._safe_wrap_write() as safe_state:
for user in safe_state._users.values():
channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {})
if channel and not channel[_Key.ACTIVE]:
reactivate = channel.get(_Key.REACTIVATE, None)
# Check if we've gone past the reactivation time and re-activate
if reactivate and datetime.now() > ts_to_datetime(reactivate):
channel[_Key.ACTIVE] = True
for user in self._users.values():
channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {})
if channel and not channel[_Key.ACTIVE]:
reactivate = channel.get(_Key.REACTIVATE, None)
# Check if we've gone past the reactivation time and re-activate
if reactivate and datetime.now() > ts_to_datetime(reactivate):
channel[_Key.ACTIVE] = True
def get_active_match_tasks(self, time: datetime | None = None) -> Generator[str, int]:
"""
@ -295,37 +313,37 @@ class State():
for task in tasks:
yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN])
@safe_write
def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool:
"""Set up a match task on a channel"""
with self._safe_wrap_write() as safe_state:
channel = safe_state._tasks.setdefault(str(channel_id), {})
matches = channel.setdefault(_Key.MATCH_TASKS, [])
channel = self._tasks.setdefault(str(channel_id), {})
matches = channel.setdefault(_Key.MATCH_TASKS, [])
found = False
for match in matches:
# Specifically check for the combination of weekday and hour
if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour:
found = True
if set:
match[_Key.MEMBERS_MIN] = members_min
else:
matches.remove(match)
# Return true as we've successfully changed the data in place
return True
# If we didn't find it, add it to the schedule
if not found and set:
matches.append({
_Key.MEMBERS_MIN: members_min,
_Key.WEEKDAY: weekday,
_Key.HOUR: hour,
})
found = False
for match in matches:
# Specifically check for the combination of weekday and hour
if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour:
found = True
if set:
match[_Key.MEMBERS_MIN] = members_min
else:
matches.remove(match)
# Return true as we've successfully changed the data in place
return True
# We did not manage to remove the schedule (or add it? though that should be impossible)
return False
# If we didn't find it, add it to the schedule
if not found and set:
matches.append({
_Key.MEMBERS_MIN: members_min,
_Key.WEEKDAY: weekday,
_Key.HOUR: hour,
})
return True
# We did not manage to remove the schedule (or add it? though that should be impossible)
return False
@property
def dict_internal_copy(self) -> dict:
@ -340,33 +358,16 @@ class State():
def _tasks(self) -> dict[str]:
return self._dict[_Key.TASKS]
@safe_write
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
"""Set a user channel property helper"""
with self._safe_wrap_write() as safe_state:
# Dive in
user = safe_state._users.setdefault(str(id), {})
channels = user.setdefault(_Key.CHANNELS, {})
channel = channels.setdefault(str(channel_id), {})
# Dive in
user = self._users.setdefault(str(id), {})
channels = user.setdefault(_Key.CHANNELS, {})
channel = channels.setdefault(str(channel_id), {})
# Set the value
channel[key] = value
# TODO: Make this a decorator?
@contextmanager
def _safe_wrap_write(self):
"""Safely run any function wrapped in a validate"""
# Wrap in a temporary state to validate first to prevent corruption
tmp_state = State(self._dict)
try:
yield tmp_state
finally:
# Validate and then overwrite our dict with the new one
tmp_state.validate()
self._dict = tmp_state._dict
# Write this change out if we have a file
if self._file:
self._save_to_file()
# Set the value
channel[key] = value
def _save_to_file(self):
"""Saves the state out to the chosen file"""