Use a decorator for the safe write in State
This is a little cleaner to my eyes
This commit is contained in:
parent
37e1e7a7ae
commit
ef4dd5c571
1 changed files with 77 additions and 76 deletions
|
@ -7,7 +7,7 @@ from typing import Protocol
|
||||||
import matchy.files.ops as ops
|
import matchy.files.ops as ops
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
from functools import wraps
|
||||||
|
|
||||||
logger = logging.getLogger("state")
|
logger = logging.getLogger("state")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
@ -182,6 +182,24 @@ class State():
|
||||||
dict = self._dict
|
dict = self._dict
|
||||||
_SCHEMA.validate(dict)
|
_SCHEMA.validate(dict)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def safe_write(func):
|
||||||
|
"""
|
||||||
|
Wraps any function running it first on some temporary state
|
||||||
|
Validates the resulting state and only then attempts to save it out
|
||||||
|
before storing the dict back in the State
|
||||||
|
"""
|
||||||
|
@wraps(func)
|
||||||
|
def inner(self: State, *args, **kwargs):
|
||||||
|
tmp = State(self._dict, self._file)
|
||||||
|
func(tmp, *args, **kwargs)
|
||||||
|
tmp.validate()
|
||||||
|
if tmp._file:
|
||||||
|
tmp._save_to_file()
|
||||||
|
self._dict = tmp._dict
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
def get_history_timestamps(self, users: list[Member]) -> list[datetime]:
|
def get_history_timestamps(self, users: list[Member]) -> list[datetime]:
|
||||||
"""Grab all timestamps in the history"""
|
"""Grab all timestamps in the history"""
|
||||||
others = [m.id for m in users]
|
others = [m.id for m in users]
|
||||||
|
@ -202,31 +220,31 @@ class State():
|
||||||
def get_user_matches(self, id: int) -> list[int]:
|
def get_user_matches(self, id: int) -> list[int]:
|
||||||
return self._users.get(str(id), {}).get(_Key.MATCHES, {})
|
return self._users.get(str(id), {}).get(_Key.MATCHES, {})
|
||||||
|
|
||||||
|
@safe_write
|
||||||
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
|
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
|
||||||
"""Log the groups"""
|
"""Log the groups"""
|
||||||
ts = datetime_to_ts(ts or datetime.now())
|
ts = datetime_to_ts(ts or datetime.now())
|
||||||
with self._safe_wrap_write() as safe_state:
|
for group in groups:
|
||||||
for group in groups:
|
# Update the matchee data with the matches
|
||||||
# Update the matchee data with the matches
|
for m in group:
|
||||||
for m in group:
|
matchee = self._users.setdefault(str(m.id), {})
|
||||||
matchee = safe_state._users.setdefault(str(m.id), {})
|
matchee_matches = matchee.setdefault(_Key.MATCHES, {})
|
||||||
matchee_matches = matchee.setdefault(_Key.MATCHES, {})
|
|
||||||
|
|
||||||
for o in (o for o in group if o.id != m.id):
|
for o in (o for o in group if o.id != m.id):
|
||||||
matchee_matches[str(o.id)] = ts
|
matchee_matches[str(o.id)] = ts
|
||||||
|
|
||||||
|
@safe_write
|
||||||
def set_user_scope(self, id: str, scope: str, value: bool = True):
|
def set_user_scope(self, id: str, scope: str, value: bool = True):
|
||||||
"""Add an auth scope to a user"""
|
"""Add an auth scope to a user"""
|
||||||
with self._safe_wrap_write() as safe_state:
|
# Dive in
|
||||||
# Dive in
|
user = self._users.setdefault(str(id), {})
|
||||||
user = safe_state._users.setdefault(str(id), {})
|
scopes = user.setdefault(_Key.SCOPES, [])
|
||||||
scopes = user.setdefault(_Key.SCOPES, [])
|
|
||||||
|
|
||||||
# Set the value
|
# Set the value
|
||||||
if value and scope not in scopes:
|
if value and scope not in scopes:
|
||||||
scopes.append(scope)
|
scopes.append(scope)
|
||||||
elif not value and scope in scopes:
|
elif not value and scope in scopes:
|
||||||
scopes.remove(scope)
|
scopes.remove(scope)
|
||||||
|
|
||||||
def get_user_has_scope(self, id: str, scope: str) -> bool:
|
def get_user_has_scope(self, id: str, scope: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -255,17 +273,17 @@ class State():
|
||||||
self._set_user_channel_prop(
|
self._set_user_channel_prop(
|
||||||
id, channel_id, _Key.REACTIVATE, datetime_to_ts(until))
|
id, channel_id, _Key.REACTIVATE, datetime_to_ts(until))
|
||||||
|
|
||||||
|
@safe_write
|
||||||
def reactivate_users(self, channel_id: str):
|
def reactivate_users(self, channel_id: str):
|
||||||
"""Reactivate any users who've passed their reactivation time on this channel"""
|
"""Reactivate any users who've passed their reactivation time on this channel"""
|
||||||
with self._safe_wrap_write() as safe_state:
|
for user in self._users.values():
|
||||||
for user in safe_state._users.values():
|
channels = user.get(_Key.CHANNELS, {})
|
||||||
channels = user.get(_Key.CHANNELS, {})
|
channel = channels.get(str(channel_id), {})
|
||||||
channel = channels.get(str(channel_id), {})
|
if channel and not channel[_Key.ACTIVE]:
|
||||||
if channel and not channel[_Key.ACTIVE]:
|
reactivate = channel.get(_Key.REACTIVATE, None)
|
||||||
reactivate = channel.get(_Key.REACTIVATE, None)
|
# Check if we've gone past the reactivation time and re-activate
|
||||||
# Check if we've gone past the reactivation time and re-activate
|
if reactivate and datetime.now() > ts_to_datetime(reactivate):
|
||||||
if reactivate and datetime.now() > ts_to_datetime(reactivate):
|
channel[_Key.ACTIVE] = True
|
||||||
channel[_Key.ACTIVE] = True
|
|
||||||
|
|
||||||
def get_active_match_tasks(self, time: datetime | None = None) -> Generator[str, int]:
|
def get_active_match_tasks(self, time: datetime | None = None) -> Generator[str, int]:
|
||||||
"""
|
"""
|
||||||
|
@ -295,37 +313,37 @@ class State():
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN])
|
yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN])
|
||||||
|
|
||||||
|
@safe_write
|
||||||
def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool:
|
def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool:
|
||||||
"""Set up a match task on a channel"""
|
"""Set up a match task on a channel"""
|
||||||
with self._safe_wrap_write() as safe_state:
|
channel = self._tasks.setdefault(str(channel_id), {})
|
||||||
channel = safe_state._tasks.setdefault(str(channel_id), {})
|
matches = channel.setdefault(_Key.MATCH_TASKS, [])
|
||||||
matches = channel.setdefault(_Key.MATCH_TASKS, [])
|
|
||||||
|
|
||||||
found = False
|
found = False
|
||||||
for match in matches:
|
for match in matches:
|
||||||
# Specifically check for the combination of weekday and hour
|
# Specifically check for the combination of weekday and hour
|
||||||
if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour:
|
if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour:
|
||||||
found = True
|
found = True
|
||||||
if set:
|
if set:
|
||||||
match[_Key.MEMBERS_MIN] = members_min
|
match[_Key.MEMBERS_MIN] = members_min
|
||||||
else:
|
else:
|
||||||
matches.remove(match)
|
matches.remove(match)
|
||||||
|
|
||||||
# Return true as we've successfully changed the data in place
|
|
||||||
return True
|
|
||||||
|
|
||||||
# If we didn't find it, add it to the schedule
|
|
||||||
if not found and set:
|
|
||||||
matches.append({
|
|
||||||
_Key.MEMBERS_MIN: members_min,
|
|
||||||
_Key.WEEKDAY: weekday,
|
|
||||||
_Key.HOUR: hour,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
# Return true as we've successfully changed the data in place
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# We did not manage to remove the schedule (or add it? though that should be impossible)
|
# If we didn't find it, add it to the schedule
|
||||||
return False
|
if not found and set:
|
||||||
|
matches.append({
|
||||||
|
_Key.MEMBERS_MIN: members_min,
|
||||||
|
_Key.WEEKDAY: weekday,
|
||||||
|
_Key.HOUR: hour,
|
||||||
|
})
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# We did not manage to remove the schedule (or add it? though that should be impossible)
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dict_internal_copy(self) -> dict:
|
def dict_internal_copy(self) -> dict:
|
||||||
|
@ -340,33 +358,16 @@ class State():
|
||||||
def _tasks(self) -> dict[str]:
|
def _tasks(self) -> dict[str]:
|
||||||
return self._dict[_Key.TASKS]
|
return self._dict[_Key.TASKS]
|
||||||
|
|
||||||
|
@safe_write
|
||||||
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
|
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
|
||||||
"""Set a user channel property helper"""
|
"""Set a user channel property helper"""
|
||||||
with self._safe_wrap_write() as safe_state:
|
# Dive in
|
||||||
# Dive in
|
user = self._users.setdefault(str(id), {})
|
||||||
user = safe_state._users.setdefault(str(id), {})
|
channels = user.setdefault(_Key.CHANNELS, {})
|
||||||
channels = user.setdefault(_Key.CHANNELS, {})
|
channel = channels.setdefault(str(channel_id), {})
|
||||||
channel = channels.setdefault(str(channel_id), {})
|
|
||||||
|
|
||||||
# Set the value
|
# Set the value
|
||||||
channel[key] = value
|
channel[key] = value
|
||||||
|
|
||||||
# TODO: Make this a decorator?
|
|
||||||
@contextmanager
|
|
||||||
def _safe_wrap_write(self):
|
|
||||||
"""Safely run any function wrapped in a validate"""
|
|
||||||
# Wrap in a temporary state to validate first to prevent corruption
|
|
||||||
tmp_state = State(self._dict)
|
|
||||||
try:
|
|
||||||
yield tmp_state
|
|
||||||
finally:
|
|
||||||
# Validate and then overwrite our dict with the new one
|
|
||||||
tmp_state.validate()
|
|
||||||
self._dict = tmp_state._dict
|
|
||||||
|
|
||||||
# Write this change out if we have a file
|
|
||||||
if self._file:
|
|
||||||
self._save_to_file()
|
|
||||||
|
|
||||||
def _save_to_file(self):
|
def _save_to_file(self):
|
||||||
"""Saves the state out to the chosen file"""
|
"""Saves the state out to the chosen file"""
|
||||||
|
|
Loading…
Add table
Reference in a new issue