Use a decorator for the safe write in State
This is a little cleaner to my eyes
This commit is contained in:
		
							parent
							
								
									37e1e7a7ae
								
							
						
					
					
						commit
						ef4dd5c571
					
				
					 1 changed files with 77 additions and 76 deletions
				
			
		|  | @ -7,7 +7,7 @@ from typing import Protocol | |||
| import matchy.files.ops as ops | ||||
| import copy | ||||
| import logging | ||||
| from contextlib import contextmanager | ||||
| from functools import wraps | ||||
| 
 | ||||
| logger = logging.getLogger("state") | ||||
| logger.setLevel(logging.INFO) | ||||
|  | @ -182,6 +182,24 @@ class State(): | |||
|             dict = self._dict | ||||
|         _SCHEMA.validate(dict) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def safe_write(func): | ||||
|         """ | ||||
|         Wraps any function running it first on some temporary state | ||||
|         Validates the resulting state and only then attempts to save it out | ||||
|         before storing the dict back in the State | ||||
|         """ | ||||
|         @wraps(func) | ||||
|         def inner(self: State, *args, **kwargs): | ||||
|             tmp = State(self._dict, self._file) | ||||
|             func(tmp, *args, **kwargs) | ||||
|             tmp.validate() | ||||
|             if tmp._file: | ||||
|                 tmp._save_to_file() | ||||
|             self._dict = tmp._dict | ||||
| 
 | ||||
|         return inner | ||||
| 
 | ||||
|     def get_history_timestamps(self, users: list[Member]) -> list[datetime]: | ||||
|         """Grab all timestamps in the history""" | ||||
|         others = [m.id for m in users] | ||||
|  | @ -202,31 +220,31 @@ class State(): | |||
|     def get_user_matches(self, id: int) -> list[int]: | ||||
|         return self._users.get(str(id), {}).get(_Key.MATCHES, {}) | ||||
| 
 | ||||
|     @safe_write | ||||
|     def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None: | ||||
|         """Log the groups""" | ||||
|         ts = datetime_to_ts(ts or datetime.now()) | ||||
|         with self._safe_wrap_write() as safe_state: | ||||
|             for group in groups: | ||||
|                 # Update the matchee data with the matches | ||||
|                 for m in group: | ||||
|                     matchee = safe_state._users.setdefault(str(m.id), {}) | ||||
|                     matchee_matches = matchee.setdefault(_Key.MATCHES, {}) | ||||
|         for group in groups: | ||||
|             # Update the matchee data with the matches | ||||
|             for m in group: | ||||
|                 matchee = self._users.setdefault(str(m.id), {}) | ||||
|                 matchee_matches = matchee.setdefault(_Key.MATCHES, {}) | ||||
| 
 | ||||
|                     for o in (o for o in group if o.id != m.id): | ||||
|                         matchee_matches[str(o.id)] = ts | ||||
|                 for o in (o for o in group if o.id != m.id): | ||||
|                     matchee_matches[str(o.id)] = ts | ||||
| 
 | ||||
|     @safe_write | ||||
|     def set_user_scope(self, id: str, scope: str, value: bool = True): | ||||
|         """Add an auth scope to a user""" | ||||
|         with self._safe_wrap_write() as safe_state: | ||||
|             # Dive in | ||||
|             user = safe_state._users.setdefault(str(id), {}) | ||||
|             scopes = user.setdefault(_Key.SCOPES, []) | ||||
|         # Dive in | ||||
|         user = self._users.setdefault(str(id), {}) | ||||
|         scopes = user.setdefault(_Key.SCOPES, []) | ||||
| 
 | ||||
|             # Set the value | ||||
|             if value and scope not in scopes: | ||||
|                 scopes.append(scope) | ||||
|             elif not value and scope in scopes: | ||||
|                 scopes.remove(scope) | ||||
|         # Set the value | ||||
|         if value and scope not in scopes: | ||||
|             scopes.append(scope) | ||||
|         elif not value and scope in scopes: | ||||
|             scopes.remove(scope) | ||||
| 
 | ||||
|     def get_user_has_scope(self, id: str, scope: str) -> bool: | ||||
|         """ | ||||
|  | @ -255,17 +273,17 @@ class State(): | |||
|         self._set_user_channel_prop( | ||||
|             id, channel_id, _Key.REACTIVATE, datetime_to_ts(until)) | ||||
| 
 | ||||
|     @safe_write | ||||
|     def reactivate_users(self, channel_id: str): | ||||
|         """Reactivate any users who've passed their reactivation time on this channel""" | ||||
|         with self._safe_wrap_write() as safe_state: | ||||
|             for user in safe_state._users.values(): | ||||
|                 channels = user.get(_Key.CHANNELS, {}) | ||||
|                 channel = channels.get(str(channel_id), {}) | ||||
|                 if channel and not channel[_Key.ACTIVE]: | ||||
|                     reactivate = channel.get(_Key.REACTIVATE, None) | ||||
|                     # Check if we've gone past the reactivation time and re-activate | ||||
|                     if reactivate and datetime.now() > ts_to_datetime(reactivate): | ||||
|                         channel[_Key.ACTIVE] = True | ||||
|         for user in self._users.values(): | ||||
|             channels = user.get(_Key.CHANNELS, {}) | ||||
|             channel = channels.get(str(channel_id), {}) | ||||
|             if channel and not channel[_Key.ACTIVE]: | ||||
|                 reactivate = channel.get(_Key.REACTIVATE, None) | ||||
|                 # Check if we've gone past the reactivation time and re-activate | ||||
|                 if reactivate and datetime.now() > ts_to_datetime(reactivate): | ||||
|                     channel[_Key.ACTIVE] = True | ||||
| 
 | ||||
