diff --git a/README.md b/README.md index c729f5e..142b3c8 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,6 @@ Matchy is configured by a `config.json` file that takes this format: ## TODO * Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html) -* Implement /pause to pause a user for a little while * Move more constants to the config * Add scheduling functionality * Fix logging in some sub files (doesn't seem to actually be output?) diff --git a/py/matchy.py b/py/matchy.py index 7f5fc8e..62bdaaa 100755 --- a/py/matchy.py +++ b/py/matchy.py @@ -75,7 +75,7 @@ async def close(ctx: commands.Context): @bot.tree.command(description="Join the matchees for this channel") @commands.guild_only() async def join(interaction: discord.Interaction): - State.set_use_active_in_channel( + State.set_user_active_in_channel( interaction.user.id, interaction.channel.id) state.save_to_file(State, STATE_FILE) await interaction.response.send_message( @@ -87,13 +87,26 @@ async def join(interaction: discord.Interaction): @bot.tree.command(description="Leave the matchees for this channel") @commands.guild_only() async def leave(interaction: discord.Interaction): - State.set_use_active_in_channel( + State.set_user_active_in_channel( interaction.user.id, interaction.channel.id, False) state.save_to_file(State, STATE_FILE) await interaction.response.send_message( f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True) +@bot.tree.command(description="Pause your matching in this channel for a number of days") +@commands.guild_only() +@app_commands.describe(days="Days to pause for (defaults to 7)") +async def pause(interaction: discord.Interaction, days: int = None): + if not days: # Default to a week + days = 7 + State.set_user_paused_in_channel( + interaction.user.id, interaction.channel.id, days) + state.save_to_file(State, STATE_FILE) + await interaction.response.send_message( + f"Sure thing {interaction.user.mention}. Paused you for {days} days!", ephemeral=True, silent=True) + + @bot.tree.command(description="List the matchees for this channel") @commands.guild_only() async def list(interaction: discord.Interaction): @@ -195,6 +208,9 @@ class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button], def get_matchees_in_channel(channel: discord.channel): """Fetches the matchees in a channel""" + # Reactivate any unpaused users + State.reactivate_users(channel.id) + # Gather up the prospective matchees return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)] diff --git a/py/state.py b/py/state.py index 344a00f..859b999 100644 --- a/py/state.py +++ b/py/state.py @@ -1,11 +1,12 @@ """Store bot state""" import os -from datetime import datetime +from datetime import datetime, timedelta from schema import Schema, And, Use, Optional from typing import Protocol import files import copy import logging +from contextlib import contextmanager logger = logging.getLogger("state") logger.setLevel(logging.INFO) @@ -83,6 +84,8 @@ _SCHEMA = Schema( Optional(str): { # Whether the user is signed up in this channel _Key.ACTIVE: And(Use(bool)), + # A timestamp for when to re-activate the user + Optional(_Key.REACTIVATE): And(Use(str)), } } } @@ -106,24 +109,21 @@ class Member(Protocol): def ts_to_datetime(ts: str) -> datetime: - """Convert a ts to datetime using the internal format""" + """Convert a string ts to datetime using the internal format""" return datetime.strptime(ts, _TIME_FORMAT) +def datetime_to_ts(ts: datetime) -> str: + """Convert a datetime to a string ts using the internal format""" + return datetime.strftime(ts, _TIME_FORMAT) + + class State(): def __init__(self, data: dict = _EMPTY_DICT): """Initialise and validate the state""" self.validate(data) self._dict = copy.deepcopy(data) - @property - def _history(self) -> dict[str]: - return self._dict[_Key.HISTORY] - - @property - def _users(self) -> dict[str]: - return self._dict[_Key.USERS] - def validate(self, dict: dict = None): """Initialise and validate a state dict""" if not dict: @@ -138,54 +138,50 @@ class State(): def get_user_matches(self, id: int) -> list[int]: return self._users.get(str(id), {}).get(_Key.MATCHES, {}) - def log_groups(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None: + def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None: """Log the groups""" - tmp_state = State(self._dict) - ts = datetime.strftime(ts, _TIME_FORMAT) + ts = datetime_to_ts(ts or datetime.now()) + with self._safe_wrap() as safe_state: + # Grab or create the hitory item for this set of groups + history_item = {} + safe_state._history[ts] = history_item + history_item_groups = [] + history_item[_Key.GROUPS] = history_item_groups - # Grab or create the hitory item for this set of groups - history_item = {} - tmp_state._history[ts] = history_item - history_item_groups = [] - history_item[_Key.GROUPS] = history_item_groups + for group in groups: - for group in groups: + # Add the group data + history_item_groups.append({ + _Key.MEMBERS: [m.id for m in group] + }) - # Add the group data - history_item_groups.append({ - _Key.MEMBERS: [m.id for m in group] - }) + # Update the matchee data with the matches + for m in group: + matchee = safe_state._users.get(str(m.id), {}) + matchee_matches = matchee.get(_Key.MATCHES, {}) - # Update the matchee data with the matches - for m in group: - matchee = tmp_state._users.get(str(m.id), {}) - matchee_matches = matchee.get(_Key.MATCHES, {}) + for o in (o for o in group if o.id != m.id): + matchee_matches[str(o.id)] = ts - for o in (o for o in group if o.id != m.id): - matchee_matches[str(o.id)] = ts - - matchee[_Key.MATCHES] = matchee_matches - tmp_state._users[str(m.id)] = matchee - - # Validate before storing the result - tmp_state.validate() - self._dict = tmp_state._dict + matchee[_Key.MATCHES] = matchee_matches + safe_state._users[str(m.id)] = matchee def set_user_scope(self, id: str, scope: str, value: bool = True): """Add an auth scope to a user""" - # Dive in - user = self._users.get(str(id), {}) - scopes = user.get(_Key.SCOPES, []) + with self._safe_wrap() as safe_state: + # Dive in + user = safe_state._users.get(str(id), {}) + scopes = user.get(_Key.SCOPES, []) - # Set the value - if value and scope not in scopes: - scopes.append(scope) - elif not value and scope in scopes: - scopes.remove(scope) + # Set the value + if value and scope not in scopes: + scopes.append(scope) + elif not value and scope in scopes: + scopes.remove(scope) - # Roll out - user[_Key.SCOPES] = scopes - self._users[id] = user + # Roll out + user[_Key.SCOPES] = scopes + safe_state._users[str(id)] = user def get_user_has_scope(self, id: str, scope: str) -> bool: """ @@ -196,20 +192,9 @@ class State(): scopes = user.get(_Key.SCOPES, []) return AuthScope.OWNER in scopes or scope in scopes - def set_use_active_in_channel(self, id: str, channel_id: str, active: bool = True): + def set_user_active_in_channel(self, id: str, channel_id: str, active: bool = True): """Set a user as active (or not) on a given channel""" - # Dive in - user = self._users.get(str(id), {}) - channels = user.get(_Key.CHANNELS, {}) - channel = channels.get(str(channel_id), {}) - - # Set the value - channel[_Key.ACTIVE] = active - - # Unroll - channels[str(channel_id)] = channel - user[_Key.CHANNELS] = channels - self._users[str(id)] = user + self._set_user_channel_prop(id, channel_id, _Key.ACTIVE, active) def get_user_active_in_channel(self, id: str, channel_id: str) -> bool: """Get a users active channels""" @@ -217,11 +202,69 @@ class State(): channels = user.get(_Key.CHANNELS, {}) return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)] + def set_user_paused_in_channel(self, id: str, channel_id: str, days: int): + """Sets a user as paused in a channel""" + # Deactivate the user in the channel first + self.set_user_active_in_channel(id, channel_id, False) + + # Set the reactivate time the number of days in the future + ts = datetime.now() + timedelta(days=days) + self._set_user_channel_prop( + id, channel_id, _Key.REACTIVATE, datetime_to_ts(ts)) + + def reactivate_users(self, channel_id: str): + """Reactivate any users who've passed their reactivation time on this channel""" + with self._safe_wrap() as safe_state: + for user in safe_state._users.values(): + channels = user.get(_Key.CHANNELS, {}) + channel = channels.get(str(channel_id), {}) + if channel and not channel[_Key.ACTIVE]: + reactivate = channel.get(_Key.REACTIVATE, None) + # Check if we've gone past the reactivation time and re-activate + if reactivate and datetime.now() > ts_to_datetime(reactivate): + channel[_Key.ACTIVE] = True + @property - def dict_internal(self) -> dict: + def dict_internal_copy(self) -> dict: """Only to be used to get the internal dict as a copy""" return copy.deepcopy(self._dict) + @property + def _history(self) -> dict[str]: + return self._dict[_Key.HISTORY] + + @property + def _users(self) -> dict[str]: + return self._dict[_Key.USERS] + + def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value): + """Set a user channel property helper""" + with self._safe_wrap() as safe_state: + # Dive in + user = safe_state._users.get(str(id), {}) + channels = user.get(_Key.CHANNELS, {}) + channel = channels.get(str(channel_id), {}) + + # Set the value + channel[key] = value + + # Unroll + channels[str(channel_id)] = channel + user[_Key.CHANNELS] = channels + safe_state._users[str(id)] = user + + @contextmanager + def _safe_wrap(self): + """Safely run any function wrapped in a validate""" + # Wrap in a temporary state to validate first to prevent corruption + tmp_state = State(self._dict) + try: + yield tmp_state + finally: + # Validate and then overwrite our dict with the new one + tmp_state.validate() + self._dict = tmp_state._dict + def _migrate(dict: dict): """Migrate a dict through versions""" @@ -254,4 +297,4 @@ def load_from_file(file: str) -> State: def save_to_file(state: State, file: str): """Saves the state out to a file""" - files.save(file, state.dict_internal) + files.save(file, state.dict_internal_copy) diff --git a/py/state_test.py b/py/state_test.py index bd97648..b79a426 100644 --- a/py/state_test.py +++ b/py/state_test.py @@ -55,11 +55,11 @@ def test_channeljoin(): assert not st.get_user_active_in_channel(1, "2") st = state.load_from_file(path) - st.set_use_active_in_channel(1, "2", True) + st.set_user_active_in_channel(1, "2", True) state.save_to_file(st, path) st = state.load_from_file(path) assert st.get_user_active_in_channel(1, "2") - st.set_use_active_in_channel(1, "2", False) + st.set_user_active_in_channel(1, "2", False) assert not st.get_user_active_in_channel(1, "2")