Implement user pausing with /pause
This commit is contained in:
		
							parent
							
								
									a480549ad3
								
							
						
					
					
						commit
						7efe781e66
					
				
					 4 changed files with 125 additions and 67 deletions
				
			
		|  | @ -30,7 +30,6 @@ Matchy is configured by a `config.json` file that takes this format: | ||||||
| 
 | 
 | ||||||
| ## TODO | ## TODO | ||||||
| * Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html) | * Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html) | ||||||
| * Implement /pause to pause a user for a little while |  | ||||||
| * Move more constants to the config | * Move more constants to the config | ||||||
| * Add scheduling functionality | * Add scheduling functionality | ||||||
| * Fix logging in some sub files (doesn't seem to actually be output?) | * Fix logging in some sub files (doesn't seem to actually be output?) | ||||||
|  |  | ||||||
							
								
								
									
										20
									
								
								py/matchy.py
									
										
									
									
									
								
							
							
						
						
									
										20
									
								
								py/matchy.py
									
										
									
									
									
								
							|  | @ -75,7 +75,7 @@ async def close(ctx: commands.Context): | ||||||
| @bot.tree.command(description="Join the matchees for this channel") | @bot.tree.command(description="Join the matchees for this channel") | ||||||
| @commands.guild_only() | @commands.guild_only() | ||||||
| async def join(interaction: discord.Interaction): | async def join(interaction: discord.Interaction): | ||||||
|     State.set_use_active_in_channel( |     State.set_user_active_in_channel( | ||||||
|         interaction.user.id, interaction.channel.id) |         interaction.user.id, interaction.channel.id) | ||||||
|     state.save_to_file(State, STATE_FILE) |     state.save_to_file(State, STATE_FILE) | ||||||
|     await interaction.response.send_message( |     await interaction.response.send_message( | ||||||
|  | @ -87,13 +87,26 @@ async def join(interaction: discord.Interaction): | ||||||
| @bot.tree.command(description="Leave the matchees for this channel") | @bot.tree.command(description="Leave the matchees for this channel") | ||||||
| @commands.guild_only() | @commands.guild_only() | ||||||
| async def leave(interaction: discord.Interaction): | async def leave(interaction: discord.Interaction): | ||||||
|     State.set_use_active_in_channel( |     State.set_user_active_in_channel( | ||||||
|         interaction.user.id, interaction.channel.id, False) |         interaction.user.id, interaction.channel.id, False) | ||||||
|     state.save_to_file(State, STATE_FILE) |     state.save_to_file(State, STATE_FILE) | ||||||
|     await interaction.response.send_message( |     await interaction.response.send_message( | ||||||
|         f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True) |         f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @bot.tree.command(description="Pause your matching in this channel for a number of days") | ||||||
|  | @commands.guild_only() | ||||||
|  | @app_commands.describe(days="Days to pause for (defaults to 7)") | ||||||
|  | async def pause(interaction: discord.Interaction, days: int = None): | ||||||
|  |     if not days:  # Default to a week | ||||||
|  |         days = 7 | ||||||
|  |     State.set_user_paused_in_channel( | ||||||
|  |         interaction.user.id, interaction.channel.id, days) | ||||||
|  |     state.save_to_file(State, STATE_FILE) | ||||||
|  |     await interaction.response.send_message( | ||||||
|  |         f"Sure thing {interaction.user.mention}. Paused you for {days} days!", ephemeral=True, silent=True) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @bot.tree.command(description="List the matchees for this channel") | @bot.tree.command(description="List the matchees for this channel") | ||||||
| @commands.guild_only() | @commands.guild_only() | ||||||
| async def list(interaction: discord.Interaction): | async def list(interaction: discord.Interaction): | ||||||
|  | @ -195,6 +208,9 @@ class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button], | ||||||
| 
 | 
 | ||||||
| def get_matchees_in_channel(channel: discord.channel): | def get_matchees_in_channel(channel: discord.channel): | ||||||
|     """Fetches the matchees in a channel""" |     """Fetches the matchees in a channel""" | ||||||
|  |     # Reactivate any unpaused users | ||||||
|  |     State.reactivate_users(channel.id) | ||||||
|  | 
 | ||||||
|     # Gather up the prospective matchees |     # Gather up the prospective matchees | ||||||
|     return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)] |     return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)] | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										119
									
								
								py/state.py
									
										
									
									
									
								
							
							
						
						
									
										119
									
								
								py/state.py
									
										
									
									
									
								
							|  | @ -1,11 +1,12 @@ | ||||||