|     def get_active_match_tasks(self, time: datetime | None = None) -> Generator[str, int]: | ||||
|         """ | ||||
|  | @ -295,37 +313,37 @@ class State(): | |||
|             for task in tasks: | ||||
|                 yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN]) | ||||
| 
 | ||||
|     @safe_write | ||||
|     def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool: | ||||
|         """Set up a match task on a channel""" | ||||
|         with self._safe_wrap_write() as safe_state: | ||||
|             channel = safe_state._tasks.setdefault(str(channel_id), {}) | ||||
|             matches = channel.setdefault(_Key.MATCH_TASKS, []) | ||||
|         channel = self._tasks.setdefault(str(channel_id), {}) | ||||
|         matches = channel.setdefault(_Key.MATCH_TASKS, []) | ||||
| 
 | ||||
|             found = False | ||||
|             for match in matches: | ||||
|                 # Specifically check for the combination of weekday and hour | ||||
|                 if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour: | ||||
|                     found = True | ||||
|                     if set: | ||||
|                         match[_Key.MEMBERS_MIN] = members_min | ||||
|                     else: | ||||
|                         matches.remove(match) | ||||
| 
 | ||||
|                     # Return true as we've successfully changed the data in place | ||||
|                     return True | ||||
| 
 | ||||
|             # If we didn't find it, add it to the schedule | ||||
|             if not found and set: | ||||
|                 matches.append({ | ||||
|                     _Key.MEMBERS_MIN: members_min, | ||||
|                     _Key.WEEKDAY: weekday, | ||||
|                     _Key.HOUR: hour, | ||||
|                 }) | ||||
|         found = False | ||||
|         for match in matches: | ||||
|             # Specifically check for the combination of weekday and hour | ||||
|             if match[_Key.WEEKDAY] == weekday and match[_Key.HOUR] == hour: | ||||
|                 found = True | ||||
|                 if set: | ||||
|                     match[_Key.MEMBERS_MIN] = members_min | ||||
|                 else: | ||||
|                     matches.remove(match) | ||||
| 
 | ||||
|                 # Return true as we've successfully changed the data in place | ||||
|                 return True | ||||
| 
 | ||||
|             # We did not manage to remove the schedule (or add it? though that should be impossible) | ||||
|             return False | ||||
|         # If we didn't find it, add it to the schedule | ||||
|         if not found and set: | ||||
|             matches.append({ | ||||
|                 _Key.MEMBERS_MIN: members_min, | ||||
|                 _Key.WEEKDAY: weekday, | ||||
|                 _Key.HOUR: hour, | ||||
|             }) | ||||
| 
 | ||||
|             return True | ||||
| 
 | ||||
|         # We did not manage to remove the schedule (or add it? though that should be impossible) | ||||
|         return False | ||||
| 
 | ||||
|     @property | ||||
|     def dict_internal_copy(self) -> dict: | ||||
|  | @ -340,33 +358,16 @@ class State(): | |||
|     def _tasks(self) -> dict[str]: | ||||
|         return self._dict[_Key.TASKS] | ||||
| 
 | ||||
|     @safe_write | ||||
|     def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value): | ||||
|         """Set a user channel property helper""" | ||||
|         with self._safe_wrap_write() as safe_state: | ||||
|             # Dive in | ||||
|             user = safe_state._users.setdefault(str(id), {}) | ||||
|             channels = user.setdefault(_Key.CHANNELS, {}) | ||||
|             channel = channels.setdefault(str(channel_id), {}) | ||||
|         # Dive in | ||||
|         user = self._users.setdefault(str(id), {}) | ||||
|         channels = user.setdefault(_Key.CHANNELS, {}) | ||||
|         channel = channels.setdefault(str(channel_id), {}) | ||||
| 
 | ||||
|             # Set the value | ||||
|             channel[key] = value | ||||
| 
 | ||||
|     # TODO: Make this a decorator? | ||||
|     @contextmanager | ||||
|     def _safe_wrap_write(self): | ||||
|         """Safely run any function wrapped in a validate""" | ||||
|         # Wrap in a temporary state to validate first to prevent corruption | ||||
|         tmp_state = State(self._dict) | ||||
|         try: | ||||
|             yield tmp_state | ||||
|         finally: | ||||
|             # Validate and then overwrite our dict with the new one | ||||
|             tmp_state.validate() | ||||
|             self._dict = tmp_state._dict | ||||
| 
 | ||||
|             # Write this change out if we have a file | ||||
|             if self._file: | ||||
|                 self._save_to_file() | ||||
|         # Set the value | ||||
|         channel[key] = value | ||||
| 
 | ||||
|     def _save_to_file(self): | ||||
|         """Saves the state out to the chosen file""" | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue