diff --git a/README.md b/README.md index 4f89f6f..0610f2b 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,10 @@ Allows a matcher to set a weekly schedule for matches in the channel, cancel can Only usable by users with the `owner` scope. Only usable in a DM with the bot user. #### $sync and $close -Syncs bot commands and reloads the state file, or closes down the bot. +Syncs bot commands or closes down the bot. + +#### $grant [user: int] +Grant a given user the matcher scope to allow them to use `/match` and `/schedule`. ## Development Current development is on Linux, though running on Mac or Windows should work fine. @@ -94,7 +97,6 @@ State is stored locally in a `state.json` file. This will be created by the bot. ## TODO * Implement better tests to the discordy parts of the codebase -* Rethink the matcher scope, seems like maybe this could be simpler or removed * Implement a .json file upgrade test * Track if matches were successful * Improve the weirdo diff --git a/py/cogs/matchy_cog.py b/py/cogs/matchy_cog.py index b7dd0ee..95cc1c7 100644 --- a/py/cogs/matchy_cog.py +++ b/py/cogs/matchy_cog.py @@ -9,7 +9,7 @@ from datetime import datetime, timedelta, time import cogs.match_button as match_button import matching -from state import State, save_to_file, AuthScope +from state import State, AuthScope import util logger = logging.getLogger("cog") @@ -38,7 +38,6 @@ class MatchyCog(commands.Cog): self.state.set_user_active_in_channel( interaction.user.id, interaction.channel.id) - save_to_file(self.state) await interaction.response.send_message( f"Roger roger {interaction.user.mention}!\n" + f"Added you to {interaction.channel.mention}!", @@ -52,7 +51,6 @@ class MatchyCog(commands.Cog): self.state.set_user_active_in_channel( interaction.user.id, interaction.channel.id, False) - save_to_file(self.state) await interaction.response.send_message( f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True) @@ -68,7 +66,6 @@ class MatchyCog(commands.Cog): until = datetime.now() + timedelta(days=days) self.state.set_user_paused_in_channel( interaction.user.id, interaction.channel.id, until) - save_to_file(self.state) await interaction.response.send_message( f"Sure thing {interaction.user.mention}!\n" + f"Paused you until {util.format_day(until)}!", @@ -127,7 +124,6 @@ class MatchyCog(commands.Cog): # Add the scheduled task and save success = self.state.set_channel_match_task( channel_id, members_min, weekday, hour, not cancel) - save_to_file(self.state) # Let the user know what happened if not cancel: diff --git a/py/cogs/owner_cog.py b/py/cogs/owner_cog.py index 06b677b..8590d7e 100644 --- a/py/cogs/owner_cog.py +++ b/py/cogs/owner_cog.py @@ -3,22 +3,27 @@ Owner bot cog """ import logging from discord.ext import commands +from state import State, AuthScope logger = logging.getLogger("owner") logger.setLevel(logging.INFO) class OwnerCog(commands.Cog): - def __init__(self, bot: commands.Bot): - self.bot = bot + def __init__(self, bot: commands.Bot, state: State): + self._bot = bot + self._state = state @commands.command() @commands.dm_only() @commands.is_owner() async def sync(self, ctx: commands.Context): - """Handle sync command""" + """ + Sync the bot commands + You get rate limited if you do this too often so it's better to keep it on command + """ msg = await ctx.reply(content="Syncing commands...", ephemeral=True) - synced = await self.bot.tree.sync() + synced = await self._bot.tree.sync() logger.info("Synced %s command(s)", len(synced)) await msg.edit(content="Done!") @@ -26,7 +31,25 @@ class OwnerCog(commands.Cog): @commands.dm_only() @commands.is_owner() async def close(self, ctx: commands.Context): - """Handle restart command""" + """ + Handle close command + Shuts down the bot when needed + """ await ctx.reply("Closing bot...", ephemeral=True) logger.info("Closing down the bot") - await self.bot.close() + await self._bot.close() + + @commands.command() + @commands.dm_only() + @commands.is_owner() + async def grant(self, ctx: commands.Context, user: str): + """ + Handle grant command + Grant the matcher scope to a given user + """ + if user.isdigit(): + self._state.set_user_scope(str(user), AuthScope.MATCHER) + logger.info("Granting user %s matcher scope", user) + await ctx.reply("Done!", ephemeral=True) + else: + await ctx.reply("Likely not a user...", ephemeral=True) diff --git a/py/matching.py b/py/matching.py index 49aa4d3..1f3b485 100644 --- a/py/matching.py +++ b/py/matching.py @@ -3,7 +3,7 @@ import logging import discord from datetime import datetime, timedelta from typing import Protocol, runtime_checkable -from state import State, save_to_file, ts_to_datetime +from state import State, ts_to_datetime import util import config @@ -166,7 +166,7 @@ def iterate_all_shifts(list: list): def members_to_groups(matchees: list[Member], - state: State = State(), + state: State, per_group: int = 3, allow_fallback: bool = False) -> list[list[Member]]: """Generate the groups from the set of matchees""" @@ -224,7 +224,6 @@ async def match_groups_in_channel(state: State, channel: discord.channel, min: i # Save the groups to the history state.log_groups(groups) - save_to_file(state) logger.info("Done! Matched into %s groups.", len(groups)) diff --git a/py/matching_test.py b/py/matching_test.py index 29e6e2b..591cac2 100644 --- a/py/matching_test.py +++ b/py/matching_test.py @@ -96,7 +96,7 @@ def members_to_groups_validate(matchees: list[Member], tmp_state: state.State, p ], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"]) def test_members_to_groups_no_history(matchees, per_group): """Test simple group matching works""" - tmp_state = state.State() + tmp_state = state.State(state._EMPTY_DICT) members_to_groups_validate(matchees, tmp_state, per_group) @@ -328,7 +328,7 @@ def items_found_in_lists(list_of_lists, items): ], ids=['simple_history', 'fallback', 'example_1', 'example_2', 'example_3']) def test_unique_regressions(history_data, matchees, per_group, checks): """Test a bunch of unqiue failures that happened in the past""" - tmp_state = state.State() + tmp_state = state.State(state._EMPTY_DICT) # Replay the history for d in history_data: @@ -380,7 +380,7 @@ def test_stess_random_groups(per_group, num_members, num_history): member.roles = [Role(i) for i in rand.sample(range(1, 8), 3)] # For each history item match up groups and log those - cumulative_state = state.State() + cumulative_state = state.State(state._EMPTY_DICT) for i in range(num_history+1): # Grab the num of members and replay @@ -394,7 +394,7 @@ def test_stess_random_groups(per_group, num_members, num_history): def test_auth_scopes(): - tmp_state = state.State() + tmp_state = state.State(state._EMPTY_DICT) id = "1" assert not tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER) diff --git a/py/matchy.py b/py/matchy.py index 63f3f09..0bccb49 100755 --- a/py/matchy.py +++ b/py/matchy.py @@ -5,12 +5,12 @@ import logging import discord from discord.ext import commands import config -import state +from state import load_from_file from cogs.matchy_cog import MatchyCog from cogs.owner_cog import OwnerCog -State = state.load_from_file() - +_STATE_FILE = "state.json" +state = load_from_file(_STATE_FILE) logger = logging.getLogger("matchy") logger.setLevel(logging.INFO) @@ -24,8 +24,8 @@ bot = commands.Bot(command_prefix='$', @bot.event async def setup_hook(): - await bot.add_cog(MatchyCog(bot, State)) - await bot.add_cog(OwnerCog(bot)) + await bot.add_cog(MatchyCog(bot, state)) + await bot.add_cog(OwnerCog(bot, state)) @bot.event diff --git a/py/cogs/owner_cog_test.py b/py/owner_cog_test.py similarity index 81% rename from py/cogs/owner_cog_test.py rename to py/owner_cog_test.py index 7fdbd70..0ea4387 100644 --- a/py/cogs/owner_cog_test.py +++ b/py/owner_cog_test.py @@ -2,9 +2,10 @@ import discord import discord.ext.commands as commands import pytest import pytest_asyncio +import state import discord.ext.test as dpytest -from owner_cog import OwnerCog +from cogs.owner_cog import OwnerCog # Primarily borrowing from https://dpytest.readthedocs.io/en/latest/tutorials/using_pytest.html # TODO: Test more somehow, though it seems like dpytest is pretty incomplete @@ -19,7 +20,7 @@ async def bot(): b = commands.Bot(command_prefix="$", intents=intents) await b._async_setup_hook() - await b.add_cog(OwnerCog(b)) + await b.add_cog(OwnerCog(b, state.State(state._EMPTY_DICT))) dpytest.configure(b) yield b await dpytest.empty_queue() @@ -32,3 +33,6 @@ async def test_must_be_owner(bot): with pytest.raises(commands.errors.NotOwner): await dpytest.message("$close") + + with pytest.raises(commands.errors.NotOwner): + await dpytest.message("$grant") diff --git a/py/state.py b/py/state.py index 81aafae..bbf9fab 100644 --- a/py/state.py +++ b/py/state.py @@ -13,10 +13,6 @@ logger = logging.getLogger("state") logger.setLevel(logging.INFO) -# Location of the default state file -_STATE_FILE = "state.json" - - # Warning: Changing any of the below needs proper thought to ensure backwards compatibility _VERSION = 4 @@ -174,10 +170,11 @@ def datetime_to_ts(ts: datetime) -> str: class State(): - def __init__(self, data: dict = _EMPTY_DICT): + def __init__(self, data: dict, file: str | None = None): """Initialise and validate the state""" self.validate(data) self._dict = copy.deepcopy(data) + self._file = file def validate(self, dict: dict = None): """Initialise and validate a state dict""" @@ -208,7 +205,7 @@ class State(): def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None: """Log the groups""" ts = datetime_to_ts(ts or datetime.now()) - with self._safe_wrap() as safe_state: + with self._safe_wrap_write() as safe_state: for group in groups: # Update the matchee data with the matches for m in group: @@ -220,7 +217,7 @@ class State(): def set_user_scope(self, id: str, scope: str, value: bool = True): """Add an auth scope to a user""" - with self._safe_wrap() as safe_state: + with self._safe_wrap_write() as safe_state: # Dive in user = safe_state._users.setdefault(str(id), {}) scopes = user.setdefault(_Key.SCOPES, []) @@ -260,7 +257,7 @@ class State(): def reactivate_users(self, channel_id: str): """Reactivate any users who've passed their reactivation time on this channel""" - with self._safe_wrap() as safe_state: + with self._safe_wrap_write() as safe_state: for user in safe_state._users.values(): channels = user.get(_Key.CHANNELS, {}) channel = channels.get(str(channel_id), {}) @@ -300,7 +297,7 @@ class State(): def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool: """Set up a match task on a channel""" - with self._safe_wrap() as safe_state: + with self._safe_wrap_write() as safe_state: channel = safe_state._tasks.setdefault(str(channel_id), {}) matches = channel.setdefault(_Key.MATCH_TASKS, []) @@ -345,7 +342,7 @@ class State(): def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value): """Set a user channel property helper""" - with self._safe_wrap() as safe_state: + with self._safe_wrap_write() as safe_state: # Dive in user = safe_state._users.setdefault(str(id), {}) channels = user.setdefault(_Key.CHANNELS, {}) @@ -355,7 +352,7 @@ class State(): channel[key] = value @contextmanager - def _safe_wrap(self): + def _safe_wrap_write(self): """Safely run any function wrapped in a validate""" # Wrap in a temporary state to validate first to prevent corruption tmp_state = State(self._dict) @@ -366,6 +363,14 @@ class State(): tmp_state.validate() self._dict = tmp_state._dict + # Write this change out if we have a file + if self._file: + self._save_to_file() + + def _save_to_file(self): + """Saves the state out to the chosen file""" + files.save(self._file, self.dict_internal_copy) + def _migrate(dict: dict): """Migrate a dict through versions""" @@ -376,9 +381,9 @@ def _migrate(dict: dict): dict[_Key.VERSION] = _VERSION -def load_from_file(file: str = _STATE_FILE) -> State: +def load_from_file(file: str) -> State: """ - Load the state from a file + Load the state from a files Apply any required migrations """ loaded = _EMPTY_DICT @@ -388,14 +393,9 @@ def load_from_file(file: str = _STATE_FILE) -> State: loaded = files.load(file) _migrate(loaded) - st = State(loaded) + st = State(loaded, file) # Save out the migrated (or new) file files.save(file, st._dict) return st - - -def save_to_file(state: State, file: str = _STATE_FILE): - """Saves the state out to a file""" - files.save(file, state.dict_internal_copy) diff --git a/py/state_test.py b/py/state_test.py index b79a426..51a7f4b 100644 --- a/py/state_test.py +++ b/py/state_test.py @@ -18,10 +18,10 @@ def test_simple_load_reload(): with tempfile.TemporaryDirectory() as tmp: path = os.path.join(tmp, 'tmp.json') st = state.load_from_file(path) - state.save_to_file(st, path) + st._save_to_file() st = state.load_from_file(path) - state.save_to_file(st, path) + st._save_to_file() st = state.load_from_file(path) @@ -30,13 +30,13 @@ def test_authscope(): with tempfile.TemporaryDirectory() as tmp: path = os.path.join(tmp, 'tmp.json') st = state.load_from_file(path) - state.save_to_file(st, path) + st._save_to_file() assert not st.get_user_has_scope(1, state.AuthScope.MATCHER) st = state.load_from_file(path) st.set_user_scope(1, state.AuthScope.MATCHER) - state.save_to_file(st, path) + st._save_to_file() st = state.load_from_file(path) assert st.get_user_has_scope(1, state.AuthScope.MATCHER) @@ -50,13 +50,13 @@ def test_channeljoin(): with tempfile.TemporaryDirectory() as tmp: path = os.path.join(tmp, 'tmp.json') st = state.load_from_file(path) - state.save_to_file(st, path) + st._save_to_file() assert not st.get_user_active_in_channel(1, "2") st = state.load_from_file(path) st.set_user_active_in_channel(1, "2", True) - state.save_to_file(st, path) + st._save_to_file() st = state.load_from_file(path) assert st.get_user_active_in_channel(1, "2")