More protection - State does it's own saving

This commit is contained in:
Marc Di Luzio 2024-08-13 23:43:15 +01:00
parent cbea7abca2
commit 57f65b265c
5 changed files with 27 additions and 29 deletions

View file

@ -174,10 +174,11 @@ def datetime_to_ts(ts: datetime) -> str:
class State():
def __init__(self, data: dict):
def __init__(self, data: dict, file: str | None = None):
"""Initialise and validate the state"""
self.validate(data)
self._dict = copy.deepcopy(data)
self._file = file
def validate(self, dict: dict = None):
"""Initialise and validate a state dict"""
@ -208,7 +209,7 @@ class State():
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
"""Log the groups"""
ts = datetime_to_ts(ts or datetime.now())
with self._safe_wrap() as safe_state:
with self._safe_wrap_write() as safe_state:
for group in groups:
# Update the matchee data with the matches
for m in group:
@ -220,7 +221,7 @@ class State():
def set_user_scope(self, id: str, scope: str, value: bool = True):
"""Add an auth scope to a user"""
with self._safe_wrap() as safe_state:
with self._safe_wrap_write() as safe_state:
# Dive in
user = safe_state._users.setdefault(str(id), {})
scopes = user.setdefault(_Key.SCOPES, [])
@ -260,7 +261,7 @@ class State():
def reactivate_users(self, channel_id: str):
"""Reactivate any users who've passed their reactivation time on this channel"""
with self._safe_wrap() as safe_state:
with self._safe_wrap_write() as safe_state:
for user in safe_state._users.values():
channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {})
@ -300,7 +301,7 @@ class State():
def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool:
"""Set up a match task on a channel"""
with self._safe_wrap() as safe_state:
with self._safe_wrap_write() as safe_state:
channel = safe_state._tasks.setdefault(str(channel_id), {})
matches = channel.setdefault(_Key.MATCH_TASKS, [])
@ -345,7 +346,7 @@ class State():
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
"""Set a user channel property helper"""
with self._safe_wrap() as safe_state:
with self._safe_wrap_write() as safe_state:
# Dive in
user = safe_state._users.setdefault(str(id), {})
channels = user.setdefault(_Key.CHANNELS, {})
@ -355,7 +356,7 @@ class State():
channel[key] = value
@contextmanager
def _safe_wrap(self):
def _safe_wrap_write(self):
"""Safely run any function wrapped in a validate"""
# Wrap in a temporary state to validate first to prevent corruption
tmp_state = State(self._dict)
@ -366,6 +367,14 @@ class State():
tmp_state.validate()
self._dict = tmp_state._dict
# Write this change out if we have a file
if self._file:
self._save_to_file()
def _save_to_file(self):
"""Saves the state out to the chosen file"""
files.save(self._file, self.dict_internal_copy)
def _migrate(dict: dict):
"""Migrate a dict through versions"""
@ -378,7 +387,7 @@ def _migrate(dict: dict):
def load_from_file(file: str = _STATE_FILE) -> State:
"""
Load the state from a file
Load the state from a files
Apply any required migrations
"""
loaded = _EMPTY_DICT
@ -388,14 +397,9 @@ def load_from_file(file: str = _STATE_FILE) -> State:
loaded = files.load(file)
_migrate(loaded)
st = State(loaded)
st = State(loaded, file)
# Save out the migrated (or new) file
files.save(file, st._dict)
return st
def save_to_file(state: State, file: str = _STATE_FILE):
"""Saves the state out to a file"""
files.save(file, state.dict_internal_copy)