Use a decorator for the safe write in State
This is a little cleaner to my eyes
This commit is contained in:
parent
37e1e7a7ae
commit
ef4dd5c571
1 changed files with 77 additions and 76 deletions
|
@ -7,7 +7,7 @@ from typing import Protocol
|
|||
import matchy.files.ops as ops
|
||||
import copy
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
logger = logging.getLogger("state")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
@ -182,6 +182,24 @@ class State():
|
|||
dict = self._dict
|
||||
_SCHEMA.validate(dict)
|
||||
|
||||
@staticmethod
|
||||
def safe_write(func):
|
||||
"""
|
||||
Wraps any function running it first on some temporary state
|
||||
Validates the resulting state and only then attempts to save it out
|
||||
before storing the dict back in the State
|
||||
"""
|
||||
@wraps(func)
|
||||
def inner(self: State, *args, **kwargs):
|
||||
tmp = State(self._dict, self._file)
|
||||
func(tmp, *args, **kwargs)
|
||||
tmp.validate()
|
||||
if tmp._file:
|
||||
tmp._save_to_file()
|
||||
self._dict = tmp._dict
|
||||
|
||||
return inner
|
||||
|
||||
def get_history_timestamps(self, users: list[Member]) -> list[datetime]:
|
||||
"""Grab all timestamps in the history"""
|
||||
others = [m.id for m in users]
|
||||
|
@ -202,31 +220,31 @@ class State():
|
|||
def get_user_matches(self, id: int) -> list[int]:
|
||||
return self._users.get(str(id), {}).get(_Key.MATCHES, {})
|
||||
|
||||
@safe_write
|
||||
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
|
||||
"""Log the groups"""
|
||||
ts = datetime_to_ts(ts or datetime.now())
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
for group in groups:
|
||||
# Update the matchee data with the matches
|
||||
for m in group:
|
||||
matchee = safe_state._users.setdefault(str(m.id), {})
|
||||
matchee_matches = matchee.setdefault(_Key.MATCHES, {})
|
||||
for group in groups:
|
||||
# Update the matchee data with the matches
|
||||
for m in group:
|
||||
matchee = self._users.setdefault(str(m.id), {})
|
||||
matchee_matches = matchee.setdefault(_Key.MATCHES, {})
|
||||
|
||||
for o in (o for o in group if o.id != m.id):
|
||||
matchee_matches[str(o.id)] = ts
|
||||
for o in (o for o in group if o.id != m.id):
|
||||
matchee_matches[str(o.id)] = ts
|
||||
|
||||
@safe_write
|
||||
def set_user_scope(self, id: str, scope: str, value: bool = True):
|
||||
"""Add an auth scope to a user"""
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
# Dive in
|
||||
user = safe_state._users.setdefault(str(id), {})
|
||||
scopes = user.setdefault(_Key.SCOPES, [])
|
||||
# Dive in
|
||||
user = self._users.setdefault(str(id), {})
|
||||
scopes = user.setdefault(_Key.SCOPES, [])
|
||||
|
||||
# Set the value
|
||||
if value and scope not in scopes:
|
||||
scopes.append(scope)
|
||||
elif not value and scope in scopes:
|
||||
scopes.remove(scope)
|
||||
# Set the value
|
||||
if value and scope not in scopes:
|
||||
scopes.append(scope)
|
||||
elif not value and scope in scopes:
|
||||
scopes.remove(scope)
|
||||
|
||||
def get_user_has_scope(self, id: str, scope: str) -> bool:
|
||||
"""
|
||||
|
@ -255,17 +273,17 @@ class State():
|
|||
self._set_user_channel_prop(
|
||||
id, channel_id, _Key.REACTIVATE, datetime_to_ts(until))
|
||||
|
||||
@safe_write
|
||||
def reactivate_users(self, channel_id: str):
|
||||
"""Reactivate any users who've passed their reactivation time on this channel"""
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
for user in safe_state._users.values():
|
||||
channels = user.get(_Key.CHANNELS, {})
|
||||
channel = channels.get(str(channel_id), {})
|
||||
if channel and not channel[_Key.ACTIVE]:
|
||||
reactivate = channel.get(_Key.REACTIVATE, None)
|
||||
# Check if we've gone past the reactivation time and re-activate
|
||||
if reactivate and datetime.now() > ts_to_datetime(reactivate):
|
||||
channel[_Key.ACTIVE] = True
|
||||
for user in self._users.values():
|
||||
channels = user.get(_Key.CHANNELS, {})
|
||||
channel = channels.get(str(channel_id), {})
|
||||
if channel and not channel[_Key.ACTIVE]:
|
||||
reactivate = channel.get(_Key.REACTIVATE, None)
|
||||
# Check if we've gone past the reactivation time and re-activate
|
||||
if reactivate and datetime.now() > ts_to_datetime(reactivate):
|
||||
channel[_Key.ACTIVE] = True
|
||||
|
||||
def get_active_match_tasks(self, time: datetime | None = None) -> Generator[str, int]:
|
||||
"""
|
||||
|
@ -295,37 +313,37 @@ class State():
|
|||
for task in tasks:
|
||||
yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN])
|
||||
|
||||
@safe_write
|
||||
def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool:
|
||||
"""Set up a match task on a channel"""
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
channel = safe_state._tasks.setdefault(str(channel_id), {})
|
||||
matches = channel.setdefault(_Key.MATCH_TASKS, [])
|
||||
channel = self._tasks.setdefault(str(channel_id), {})
|
||||
matches = channel.setdefault(_Key.MATCH_TASKS, [])
|
||||
|
||||
found = False
|
||||
for match in matches:
|
||||
# Specifically check for the combination of weekday and hour
|
||||
if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour:
|
||||
found = True
|
||||
if set:
|
||||
match[_Key.MEMBERS_MIN] = members_min
|
||||
else:
|
||||
matches.remove(match)
|
||||
|
||||
# Return true as we've successfully changed the data in place
|
||||
return True
|
||||
|
||||
# If we didn't find it, add it to the schedule
|
||||
if not found and set:
|
||||
matches.append({
|
||||
_Key.MEMBERS_MIN: members_min,
|
||||
_Key.WEEKDAY: weekday,
|
||||
_Key.HOUR: hour,
|
||||
})
|
||||
found = False
|
||||
for match in matches:
|
||||
# Specifically check for the combination of weekday and hour
|
||||
if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour:
|
||||
found = True
|
||||
if set:
|
||||
match[_Key.MEMBERS_MIN] = members_min
|
||||
else:
|
||||
matches.remove(match)
|
||||
|
||||
# Return true as we've successfully changed the data in place
|
||||
return True
|
||||
|
||||
# We did not manage to remove the schedule (or add it? though that should be impossible)
|
||||
return False
|
||||
# If we didn't find it, add it to the schedule
|
||||
if not found and set:
|
||||
matches.append({
|
||||
_Key.MEMBERS_MIN: members_min,
|
||||
_Key.WEEKDAY: weekday,
|
||||
_Key.HOUR: hour,
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
# We did not manage to remove the schedule (or add it? though that should be impossible)
|
||||
return False
|
||||
|
||||
@property
|
||||
def dict_internal_copy(self) -> dict:
|
||||
|
@ -340,33 +358,16 @@ class State():
|
|||
def _tasks(self) -> dict[str]:
|
||||
return self._dict[_Key.TASKS]
|
||||
|
||||
@safe_write
|
||||
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
|
||||
"""Set a user channel property helper"""
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
# Dive in
|
||||
user = safe_state._users.setdefault(str(id), {})
|
||||
channels = user.setdefault(_Key.CHANNELS, {})
|
||||
channel = channels.setdefault(str(channel_id), {})
|
||||
# Dive in
|
||||
user = self._users.setdefault(str(id), {})
|
||||
channels = user.setdefault(_Key.CHANNELS, {})
|
||||
channel = channels.setdefault(str(channel_id), {})
|
||||
|
||||
# Set the value
|
||||
channel[key] = value
|
||||
|
||||
# TODO: Make this a decorator?
|
||||
@contextmanager
|
||||
def _safe_wrap_write(self):
|
||||
"""Safely run any function wrapped in a validate"""
|
||||
# Wrap in a temporary state to validate first to prevent corruption
|
||||
tmp_state = State(self._dict)
|
||||
try:
|
||||
yield tmp_state
|
||||
finally:
|
||||
# Validate and then overwrite our dict with the new one
|
||||
tmp_state.validate()
|
||||
self._dict = tmp_state._dict
|
||||
|
||||
# Write this change out if we have a file
|
||||
if self._file:
|
||||
self._save_to_file()
|
||||
# Set the value
|
||||
channel[key] = value
|
||||
|
||||
def _save_to_file(self):
|
||||
"""Saves the state out to the chosen file"""
|
||||
|
|
Loading…
Add table
Reference in a new issue