diff --git a/matchy.py b/matchy.py index 9b24a8a..7b6a1ae 100644 --- a/matchy.py +++ b/matchy.py @@ -5,13 +5,9 @@ import logging import discord from discord.ext import commands import os -from matchy.state import load_from_file import matchy.cogs.matcher import matchy.cogs.owner -_STATE_FILE = ".matchy/state.json" -state = load_from_file(_STATE_FILE) - logger = logging.getLogger("matchy") logger.setLevel(logging.INFO) @@ -24,8 +20,8 @@ bot = commands.Bot(command_prefix='$', @bot.event async def setup_hook(): - await bot.add_cog(matchy.cogs.matcher.MatcherCog(bot, state)) - await bot.add_cog(matchy.cogs.owner.OwnerCog(bot, state)) + await bot.add_cog(matchy.cogs.matcher.MatcherCog(bot)) + await bot.add_cog(matchy.cogs.owner.OwnerCog(bot)) @bot.event diff --git a/matchy/cogs/matcher.py b/matchy/cogs/matcher.py index 4252f0a..420c60e 100644 --- a/matchy/cogs/matcher.py +++ b/matchy/cogs/matcher.py @@ -20,15 +20,14 @@ logger.setLevel(logging.INFO) class MatcherCog(commands.Cog): - def __init__(self, bot: commands.Bot, state: State): + def __init__(self, bot: commands.Bot): self.bot = bot - self.state = state @commands.Cog.listener() async def on_ready(self): """Bot is ready and connected""" self.run_hourly_tasks.start() - self.bot.add_dynamic_items(DynamicGroupButton) + self.bot.add_dynamic_items(MatchDynamicButton) activity = discord.Game("/join") await self.bot.change_presence(status=discord.Status.online, activity=activity) logger.info("Bot is up and ready!") @@ -39,7 +38,7 @@ class MatcherCog(commands.Cog): logger.info("Handling /join in %s %s from %s", interaction.guild.name, interaction.channel, interaction.user.name) - self.state.set_user_active_in_channel( + state.State.set_user_active_in_channel( interaction.user.id, interaction.channel.id) await interaction.response.send_message( strings.acknowledgement(interaction.user.mention) + "\n" @@ -52,7 +51,7 @@ class MatcherCog(commands.Cog): logger.info("Handling /leave in %s %s from %s", interaction.guild.name, interaction.channel, interaction.user.name) - self.state.set_user_active_in_channel( + state.State.set_user_active_in_channel( interaction.user.id, interaction.channel.id, False) await interaction.response.send_message( strings.user_leave(interaction.user.mention), ephemeral=True, silent=True) @@ -67,7 +66,7 @@ class MatcherCog(commands.Cog): if days is None: # Default to a week days = 7 until = datetime.now() + timedelta(days=days) - self.state.set_user_paused_in_channel( + state.State.set_user_paused_in_channel( interaction.user.id, interaction.channel.id, until) await interaction.response.send_message( strings.acknowledgement(interaction.user.mention) + "\n" @@ -80,8 +79,7 @@ class MatcherCog(commands.Cog): logger.info("Handling /list command in %s %s from %s", interaction.guild.name, interaction.channel, interaction.user.name) - (matchees, paused) = matching.get_matchees_in_channel( - self.state, interaction.channel) + (matchees, paused) = matching.get_matchees_in_channel(interaction.channel) msg = "" @@ -93,7 +91,7 @@ class MatcherCog(commands.Cog): mentions = [m.mention for m in paused] msg += "\n" + strings.paused_matchees(mentions) + "\n" - tasks = self.state.get_channel_match_tasks(interaction.channel.id) + tasks = state.State.get_channel_match_tasks(interaction.channel.id) for (day, hour, min) in tasks: next_run = util.get_next_datetime(day, hour) msg += "\n" + strings.scheduled(next_run, min) @@ -125,13 +123,13 @@ class MatcherCog(commands.Cog): channel_id = str(interaction.channel.id) # Bail if not a matcher - if not self.state.get_user_has_scope(interaction.user.id, AuthScope.MATCHER): + if not state.State.get_user_has_scope(interaction.user.id, AuthScope.MATCHER): await interaction.response.send_message(strings.need_matcher_scope(), ephemeral=True, silent=True) return # Add the scheduled task and save - self.state.set_channel_match_task( + state.State.set_channel_match_task( channel_id, members_min, weekday, hour) # Let the user know what happened @@ -139,23 +137,26 @@ class MatcherCog(commands.Cog): channel_id, members_min, weekday, hour) next_run = util.get_next_datetime(weekday, hour) + view = discord.ui.View(timeout=None) + view.add_item(ScheduleButton()) + await interaction.response.send_message( strings.scheduled_success(next_run), - ephemeral=True, silent=True) + ephemeral=True, silent=True, view=view) @app_commands.command(description="Cancel all scheduled matches in this channel") @commands.guild_only() async def cancel(self, interaction: discord.Interaction): """Cancel scheduled matches in this channel""" # Bail if not a matcher - if not self.state.get_user_has_scope(interaction.user.id, AuthScope.MATCHER): + if not state.State.get_user_has_scope(interaction.user.id, AuthScope.MATCHER): await interaction.response.send_message(strings.need_matcher_scope(), ephemeral=True, silent=True) return # Add the scheduled task and save channel_id = str(interaction.channel.id) - self.state.remove_channel_match_tasks(channel_id) + state.State.remove_channel_match_tasks(channel_id) await interaction.response.send_message( strings.cancelled(), ephemeral=True, silent=True) @@ -176,7 +177,7 @@ class MatcherCog(commands.Cog): # Grab the groups groups = matching.active_members_to_groups( - self.state, interaction.channel, members_min) + interaction.channel, members_min) # Let the user know when there's nobody to match if not groups: @@ -189,11 +190,11 @@ class MatcherCog(commands.Cog): msg = strings.generated_groups(groups_list) view = discord.utils.MISSING - if self.state.get_user_has_scope(interaction.user.id, AuthScope.MATCHER): + if state.State.get_user_has_scope(interaction.user.id, AuthScope.MATCHER): # Otherwise set up the button msg += "\n\n" + strings.click_to_match() + "\n" view = discord.ui.View(timeout=None) - view.add_item(DynamicGroupButton(members_min)) + view.add_item(MatchDynamicButton(members_min)) else: # Let a non-matcher know why they don't have the button msg += "\n\n" + strings.need_matcher_to_post() @@ -206,12 +207,12 @@ class MatcherCog(commands.Cog): async def run_hourly_tasks(self): """Run any hourly tasks we have""" - for (channel, min) in self.state.get_active_match_tasks(): + for (channel, min) in state.State.get_active_match_tasks(): logger.info("Scheduled match task triggered in %s", channel) msg_channel = self.bot.get_channel(int(channel)) - await match_groups_in_channel(self.state, msg_channel, min) + await match_groups_in_channel(msg_channel, min) - for (channel, _) in self.state.get_active_match_tasks(datetime.now() + timedelta(days=1)): + for (channel, _) in state.State.get_active_match_tasks(datetime.now() + timedelta(days=1)): logger.info("Reminding about scheduled task in %s", channel) msg_channel = self.bot.get_channel(int(channel)) await msg_channel.send(strings.reminder()) @@ -222,7 +223,7 @@ _MATCH_BUTTON_CUSTOM_ID_VERSION = 1 _MATCH_BUTTON_CUSTOM_ID_PREFIX = f'match:v{_MATCH_BUTTON_CUSTOM_ID_VERSION}:' -class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button], +class MatchDynamicButton(discord.ui.DynamicItem[discord.ui.Button], template=_MATCH_BUTTON_CUSTOM_ID_PREFIX + r'min:(?P[0-9]+)'): """ Describes a simple button that lets the user trigger a match @@ -237,7 +238,6 @@ class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button], ) ) self.min: int = min - self.state = state.load_from_file() # This is called when the button is clicked and the custom_id matches the template. @classmethod @@ -256,10 +256,10 @@ class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button], await intrctn.response.send_message(content=strings.matching(), ephemeral=True) # Perform the match - await match_groups_in_channel(self.state, intrctn.channel, self.min) + await match_groups_in_channel(intrctn.channel, self.min) -async def match_groups_in_channel(state: State, channel: discord.channel, min: int): +async def match_groups_in_channel(channel: discord.channel, min: int): """Match up the groups in a given channel""" groups = matching.active_members_to_groups(state, channel, min) @@ -277,6 +277,39 @@ async def match_groups_in_channel(state: State, channel: discord.channel, min: i # Close off with a message await channel.send(strings.matching_done()) # Save the groups to the history - state.log_groups(groups) + state.State.log_groups(groups) logger.info("Done! Matched into %s groups.", len(groups)) + + +class ScheduleButton(discord.ui.Button): + """ + Describes a simple button that lets the user post the schedule to the channel + """ + + def __init__(self) -> None: + super().__init__( + label='Post schedule', + style=discord.ButtonStyle.blurple + ) + + async def callback(self, interaction: discord.Interaction) -> None: + """Post about the current schedule when requested""" + logger.info("Handling schedule button press byuser %s from %s in #%s", + interaction.user, interaction.guild.name, interaction.channel.name) + + tasks = state.State.get_channel_match_tasks(interaction.channel.id) + + msg = f"{interaction.user.mention} added a match to this channel!\n" + msg += "Current scheduled matches are:" + + if tasks: + for (day, hour, min) in tasks: + next_run = util.get_next_datetime(day, hour) + date_str = util.datetime_as_discord_time(next_run) + msg += f"\n{date_str} with {min} members per group\n" + + await interaction.channel.send(msg) + await interaction.response.send_message(content="Posted :)", ephemeral=True) + else: + await interaction.response.send_message(content="No scheduled matches to post :(", ephemeral=True) diff --git a/matchy/cogs/owner.py b/matchy/cogs/owner.py index d421251..332622d 100644 --- a/matchy/cogs/owner.py +++ b/matchy/cogs/owner.py @@ -3,16 +3,15 @@ Owner bot cog """ import logging from discord.ext import commands -from matchy.state import State, AuthScope +import matchy.state as state logger = logging.getLogger("owner") logger.setLevel(logging.INFO) class OwnerCog(commands.Cog): - def __init__(self, bot: commands.Bot, state: State): + def __init__(self, bot: commands.Bot): self._bot = bot - self._state = state @commands.command() @commands.dm_only() @@ -48,7 +47,7 @@ class OwnerCog(commands.Cog): Grant the matcher scope to a given user """ if user.isdigit(): - self._state.set_user_scope(str(user), AuthScope.MATCHER) + state.State.set_user_scope(str(user), state.AuthScope.MATCHER) logger.info("Granting user %s matcher scope", user) await ctx.reply("Done!", ephemeral=True) else: diff --git a/matchy/matching.py b/matchy/matching.py index 25292d4..465dac5 100644 --- a/matchy/matching.py +++ b/matchy/matching.py @@ -5,6 +5,7 @@ from datetime import datetime from typing import Protocol, runtime_checkable from matchy.state import State, ts_to_datetime import matchy.util as util +import matchy.state as state class _ScoreFactors(int): @@ -95,7 +96,6 @@ def get_member_group_eligibility_score(member: Member, def attempt_create_groups(matchees: list[Member], - state: State, oldest_relevant_ts: datetime, per_group: int) -> tuple[bool, list[list[Member]]]: """History aware group matching""" @@ -110,10 +110,10 @@ def attempt_create_groups(matchees: list[Member], while matchees_left: # Get the next matchee to place matchee = matchees_left.pop() - matchee_matches = state.get_user_matches(matchee.id) + matchee_matches = state.State.get_user_matches(matchee.id) relevant_matches = [int(id) for id, ts in matchee_matches.items() - if ts_to_datetime(ts) >= oldest_relevant_ts] + if state.ts_to_datetime(ts) >= oldest_relevant_ts] # Try every single group from the current group onwards # Progressing through the groups like this ensures we slowly fill them up with compatible people @@ -143,7 +143,6 @@ def attempt_create_groups(matchees: list[Member], def members_to_groups(matchees: list[Member], - state: State, per_group: int = 3, allow_fallback: bool = False) -> list[list[Member]]: """Generate the groups from the set of matchees""" @@ -155,14 +154,14 @@ def members_to_groups(matchees: list[Member], return [] # Walk from the start of history until now trying to match up groups - for oldest_relevant_datetime in state.get_history_timestamps(matchees) + [datetime.now()]: + for oldest_relevant_datetime in state.State.get_history_timestamps(matchees) + [datetime.now()]: # Attempt with each starting matchee for shifted_matchees in util.iterate_all_shifts(matchees): attempts += 1 groups = attempt_create_groups( - shifted_matchees, state, oldest_relevant_datetime, per_group) + shifted_matchees, oldest_relevant_datetime, per_group) # Fail the match if our groups aren't big enough if num_groups <= 1 or (groups and all(len(g) >= per_group for g in groups)): @@ -179,19 +178,21 @@ def members_to_groups(matchees: list[Member], assert False -def get_matchees_in_channel(state: State, channel: discord.channel): +def get_matchees_in_channel(channel: discord.channel): """Fetches the matchees in a channel""" # Reactivate any unpaused users - state.reactivate_users(channel.id) + state.State.reactivate_users(channel.id) # Gather up the prospective matchees - active = [m for m in channel.members if state.get_user_active_in_channel(m.id, channel.id)] - paused = [m for m in channel.members if state.get_user_paused_in_channel(m.id, channel.id)] + active = [m for m in channel.members if state.State.get_user_active_in_channel( + m.id, channel.id)] + paused = [m for m in channel.members if state.State.get_user_paused_in_channel( + m.id, channel.id)] return (active, paused) -def active_members_to_groups(state: State, channel: discord.channel, min_members: int): +def active_members_to_groups(channel: discord.channel, min_members: int): """Helper to create groups from channel members""" # Gather up the prospective matchees - matchees = get_matchees_in_channel(state, channel) + matchees = get_matchees_in_channel(channel) # Create our groups! - return members_to_groups(matchees, state, min_members, allow_fallback=True) + return members_to_groups(matchees, min_members, allow_fallback=True) diff --git a/matchy/state.py b/matchy/state.py index 8803d6b..d498af6 100644 --- a/matchy/state.py +++ b/matchy/state.py @@ -15,7 +15,6 @@ import matchy.util as util logger = logging.getLogger("state") logger.setLevel(logging.INFO) - # Warning: Changing any of the below needs proper thought to ensure backwards compatibility _VERSION = 4 @@ -193,7 +192,7 @@ def _save(file: str, content: dict): shutil.move(intermediate, file) -class State(): +class _State(): def __init__(self, data: dict, file: str | None = None): """Copy the data, migrate if needed, and validate""" self._dict = copy.deepcopy(data) @@ -216,7 +215,7 @@ class State(): """ @wraps(func) def inner(self, *args, **kwargs): - tmp = State(self._dict, self._file) + tmp = _State(self._dict, self._file) func(tmp, *args, **kwargs) _SCHEMA.validate(tmp._dict) if tmp._file: @@ -380,11 +379,15 @@ class State(): return self._dict[_Key.TASKS] -def load_from_file(file: str) -> State: +def load_from_file(file: str) -> _State: """ Load the state from a files """ loaded = _load(file) if os.path.isfile(file) else _EMPTY_DICT - st = State(loaded, file) + st = _State(loaded, file) _save(file, st._dict) return st + + +_STATE_FILE = ".matchy/state.json" +State = load_from_file(_STATE_FILE) diff --git a/tests/matching_test.py b/tests/matching_test.py index b893cfb..6a15e86 100644 --- a/tests/matching_test.py +++ b/tests/matching_test.py @@ -11,6 +11,12 @@ import itertools from datetime import datetime, timedelta +@pytest.fixture(autouse=True) +def clean_state(): + """Ensure every single one of these tests has a clean state""" + state.State = state._State(state._EMPTY_DICT) + + def test_protocols(): """Verify the protocols we're using match the discord ones""" assert isinstance(discord.Member, matching.Member) @@ -59,16 +65,16 @@ class Member(): return self._id -def members_to_groups_validate(matchees: list[Member], tmp_state: state.State, per_group: int): +def members_to_groups_validate(matchees: list[Member], per_group: int): """Inner function to validate the main output of the groups function""" - groups = matching.members_to_groups(matchees, tmp_state, per_group) + groups = matching.members_to_groups(matchees, per_group) # We should always have one group assert len(groups) # Log the groups to history # This will validate the internals - tmp_state.log_groups(groups) + state.State.log_groups(groups) # Ensure each group contains within the bounds of expected members for group in groups: @@ -96,8 +102,7 @@ def members_to_groups_validate(matchees: list[Member], tmp_state: state.State, p ], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"]) def test_members_to_groups_no_history(matchees, per_group): """Test simple group matching works""" - tmp_state = state.State(state._EMPTY_DICT) - members_to_groups_validate(matchees, tmp_state, per_group) + members_to_groups_validate(matchees, per_group) def items_found_in_lists(list_of_lists, items): @@ -328,13 +333,12 @@ def items_found_in_lists(list_of_lists, items): ], ids=['simple_history', 'fallback', 'example_1', 'example_2', 'example_3']) def test_unique_regressions(history_data, matchees, per_group, checks): """Test a bunch of unqiue failures that happened in the past""" - tmp_state = state.State(state._EMPTY_DICT) # Replay the history for d in history_data: - tmp_state.log_groups(d["groups"], d["ts"]) + state.State.log_groups(d["groups"], d["ts"]) - groups = members_to_groups_validate(matchees, tmp_state, per_group) + groups = members_to_groups_validate(matchees, per_group) # Run the custom validate functions for check in checks: @@ -380,28 +384,25 @@ def test_stess_random_groups(per_group, num_members, num_history): member.roles = [Role(i) for i in rand.sample(range(1, 8), 3)] # For each history item match up groups and log those - cumulative_state = state.State(state._EMPTY_DICT) for i in range(num_history+1): # Grab the num of members and replay rand.shuffle(possible_members) members = copy.deepcopy(possible_members[:num_members]) - groups = members_to_groups_validate( - members, cumulative_state, per_group) - cumulative_state.log_groups( + groups = members_to_groups_validate(members, per_group) + state.State.log_groups( groups, datetime.now() - timedelta(days=num_history-i)) def test_auth_scopes(): - tmp_state = state.State(state._EMPTY_DICT) id = "1" - assert not tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER) + assert not state.State.get_user_has_scope(id, state.AuthScope.MATCHER) id = "2" - tmp_state.set_user_scope(id, state.AuthScope.MATCHER) - assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER) + state.State.set_user_scope(id, state.AuthScope.MATCHER) + assert state.State.get_user_has_scope(id, state.AuthScope.MATCHER) # Validate the state by constucting a new one - _ = state.State(tmp_state._dict) + _ = state._State(state.State._dict) diff --git a/tests/owner_cog_test.py b/tests/owner_cog_test.py index 009faa8..6f11058 100644 --- a/tests/owner_cog_test.py +++ b/tests/owner_cog_test.py @@ -2,7 +2,6 @@ import discord import discord.ext.commands as commands import pytest import pytest_asyncio -import matchy.state as state import discord.ext.test as dpytest from matchy.cogs.owner import OwnerCog @@ -20,7 +19,7 @@ async def bot(): b = commands.Bot(command_prefix="$", intents=intents) await b._async_setup_hook() - await b.add_cog(OwnerCog(b, state.State(state._EMPTY_DICT))) + await b.add_cog(OwnerCog(b)) dpytest.configure(b) yield b await dpytest.empty_queue()