"""Store bot state""" import os from datetime import datetime, timedelta from schema import Schema, And, Use, Optional from typing import Protocol import files import copy import logging from contextlib import contextmanager logger = logging.getLogger("state") logger.setLevel(logging.INFO) # Warning: Changing any of the below needs proper thought to ensure backwards compatibility _VERSION = 1 def _migrate_to_v1(d: dict): logger.info("Renaming %s to %s", _Key.MATCHEES, _Key.USERS) d[_Key.USERS] = d[_Key.MATCHEES] del d[_Key.MATCHEES] # Set of migration functions to apply _MIGRATIONS = [ _migrate_to_v1 ] class AuthScope(str): """Various auth scopes""" OWNER = "owner" MATCHER = "matcher" class _Key(str): """Various keys used in the schema""" HISTORY = "history" GROUPS = "groups" MEMBERS = "members" USERS = "users" SCOPES = "scopes" MATCHES = "matches" ACTIVE = "active" CHANNELS = "channels" REACTIVATE = "reactivate" VERSION = "version" # Unused MATCHEES = "matchees" _TIME_FORMAT = "%a %b %d %H:%M:%S %Y" _SCHEMA = Schema( { # The current version _Key.VERSION: And(Use(int)), Optional(_Key.HISTORY): { # A datetime Optional(str): { _Key.GROUPS: [ { _Key.MEMBERS: [ # The ID of each matchee in the match And(Use(int)) ] } ] } }, Optional(_Key.USERS): { Optional(str): { Optional(_Key.SCOPES): And(Use(list[str])), Optional(_Key.MATCHES): { # Matchee ID and Datetime pair Optional(str): And(Use(str)) }, Optional(_Key.CHANNELS): { # The channel ID Optional(str): { # Whether the user is signed up in this channel _Key.ACTIVE: And(Use(bool)), # A timestamp for when to re-activate the user Optional(_Key.REACTIVATE): And(Use(str)), } } } }, } ) # Empty but schema-valid internal dict _EMPTY_DICT = { _Key.HISTORY: {}, _Key.USERS: {}, _Key.VERSION: _VERSION } assert _SCHEMA.validate(_EMPTY_DICT) class Member(Protocol): @property def id(self) -> int: pass def ts_to_datetime(ts: str) -> datetime: """Convert a string ts to datetime using the internal format""" return datetime.strptime(ts, _TIME_FORMAT) def datetime_to_ts(ts: datetime) -> str: """Convert a datetime to a string ts using the internal format""" return datetime.strftime(ts, _TIME_FORMAT) class State(): def __init__(self, data: dict = _EMPTY_DICT): """Initialise and validate the state""" self.validate(data) self._dict = copy.deepcopy(data) def validate(self, dict: dict = None): """Initialise and validate a state dict""" if not dict: dict = self._dict _SCHEMA.validate(dict) def get_oldest_timestamp(self) -> datetime: """Grab the oldest timestamp in history""" times = (ts_to_datetime(dt) for dt in self._history.keys()) return next(times, None) def get_user_matches(self, id: int) -> list[int]: return self._users.get(str(id), {}).get(_Key.MATCHES, {}) def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None: """Log the groups""" ts = datetime_to_ts(ts or datetime.now()) with self._safe_wrap() as safe_state: # Grab or create the hitory item for this set of groups history_item = {} safe_state._history[ts] = history_item history_item_groups = [] history_item[_Key.GROUPS] = history_item_groups for group in groups: # Add the group data history_item_groups.append({ _Key.MEMBERS: [m.id for m in group] }) # Update the matchee data with the matches for m in group: matchee = safe_state._users.get(str(m.id), {}) matchee_matches = matchee.get(_Key.MATCHES, {}) for o in (o for o in group if o.id != m.id): matchee_matches[str(o.id)] = ts matchee[_Key.MATCHES] = matchee_matches safe_state._users[str(m.id)] = matchee def set_user_scope(self, id: str, scope: str, value: bool = True): """Add an auth scope to a user""" with self._safe_wrap() as safe_state: # Dive in user = safe_state._users.get(str(id), {}) scopes = user.get(_Key.SCOPES, []) # Set the value if value and scope not in scopes: scopes.append(scope) elif not value and scope in scopes: scopes.remove(scope) # Roll out user[_Key.SCOPES] = scopes safe_state._users[str(id)] = user def get_user_has_scope(self, id: str, scope: str) -> bool: """ Check if a user has an auth scope "owner" users have all scopes """ user = self._users.get(str(id), {}) scopes = user.get(_Key.SCOPES, []) return AuthScope.OWNER in scopes or scope in scopes def set_user_active_in_channel(self, id: str, channel_id: str, active: bool = True): """Set a user as active (or not) on a given channel""" self._set_user_channel_prop(id, channel_id, _Key.ACTIVE, active) def get_user_active_in_channel(self, id: str, channel_id: str) -> bool: """Get a users active channels""" user = self._users.get(str(id), {}) channels = user.get(_Key.CHANNELS, {}) return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)] def set_user_paused_in_channel(self, id: str, channel_id: str, days: int): """Sets a user as paused in a channel""" # Deactivate the user in the channel first self.set_user_active_in_channel(id, channel_id, False) # Set the reactivate time the number of days in the future ts = datetime.now() + timedelta(days=days) self._set_user_channel_prop( id, channel_id, _Key.REACTIVATE, datetime_to_ts(ts)) def reactivate_users(self, channel_id: str): """Reactivate any users who've passed their reactivation time on this channel""" with self._safe_wrap() as safe_state: for user in safe_state._users.values(): channels = user.get(_Key.CHANNELS, {}) channel = channels.get(str(channel_id), {}) if channel and not channel[_Key.ACTIVE]: reactivate = channel.get(_Key.REACTIVATE, None) # Check if we've gone past the reactivation time and re-activate if reactivate and datetime.now() > ts_to_datetime(reactivate): channel[_Key.ACTIVE] = True @property def dict_internal_copy(self) -> dict: """Only to be used to get the internal dict as a copy""" return copy.deepcopy(self._dict) @property def _history(self) -> dict[str]: return self._dict[_Key.HISTORY] @property def _users(self) -> dict[str]: return self._dict[_Key.USERS] def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value): """Set a user channel property helper""" with self._safe_wrap() as safe_state: # Dive in user = safe_state._users.get(str(id), {}) channels = user.get(_Key.CHANNELS, {}) channel = channels.get(str(channel_id), {}) # Set the value channel[key] = value # Unroll channels[str(channel_id)] = channel user[_Key.CHANNELS] = channels safe_state._users[str(id)] = user @contextmanager def _safe_wrap(self): """Safely run any function wrapped in a validate""" # Wrap in a temporary state to validate first to prevent corruption tmp_state = State(self._dict) try: yield tmp_state finally: # Validate and then overwrite our dict with the new one tmp_state.validate() self._dict = tmp_state._dict def _migrate(dict: dict): """Migrate a dict through versions""" version = dict.get("version", 0) for i in range(version, _VERSION): logger.info("Migrating from v%s to v%s", version, version+1) _MIGRATIONS[i](dict) dict[_Key.VERSION] = _VERSION def load_from_file(file: str) -> State: """ Load the state from a file Apply any required migrations """ loaded = _EMPTY_DICT # If there's a file load it and try to migrate if os.path.isfile(file): loaded = files.load(file) _migrate(loaded) st = State(loaded) # Save out the migrated (or new) file files.save(file, st._dict) return st def save_to_file(state: State, file: str): """Saves the state out to a file""" files.save(file, state.dict_internal_copy)