| """Store bot state""" | """Store bot state""" | ||||||
| import os | import os | ||||||
| from datetime import datetime | from datetime import datetime, timedelta | ||||||
| from schema import Schema, And, Use, Optional | from schema import Schema, And, Use, Optional | ||||||
| from typing import Protocol | from typing import Protocol | ||||||
| import files | import files | ||||||
| import copy | import copy | ||||||
| import logging | import logging | ||||||
|  | from contextlib import contextmanager | ||||||
| 
 | 
 | ||||||
| logger = logging.getLogger("state") | logger = logging.getLogger("state") | ||||||
| logger.setLevel(logging.INFO) | logger.setLevel(logging.INFO) | ||||||
|  | @ -83,6 +84,8 @@ _SCHEMA = Schema( | ||||||
|                     Optional(str): { |                     Optional(str): { | ||||||
|                         # Whether the user is signed up in this channel |                         # Whether the user is signed up in this channel | ||||||
|                         _Key.ACTIVE: And(Use(bool)), |                         _Key.ACTIVE: And(Use(bool)), | ||||||
|  |                         # A timestamp for when to re-activate the user | ||||||
|  |                         Optional(_Key.REACTIVATE): And(Use(str)), | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  | @ -106,24 +109,21 @@ class Member(Protocol): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def ts_to_datetime(ts: str) -> datetime: | def ts_to_datetime(ts: str) -> datetime: | ||||||
|     """Convert a ts to datetime using the internal format""" |     """Convert a string ts to datetime using the internal format""" | ||||||
|     return datetime.strptime(ts, _TIME_FORMAT) |     return datetime.strptime(ts, _TIME_FORMAT) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def datetime_to_ts(ts: datetime) -> str: | ||||||
|  |     """Convert a datetime to a string ts using the internal format""" | ||||||
|  |     return datetime.strftime(ts, _TIME_FORMAT) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class State(): | class State(): | ||||||
|     def __init__(self, data: dict = _EMPTY_DICT): |     def __init__(self, data: dict = _EMPTY_DICT): | ||||||
|         """Initialise and validate the state""" |         """Initialise and validate the state""" | ||||||
|         self.validate(data) |         self.validate(data) | ||||||
|         self._dict = copy.deepcopy(data) |         self._dict = copy.deepcopy(data) | ||||||
| 
 | 
 | ||||||
|     @property |  | ||||||
|     def _history(self) -> dict[str]: |  | ||||||
|         return self._dict[_Key.HISTORY] |  | ||||||
| 
 |  | ||||||
|     @property |  | ||||||
|     def _users(self) -> dict[str]: |  | ||||||
|         return self._dict[_Key.USERS] |  | ||||||
| 
 |  | ||||||
|     def validate(self, dict: dict = None): |     def validate(self, dict: dict = None): | ||||||
|         """Initialise and validate a state dict""" |         """Initialise and validate a state dict""" | ||||||
|         if not dict: |         if not dict: | ||||||
|  | @ -138,14 +138,13 @@ class State(): | ||||||
|     def get_user_matches(self, id: int) -> list[int]: |     def get_user_matches(self, id: int) -> list[int]: | ||||||
|         return self._users.get(str(id), {}).get(_Key.MATCHES, {}) |         return self._users.get(str(id), {}).get(_Key.MATCHES, {}) | ||||||
| 
 | 
 | ||||||
|     def log_groups(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None: |     def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None: | ||||||
|         """Log the groups""" |         """Log the groups""" | ||||||
|         tmp_state = State(self._dict) |         ts = datetime_to_ts(ts or datetime.now()) | ||||||
|         ts = datetime.strftime(ts, _TIME_FORMAT) |         with self._safe_wrap() as safe_state: | ||||||
| 
 |  | ||||||
|             # Grab or create the hitory item for this set of groups |             # Grab or create the hitory item for this set of groups | ||||||
|             history_item = {} |             history_item = {} | ||||||
|         tmp_state._history[ts] = history_item |             safe_state._history[ts] = history_item | ||||||
|             history_item_groups = [] |             history_item_groups = [] | ||||||
|             history_item[_Key.GROUPS] = history_item_groups |             history_item[_Key.GROUPS] = history_item_groups | ||||||
| 
 | 
 | ||||||
|  | @ -158,23 +157,20 @@ class State(): | ||||||
| 
 | 
 | ||||||
|                 # Update the matchee data with the matches |                 # Update the matchee data with the matches | ||||||
|                 for m in group: |                 for m in group: | ||||||
|                 matchee = tmp_state._users.get(str(m.id), {}) |                     matchee = safe_state._users.get(str(m.id), {}) | ||||||
|                     matchee_matches = matchee.get(_Key.MATCHES, {}) |                     matchee_matches = matchee.get(_Key.MATCHES, {}) | ||||||
| 
 | 
 | ||||||
|                     for o in (o for o in group if o.id != m.id): |                     for o in (o for o in group if o.id != m.id): | ||||||
|                         matchee_matches[str(o.id)] = ts |                         matchee_matches[str(o.id)] = ts | ||||||
| 
 | 
 | ||||||
|                     matchee[_Key.MATCHES] = matchee_matches |                     matchee[_Key.MATCHES] = matchee_matches | ||||||
|                 tmp_state._users[str(m.id)] = matchee |                     safe_state._users[str(m.id)] = matchee | ||||||
| 
 |  | ||||||
|         # Validate before storing the result |  | ||||||
|         tmp_state.validate() |  | ||||||
|         self._dict = tmp_state._dict |  | ||||||
| 
 | 
 | ||||||
|     def set_user_scope(self, id: str, scope: str, value: bool = True): |     def set_user_scope(self, id: str, scope: str, value: bool = True): | ||||||
|         """Add an auth scope to a user""" |         """Add an auth scope to a user""" | ||||||
|  |         with self._safe_wrap() as safe_state: | ||||||
|             # Dive in |             # Dive in | ||||||
|         user = self._users.get(str(id), {}) |             user = safe_state._users.get(str(id), {}) | ||||||
|             scopes = user.get(_Key.SCOPES, []) |             scopes = user.get(_Key.SCOPES, []) | ||||||
| 
 | 
 | ||||||
|             # Set the value |             # Set the value | ||||||
|  | @ -185,7 +181,7 @@ class State(): | ||||||
| 
 | 
 | ||||||
|             # Roll out |             # Roll out | ||||||
|             user[_Key.SCOPES] = scopes |             user[_Key.SCOPES] = scopes | ||||||
|         self._users[id] = user |             safe_state._users[str(id)] = user | ||||||
| 
 | 
 | ||||||
|     def get_user_has_scope(self, id: str, scope: str) -> bool: |     def get_user_has_scope(self, id: str, scope: str) -> bool: | ||||||
|         """ |         """ | ||||||
|  | @ -196,20 +192,9 @@ class State(): | ||||||
|         scopes = user.get(_Key.SCOPES, []) |         scopes = user.get(_Key.SCOPES, []) | ||||||
|         return AuthScope.OWNER in scopes or scope in scopes |         return AuthScope.OWNER in scopes or scope in scopes | ||||||
| 
 | 
 | ||||||
|     def set_use_active_in_channel(self, id: str, channel_id: str, active: bool = True): |     def set_user_active_in_channel(self, id: str, channel_id: str, active: bool = True): | ||||||
|         """Set a user as active (or not) on a given channel""" |         """Set a user as active (or not) on a given channel""" | ||||||
|         # Dive in |         self._set_user_channel_prop(id, channel_id, _Key.ACTIVE, active) | ||||||
|         user = self._users.get(str(id), {}) |  | ||||||
|         channels = user.get(_Key.CHANNELS, {}) |  | ||||||
|         channel = channels.get(str(channel_id), {}) |  | ||||||
| 
 |  | ||||||
|         # Set the value |  | ||||||
|         channel[_Key.ACTIVE] = active |  | ||||||
| 
 |  | ||||||
|         # Unroll |  | ||||||
|         channels[str(channel_id)] = channel |  | ||||||
|         user[_Key.CHANNELS] = channels |  | ||||||
|         self._users[str(id)] = user |  | ||||||
| 
 | 
 | ||||||
|     def get_user_active_in_channel(self, id: str, channel_id: str) -> bool: |     def get_user_active_in_channel(self, id: str, channel_id: str) -> bool: | ||||||
|         """Get a users active channels""" |         """Get a users active channels""" | ||||||
|  | @ -217,11 +202,69 @@ class State(): | ||||||
|         channels = user.get(_Key.CHANNELS, {}) |         channels = user.get(_Key.CHANNELS, {}) | ||||||
|         return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)] |         return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)] | ||||||
| 
 | 
 | ||||||
