| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  | """Store bot state""" | 
					
						
							|  |  |  | import os | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  | from datetime import datetime, timedelta | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  | from schema import Schema, Use, Optional | 
					
						
							|  |  |  | from collections.abc import Generator | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  | from typing import Protocol | 
					
						
							|  |  |  | import files | 
					
						
							|  |  |  | import copy | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  | from contextlib import contextmanager | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | logger = logging.getLogger("state") | 
					
						
							|  |  |  | logger.setLevel(logging.INFO) | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Warning: Changing any of the below needs proper thought to ensure backwards compatibility | 
					
						
							| 
									
										
										
										
											2024-08-13 00:12:30 +01:00
										 |  |  | _VERSION = 4 | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _migrate_to_v1(d: dict): | 
					
						
							| 
									
										
										
										
											2024-08-11 22:31:20 +01:00
										 |  |  |     """v1 simply renamed matchees to users""" | 
					
						
							|  |  |  |     logger.info("Renaming %s to %s", _Key._MATCHEES, _Key.USERS) | 
					
						
							|  |  |  |     d[_Key.USERS] = d[_Key._MATCHEES] | 
					
						
							|  |  |  |     del d[_Key._MATCHEES] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _migrate_to_v2(d: dict): | 
					
						
							|  |  |  |     """v2 swapped the date over to a less silly format""" | 
					
						
							|  |  |  |     logger.info("Fixing up date format from %s to %s", | 
					
						
							|  |  |  |                 _TIME_FORMAT_OLD, _TIME_FORMAT) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def old_to_new_ts(ts: str) -> str: | 
					
						
							|  |  |  |         return datetime.strftime(datetime.strptime(ts, _TIME_FORMAT_OLD), _TIME_FORMAT) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Adjust all the history keys | 
					
						
							| 
									
										
										
										
											2024-08-13 00:12:30 +01:00
										 |  |  |     d[_Key._HISTORY] = { | 
					
						
							| 
									
										
										
										
											2024-08-11 22:31:20 +01:00
										 |  |  |         old_to_new_ts(ts): entry | 
					
						
							| 
									
										
										
										
											2024-08-13 00:12:30 +01:00
										 |  |  |         for ts, entry in d[_Key._HISTORY].items() | 
					
						
							| 
									
										
										
										
											2024-08-11 22:31:20 +01:00
										 |  |  |     } | 
					
						
							|  |  |  |     # Adjust all the user parts | 
					
						
							|  |  |  |     for user in d[_Key.USERS].values(): | 
					
						
							|  |  |  |         # Update the match dates | 
					
						
							|  |  |  |         matches = user.get(_Key.MATCHES, {}) | 
					
						
							|  |  |  |         for id, ts in matches.items(): | 
					
						
							|  |  |  |             matches[id] = old_to_new_ts(ts) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Update any reactivation dates | 
					
						
							|  |  |  |         channels = user.get(_Key.CHANNELS, {}) | 
					
						
							|  |  |  |         for id, channel in channels.items(): | 
					
						
							|  |  |  |             old_ts = channel.get(_Key.REACTIVATE, None) | 
					
						
							|  |  |  |             if old_ts: | 
					
						
							|  |  |  |                 channel[_Key.REACTIVATE] = old_to_new_ts(old_ts) | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  | def _migrate_to_v3(d: dict): | 
					
						
							|  |  |  |     """v3 simply added the tasks entry""" | 
					
						
							|  |  |  |     d[_Key.TASKS] = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 00:12:30 +01:00
										 |  |  | def _migrate_to_v4(d: dict): | 
					
						
							|  |  |  |     """v4 removed verbose history tracking""" | 
					
						
							|  |  |  |     del d[_Key._HISTORY] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | # Set of migration functions to apply | 
					
						
							|  |  |  | _MIGRATIONS = [ | 
					
						
							| 
									
										
										
										
											2024-08-11 22:31:20 +01:00
										 |  |  |     _migrate_to_v1, | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  |     _migrate_to_v2, | 
					
						
							|  |  |  |     _migrate_to_v3, | 
					
						
							| 
									
										
										
										
											2024-08-13 00:12:30 +01:00
										 |  |  |     _migrate_to_v4, | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class AuthScope(str): | 
					
						
							|  |  |  |     """Various auth scopes""" | 
					
						
							|  |  |  |     OWNER = "owner" | 
					
						
							|  |  |  |     MATCHER = "matcher" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class _Key(str): | 
					
						
							|  |  |  |     """Various keys used in the schema""" | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  |     VERSION = "version" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |     USERS = "users" | 
					
						
							|  |  |  |     SCOPES = "scopes" | 
					
						
							|  |  |  |     MATCHES = "matches" | 
					
						
							|  |  |  |     ACTIVE = "active" | 
					
						
							|  |  |  |     CHANNELS = "channels" | 
					
						
							|  |  |  |     REACTIVATE = "reactivate" | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     TASKS = "tasks" | 
					
						
							|  |  |  |     MATCH_TASKS = "match_tasks" | 
					
						
							|  |  |  |     MEMBERS_MIN = "members_min" | 
					
						
							|  |  |  |     WEEKDAY = "weekdays" | 
					
						
							|  |  |  |     HOUR = "hours" | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Unused | 
					
						
							| 
									
										
										
										
											2024-08-11 22:31:20 +01:00
										 |  |  |     _MATCHEES = "matchees" | 
					
						
							| 
									
										
										
										
											2024-08-13 00:12:30 +01:00
										 |  |  |     _HISTORY = "history" | 
					
						
							|  |  |  |     _GROUPS = "groups" | 
					
						
							|  |  |  |     _MEMBERS = "members" | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 22:31:20 +01:00
										 |  |  | _TIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f" | 
					
						
							|  |  |  | _TIME_FORMAT_OLD = "%a %b %d %H:%M:%S %Y" | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  | _SCHEMA = Schema( | 
					
						
							|  |  |  |     { | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |         # The current version | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  |         _Key.VERSION: Use(int), | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  |         _Key.USERS: { | 
					
						
							|  |  |  |             # User ID as string | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  |             Optional(str): { | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  |                 Optional(_Key.SCOPES): Use(list[str]), | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |                 Optional(_Key.MATCHES): { | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  |                     # Matchee ID and Datetime pair | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  |                     Optional(str): Use(str) | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |                 }, | 
					
						
							|  |  |  |                 Optional(_Key.CHANNELS): { | 
					
						
							|  |  |  |                     # The channel ID | 
					
						
							|  |  |  |                     Optional(str): { | 
					
						
							|  |  |  |                         # Whether the user is signed up in this channel | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  |                         _Key.ACTIVE: Use(bool), | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |                         # A timestamp for when to re-activate the user | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  |                         Optional(_Key.REACTIVATE): Use(str), | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |                     } | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |         }, | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |         _Key.TASKS: { | 
					
						
							|  |  |  |             # Channel ID as string | 
					
						
							|  |  |  |             Optional(str): { | 
					
						
							|  |  |  |                 Optional(_Key.MATCH_TASKS): [ | 
					
						
							|  |  |  |                     { | 
					
						
							|  |  |  |                         _Key.MEMBERS_MIN: Use(int), | 
					
						
							|  |  |  |                         _Key.WEEKDAY: Use(int), | 
					
						
							|  |  |  |                         _Key.HOUR: Use(int), | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 ] | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  |     } | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | # Empty but schema-valid internal dict | 
					
						
							|  |  |  | _EMPTY_DICT = { | 
					
						
							|  |  |  |     _Key.USERS: {}, | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  |     _Key.TASKS: {}, | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |     _Key.VERSION: _VERSION | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | assert _SCHEMA.validate(_EMPTY_DICT) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | class Member(Protocol): | 
					
						
							|  |  |  |     @property | 
					
						
							|  |  |  |     def id(self) -> int: | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def ts_to_datetime(ts: str) -> datetime: | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |     """Convert a string ts to datetime using the internal format""" | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  |     return datetime.strptime(ts, _TIME_FORMAT) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  | def datetime_to_ts(ts: datetime) -> str: | 
					
						
							|  |  |  |     """Convert a datetime to a string ts using the internal format""" | 
					
						
							|  |  |  |     return datetime.strftime(ts, _TIME_FORMAT) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  | class State(): | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |     def __init__(self, data: dict = _EMPTY_DICT): | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  |         """Initialise and validate the state""" | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |         self.validate(data) | 
					
						
							|  |  |  |         self._dict = copy.deepcopy(data) | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |     def validate(self, dict: dict = None): | 
					
						
							|  |  |  |         """Initialise and validate a state dict""" | 
					
						
							|  |  |  |         if not dict: | 
					
						
							|  |  |  |             dict = self._dict | 
					
						
							|  |  |  |         _SCHEMA.validate(dict) | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 00:12:30 +01:00
										 |  |  |     def get_history_timestamps(self, users: list[Member]) -> list[datetime]: | 
					
						
							| 
									
										
										
										
											2024-08-11 22:07:43 +01:00
										 |  |  |         """Grab all timestamps in the history""" | 
					
						
							| 
									
										
										
										
											2024-08-13 00:12:30 +01:00
										 |  |  |         others = [m.id for m in users] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Fetch all the interaction times in history | 
					
						
							|  |  |  |         # But only for interactions in the given user group | 
					
						
							|  |  |  |         times = set() | 
					
						
							|  |  |  |         for data in (data for id, data in self._users.items() if int(id) in others): | 
					
						
							|  |  |  |             matches = data.get(_Key.MATCHES, {}) | 
					
						
							|  |  |  |             for ts in (ts for id, ts in matches.items() if int(id) in others): | 
					
						
							|  |  |  |                 times.add(ts) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Convert to datetimes and sort | 
					
						
							|  |  |  |         datetimes = [ts_to_datetime(ts) for ts in times] | 
					
						
							|  |  |  |         datetimes.sort() | 
					
						
							|  |  |  |         return datetimes | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_user_matches(self, id: int) -> list[int]: | 
					
						
							|  |  |  |         return self._users.get(str(id), {}).get(_Key.MATCHES, {}) | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |     def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None: | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  |         """Log the groups""" | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |         ts = datetime_to_ts(ts or datetime.now()) | 
					
						
							|  |  |  |         with self._safe_wrap() as safe_state: | 
					
						
							|  |  |  |             for group in groups: | 
					
						
							|  |  |  |                 # Update the matchee data with the matches | 
					
						
							|  |  |  |                 for m in group: | 
					
						
							| 
									
										
										
										
											2024-08-13 00:12:30 +01:00
										 |  |  |                     matchee = safe_state._users.setdefault(str(m.id), {}) | 
					
						
							|  |  |  |                     matchee_matches = matchee.setdefault(_Key.MATCHES, {}) | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |                     for o in (o for o in group if o.id != m.id): | 
					
						
							|  |  |  |                         matchee_matches[str(o.id)] = ts | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |     def set_user_scope(self, id: str, scope: str, value: bool = True): | 
					
						
							|  |  |  |         """Add an auth scope to a user""" | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |         with self._safe_wrap() as safe_state: | 
					
						
							|  |  |  |             # Dive in | 
					
						
							| 
									
										
										
										
											2024-08-13 00:12:30 +01:00
										 |  |  |             user = safe_state._users.setdefault(str(id), {}) | 
					
						
							|  |  |  |             scopes = user.setdefault(_Key.SCOPES, []) | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |             # Set the value | 
					
						
							|  |  |  |             if value and scope not in scopes: | 
					
						
							|  |  |  |                 scopes.append(scope) | 
					
						
							|  |  |  |             elif not value and scope in scopes: | 
					
						
							|  |  |  |                 scopes.remove(scope) | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_user_has_scope(self, id: str, scope: str) -> bool: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |             Check if a user has an auth scope | 
					
						
							|  |  |  |             "owner" users have all scopes | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         user = self._users.get(str(id), {}) | 
					
						
							|  |  |  |         scopes = user.get(_Key.SCOPES, []) | 
					
						
							|  |  |  |         return AuthScope.OWNER in scopes or scope in scopes | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |     def set_user_active_in_channel(self, id: str, channel_id: str, active: bool = True): | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |         """Set a user as active (or not) on a given channel""" | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |         self._set_user_channel_prop(id, channel_id, _Key.ACTIVE, active) | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_user_active_in_channel(self, id: str, channel_id: str) -> bool: | 
					
						
							|  |  |  |         """Get a users active channels""" | 
					
						
							|  |  |  |         user = self._users.get(str(id), {}) | 
					
						
							|  |  |  |         channels = user.get(_Key.CHANNELS, {}) | 
					
						
							|  |  |  |         return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-13 14:17:36 +01:00
										 |  |  |     def set_user_paused_in_channel(self, id: str, channel_id: str, until: datetime): | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |         """Sets a user as paused in a channel""" | 
					
						
							|  |  |  |         # Deactivate the user in the channel first | 
					
						
							|  |  |  |         self.set_user_active_in_channel(id, channel_id, False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self._set_user_channel_prop( | 
					
						
							| 
									
										
										
										
											2024-08-13 14:17:36 +01:00
										 |  |  |             id, channel_id, _Key.REACTIVATE, datetime_to_ts(until)) | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def reactivate_users(self, channel_id: str): | 
					
						
							|  |  |  |         """Reactivate any users who've passed their reactivation time on this channel""" | 
					
						
							|  |  |  |         with self._safe_wrap() as safe_state: | 
					
						
							|  |  |  |             for user in safe_state._users.values(): | 
					
						
							| 
									
										
										
										
											2024-08-13 14:17:36 +01:00
										 |  |  |                 channels = user.get(_Key.CHANNELS, {}) | 
					
						
							|  |  |  |                 channel = channels.get(str(channel_id), {}) | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |                 if channel and not channel[_Key.ACTIVE]: | 
					
						
							|  |  |  |                     reactivate = channel.get(_Key.REACTIVATE, None) | 
					
						
							|  |  |  |                     # Check if we've gone past the reactivation time and re-activate | 
					
						
							|  |  |  |                     if reactivate and datetime.now() > ts_to_datetime(reactivate): | 
					
						
							|  |  |  |                         channel[_Key.ACTIVE] = True | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-12 23:29:40 +01:00
										 |  |  |     def get_active_match_tasks(self, time: datetime | None = None) -> Generator[str, int]: | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-08-12 23:29:40 +01:00
										 |  |  |         Get any active match tasks at the given time | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  |         returns list of channel,members_min pairs | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-08-12 23:29:40 +01:00
										 |  |  |         if not time: | 
					
						
							|  |  |  |             time = datetime.now() | 
					
						
							|  |  |  |         weekday = time.weekday() | 
					
						
							|  |  |  |         hour = time.hour | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |         for channel, tasks in self._tasks.items(): | 
					
						
							|  |  |  |             for match in tasks.get(_Key.MATCH_TASKS, []): | 
					
						
							|  |  |  |                 if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour: | 
					
						
							|  |  |  |                     yield (channel, match[_Key.MEMBERS_MIN]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_channel_match_tasks(self, channel_id: str) -> Generator[int, int, int]: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get all match tasks for the channel | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         all_tasks = ( | 
					
						
							|  |  |  |             tasks.get(_Key.MATCH_TASKS, []) | 
					
						
							|  |  |  |             for channel, tasks in self._tasks.items() | 
					
						
							|  |  |  |             if str(channel) == str(channel_id) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         for tasks in all_tasks: | 
					
						
							|  |  |  |             for task in tasks: | 
					
						
							|  |  |  |                 yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool: | 
					
						
							|  |  |  |         """Set up a match task on a channel""" | 
					
						
							|  |  |  |         with self._safe_wrap() as safe_state: | 
					
						
							| 
									
										
										
										
											2024-08-13 00:12:30 +01:00
										 |  |  |             channel = safe_state._tasks.setdefault(str(channel_id), {}) | 
					
						
							|  |  |  |             matches = channel.setdefault(_Key.MATCH_TASKS, []) | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |             found = False | 
					
						
							|  |  |  |             for match in matches: | 
					
						
							|  |  |  |                 # Specifically check for the combination of weekday and hour | 
					
						
							|  |  |  |                 if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour: | 
					
						
							|  |  |  |                     found = True | 
					
						
							|  |  |  |                     if set: | 
					
						
							|  |  |  |                         match[_Key.MEMBERS_MIN] = members_min | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         matches.remove(match) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     # Return true as we've successfully changed the data in place | 
					
						
							|  |  |  |                     return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # If we didn't find it, add it to the schedule | 
					
						
							|  |  |  |             if not found and set: | 
					
						
							|  |  |  |                 matches.append({ | 
					
						
							|  |  |  |                     _Key.MEMBERS_MIN: members_min, | 
					
						
							|  |  |  |                     _Key.WEEKDAY: weekday, | 
					
						
							|  |  |  |                     _Key.HOUR: hour, | 
					
						
							|  |  |  |                 }) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # We did not manage to remove the schedule (or add it? though that should be impossible) | 
					
						
							|  |  |  |             return False | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |     @property | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |     def dict_internal_copy(self) -> dict: | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |         """Only to be used to get the internal dict as a copy""" | 
					
						
							|  |  |  |         return copy.deepcopy(self._dict) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |     @property | 
					
						
							|  |  |  |     def _users(self) -> dict[str]: | 
					
						
							|  |  |  |         return self._dict[_Key.USERS] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-12 23:00:49 +01:00
										 |  |  |     @property | 
					
						
							|  |  |  |     def _tasks(self) -> dict[str]: | 
					
						
							|  |  |  |         return self._dict[_Key.TASKS] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |     def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value): | 
					
						
							|  |  |  |         """Set a user channel property helper""" | 
					
						
							|  |  |  |         with self._safe_wrap() as safe_state: | 
					
						
							|  |  |  |             # Dive in | 
					
						
							| 
									
										
										
										
											2024-08-13 00:12:30 +01:00
										 |  |  |             user = safe_state._users.setdefault(str(id), {}) | 
					
						
							|  |  |  |             channels = user.setdefault(_Key.CHANNELS, {}) | 
					
						
							|  |  |  |             channel = channels.setdefault(str(channel_id), {}) | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # Set the value | 
					
						
							|  |  |  |             channel[key] = value | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @contextmanager | 
					
						
							|  |  |  |     def _safe_wrap(self): | 
					
						
							|  |  |  |         """Safely run any function wrapped in a validate""" | 
					
						
							|  |  |  |         # Wrap in a temporary state to validate first to prevent corruption | 
					
						
							|  |  |  |         tmp_state = State(self._dict) | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             yield tmp_state | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             # Validate and then overwrite our dict with the new one | 
					
						
							|  |  |  |             tmp_state.validate() | 
					
						
							|  |  |  |             self._dict = tmp_state._dict | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | def _migrate(dict: dict): | 
					
						
							|  |  |  |     """Migrate a dict through versions""" | 
					
						
							|  |  |  |     version = dict.get("version", 0) | 
					
						
							|  |  |  |     for i in range(version, _VERSION): | 
					
						
							|  |  |  |         logger.info("Migrating from v%s to v%s", version, version+1) | 
					
						
							|  |  |  |         _MIGRATIONS[i](dict) | 
					
						
							|  |  |  |         dict[_Key.VERSION] = _VERSION | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def load_from_file(file: str) -> State: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Load the state from a file | 
					
						
							|  |  |  |     Apply any required migrations | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     loaded = _EMPTY_DICT | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # If there's a file load it and try to migrate | 
					
						
							|  |  |  |     if os.path.isfile(file): | 
					
						
							|  |  |  |         loaded = files.load(file) | 
					
						
							|  |  |  |         _migrate(loaded) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     st = State(loaded) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Save out the migrated (or new) file | 
					
						
							|  |  |  |     files.save(file, st._dict) | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  |     return st | 
					
						
							| 
									
										
										
										
											2024-08-11 12:16:23 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-11 17:53:37 +01:00
										 |  |  | def save_to_file(state: State, file: str): | 
					
						
							|  |  |  |     """Saves the state out to a file""" | 
					
						
							| 
									
										
										
										
											2024-08-11 19:02:47 +01:00
										 |  |  |     files.save(file, state.dict_internal_copy) |