|  |     def set_user_paused_in_channel(self, id: str, channel_id: str, days: int): | ||||||
|  |         """Sets a user as paused in a channel""" | ||||||
|  |         # Deactivate the user in the channel first | ||||||
|  |         self.set_user_active_in_channel(id, channel_id, False) | ||||||
|  | 
 | ||||||
|  |         # Set the reactivate time the number of days in the future | ||||||
|  |         ts = datetime.now() + timedelta(days=days) | ||||||
|  |         self._set_user_channel_prop( | ||||||
|  |             id, channel_id, _Key.REACTIVATE, datetime_to_ts(ts)) | ||||||
|  | 
 | ||||||
|  |     def reactivate_users(self, channel_id: str): | ||||||
|  |         """Reactivate any users who've passed their reactivation time on this channel""" | ||||||
|  |         with self._safe_wrap() as safe_state: | ||||||
|  |             for user in safe_state._users.values(): | ||||||
|  |                 channels = user.get(_Key.CHANNELS, {}) | ||||||
|  |                 channel = channels.get(str(channel_id), {}) | ||||||
|  |                 if channel and not channel[_Key.ACTIVE]: | ||||||
|  |                     reactivate = channel.get(_Key.REACTIVATE, None) | ||||||
|  |                     # Check if we've gone past the reactivation time and re-activate | ||||||
|  |                     if reactivate and datetime.now() > ts_to_datetime(reactivate): | ||||||
|  |                         channel[_Key.ACTIVE] = True | ||||||
|  | 
 | ||||||
|     @property |     @property | ||||||
|     def dict_internal(self) -> dict: |     def dict_internal_copy(self) -> dict: | ||||||
|         """Only to be used to get the internal dict as a copy""" |         """Only to be used to get the internal dict as a copy""" | ||||||
|         return copy.deepcopy(self._dict) |         return copy.deepcopy(self._dict) | ||||||
| 
 | 
 | ||||||
|  |     @property | ||||||
|  |     def _history(self) -> dict[str]: | ||||||
|  |         return self._dict[_Key.HISTORY] | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def _users(self) -> dict[str]: | ||||||
|  |         return self._dict[_Key.USERS] | ||||||
|  | 
 | ||||||
|  |     def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value): | ||||||
|  |         """Set a user channel property helper""" | ||||||
|  |         with self._safe_wrap() as safe_state: | ||||||
|  |             # Dive in | ||||||
|  |             user = safe_state._users.get(str(id), {}) | ||||||
|  |             channels = user.get(_Key.CHANNELS, {}) | ||||||
|  |             channel = channels.get(str(channel_id), {}) | ||||||
|  | 
 | ||||||
|  |             # Set the value | ||||||
|  |             channel[key] = value | ||||||
|  | 
 | ||||||
|  |             # Unroll | ||||||
|  |             channels[str(channel_id)] = channel | ||||||
|  |             user[_Key.CHANNELS] = channels | ||||||
|  |             safe_state._users[str(id)] = user | ||||||
|  | 
 | ||||||
|  |     @contextmanager | ||||||
|  |     def _safe_wrap(self): | ||||||
|  |         """Safely run any function wrapped in a validate""" | ||||||
|  |         # Wrap in a temporary state to validate first to prevent corruption | ||||||
|  |         tmp_state = State(self._dict) | ||||||
|  |         try: | ||||||
|  |             yield tmp_state | ||||||
|  |         finally: | ||||||
|  |             # Validate and then overwrite our dict with the new one | ||||||
|  |             tmp_state.validate() | ||||||
|  |             self._dict = tmp_state._dict | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def _migrate(dict: dict): | def _migrate(dict: dict): | ||||||
|     """Migrate a dict through versions""" |     """Migrate a dict through versions""" | ||||||
|  | @ -254,4 +297,4 @@ def load_from_file(file: str) -> State: | ||||||
| 
 | 
 | ||||||
| def save_to_file(state: State, file: str): | def save_to_file(state: State, file: str): | ||||||
|     """Saves the state out to a file""" |     """Saves the state out to a file""" | ||||||
|     files.save(file, state.dict_internal) |     files.save(file, state.dict_internal_copy) | ||||||
|  |  | ||||||
|  | @ -55,11 +55,11 @@ def test_channeljoin(): | ||||||
|         assert not st.get_user_active_in_channel(1, "2") |         assert not st.get_user_active_in_channel(1, "2") | ||||||
| 
 | 
 | ||||||
|         st = state.load_from_file(path) |         st = state.load_from_file(path) | ||||||
|         st.set_use_active_in_channel(1, "2", True) |         st.set_user_active_in_channel(1, "2", True) | ||||||
|         state.save_to_file(st, path) |         state.save_to_file(st, path) | ||||||
| 
 | 
 | ||||||
|         st = state.load_from_file(path) |         st = state.load_from_file(path) | ||||||
|         assert st.get_user_active_in_channel(1, "2") |         assert st.get_user_active_in_channel(1, "2") | ||||||
| 
 | 
 | ||||||
|         st.set_use_active_in_channel(1, "2", False) |         st.set_user_active_in_channel(1, "2", False) | ||||||
|         assert not st.get_user_active_in_channel(1, "2") |         assert not st.get_user_active_in_channel(1, "2") | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue