Merge pull request #4 from mdiluz/allow-grant-scope

Add $grant to give users the matcher scope
This commit is contained in:
Marc Di Luzio 2024-08-13 23:46:22 +01:00 committed by GitHub
commit b810dedb26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 76 additions and 52 deletions

View file

@ -31,7 +31,10 @@ Allows a matcher to set a weekly schedule for matches in the channel, cancel can
Only usable by users with the `owner` scope. Only usable in a DM with the bot user. Only usable by users with the `owner` scope. Only usable in a DM with the bot user.
#### $sync and $close #### $sync and $close
Syncs bot commands and reloads the state file, or closes down the bot. Syncs bot commands or closes down the bot.
#### $grant [user: int]
Grant a given user the matcher scope to allow them to use `/match` and `/schedule`.
## Development ## Development
Current development is on Linux, though running on Mac or Windows should work fine. Current development is on Linux, though running on Mac or Windows should work fine.
@ -94,7 +97,6 @@ State is stored locally in a `state.json` file. This will be created by the bot.
## TODO ## TODO
* Implement better tests to the discordy parts of the codebase * Implement better tests to the discordy parts of the codebase
* Rethink the matcher scope, seems like maybe this could be simpler or removed
* Implement a .json file upgrade test * Implement a .json file upgrade test
* Track if matches were successful * Track if matches were successful
* Improve the weirdo * Improve the weirdo

View file

@ -9,7 +9,7 @@ from datetime import datetime, timedelta, time
import cogs.match_button as match_button import cogs.match_button as match_button
import matching import matching
from state import State, save_to_file, AuthScope from state import State, AuthScope
import util import util
logger = logging.getLogger("cog") logger = logging.getLogger("cog")
@ -38,7 +38,6 @@ class MatchyCog(commands.Cog):
self.state.set_user_active_in_channel( self.state.set_user_active_in_channel(
interaction.user.id, interaction.channel.id) interaction.user.id, interaction.channel.id)
save_to_file(self.state)
await interaction.response.send_message( await interaction.response.send_message(
f"Roger roger {interaction.user.mention}!\n" f"Roger roger {interaction.user.mention}!\n"
+ f"Added you to {interaction.channel.mention}!", + f"Added you to {interaction.channel.mention}!",
@ -52,7 +51,6 @@ class MatchyCog(commands.Cog):
self.state.set_user_active_in_channel( self.state.set_user_active_in_channel(
interaction.user.id, interaction.channel.id, False) interaction.user.id, interaction.channel.id, False)
save_to_file(self.state)
await interaction.response.send_message( await interaction.response.send_message(
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True) f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
@ -68,7 +66,6 @@ class MatchyCog(commands.Cog):
until = datetime.now() + timedelta(days=days) until = datetime.now() + timedelta(days=days)
self.state.set_user_paused_in_channel( self.state.set_user_paused_in_channel(
interaction.user.id, interaction.channel.id, until) interaction.user.id, interaction.channel.id, until)
save_to_file(self.state)
await interaction.response.send_message( await interaction.response.send_message(
f"Sure thing {interaction.user.mention}!\n" f"Sure thing {interaction.user.mention}!\n"
+ f"Paused you until {util.format_day(until)}!", + f"Paused you until {util.format_day(until)}!",
@ -127,7 +124,6 @@ class MatchyCog(commands.Cog):
# Add the scheduled task and save # Add the scheduled task and save
success = self.state.set_channel_match_task( success = self.state.set_channel_match_task(
channel_id, members_min, weekday, hour, not cancel) channel_id, members_min, weekday, hour, not cancel)
save_to_file(self.state)
# Let the user know what happened # Let the user know what happened
if not cancel: if not cancel:

View file

@ -3,22 +3,27 @@ Owner bot cog
""" """
import logging import logging
from discord.ext import commands from discord.ext import commands
from state import State, AuthScope
logger = logging.getLogger("owner") logger = logging.getLogger("owner")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
class OwnerCog(commands.Cog): class OwnerCog(commands.Cog):
def __init__(self, bot: commands.Bot): def __init__(self, bot: commands.Bot, state: State):
self.bot = bot self._bot = bot
self._state = state
@commands.command() @commands.command()
@commands.dm_only() @commands.dm_only()
@commands.is_owner() @commands.is_owner()
async def sync(self, ctx: commands.Context): async def sync(self, ctx: commands.Context):
"""Handle sync command""" """
Sync the bot commands
You get rate limited if you do this too often so it's better to keep it on command
"""
msg = await ctx.reply(content="Syncing commands...", ephemeral=True) msg = await ctx.reply(content="Syncing commands...", ephemeral=True)
synced = await self.bot.tree.sync() synced = await self._bot.tree.sync()
logger.info("Synced %s command(s)", len(synced)) logger.info("Synced %s command(s)", len(synced))
await msg.edit(content="Done!") await msg.edit(content="Done!")
@ -26,7 +31,25 @@ class OwnerCog(commands.Cog):
@commands.dm_only() @commands.dm_only()
@commands.is_owner() @commands.is_owner()
async def close(self, ctx: commands.Context): async def close(self, ctx: commands.Context):
"""Handle restart command""" """
Handle close command
Shuts down the bot when needed
"""
await ctx.reply("Closing bot...", ephemeral=True) await ctx.reply("Closing bot...", ephemeral=True)
logger.info("Closing down the bot") logger.info("Closing down the bot")
await self.bot.close() await self._bot.close()
@commands.command()
@commands.dm_only()
@commands.is_owner()
async def grant(self, ctx: commands.Context, user: str):
"""
Handle grant command
Grant the matcher scope to a given user
"""
if user.isdigit():
self._state.set_user_scope(str(user), AuthScope.MATCHER)
logger.info("Granting user %s matcher scope", user)
await ctx.reply("Done!", ephemeral=True)
else:
await ctx.reply("Likely not a user...", ephemeral=True)

View file

@ -3,7 +3,7 @@ import logging
import discord import discord
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Protocol, runtime_checkable from typing import Protocol, runtime_checkable
from state import State, save_to_file, ts_to_datetime from state import State, ts_to_datetime
import util import util
import config import config
@ -166,7 +166,7 @@ def iterate_all_shifts(list: list):
def members_to_groups(matchees: list[Member], def members_to_groups(matchees: list[Member],
state: State = State(), state: State,
per_group: int = 3, per_group: int = 3,
allow_fallback: bool = False) -> list[list[Member]]: allow_fallback: bool = False) -> list[list[Member]]:
"""Generate the groups from the set of matchees""" """Generate the groups from the set of matchees"""
@ -224,7 +224,6 @@ async def match_groups_in_channel(state: State, channel: discord.channel, min: i
# Save the groups to the history # Save the groups to the history
state.log_groups(groups) state.log_groups(groups)
save_to_file(state)
logger.info("Done! Matched into %s groups.", len(groups)) logger.info("Done! Matched into %s groups.", len(groups))

View file

@ -96,7 +96,7 @@ def members_to_groups_validate(matchees: list[Member], tmp_state: state.State, p
], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"]) ], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"])
def test_members_to_groups_no_history(matchees, per_group): def test_members_to_groups_no_history(matchees, per_group):
"""Test simple group matching works""" """Test simple group matching works"""
tmp_state = state.State() tmp_state = state.State(state._EMPTY_DICT)
members_to_groups_validate(matchees, tmp_state, per_group) members_to_groups_validate(matchees, tmp_state, per_group)
@ -328,7 +328,7 @@ def items_found_in_lists(list_of_lists, items):
], ids=['simple_history', 'fallback', 'example_1', 'example_2', 'example_3']) ], ids=['simple_history', 'fallback', 'example_1', 'example_2', 'example_3'])
def test_unique_regressions(history_data, matchees, per_group, checks): def test_unique_regressions(history_data, matchees, per_group, checks):
"""Test a bunch of unqiue failures that happened in the past""" """Test a bunch of unqiue failures that happened in the past"""
tmp_state = state.State() tmp_state = state.State(state._EMPTY_DICT)
# Replay the history # Replay the history
for d in history_data: for d in history_data:
@ -380,7 +380,7 @@ def test_stess_random_groups(per_group, num_members, num_history):
member.roles = [Role(i) for i in rand.sample(range(1, 8), 3)] member.roles = [Role(i) for i in rand.sample(range(1, 8), 3)]
# For each history item match up groups and log those # For each history item match up groups and log those
cumulative_state = state.State() cumulative_state = state.State(state._EMPTY_DICT)
for i in range(num_history+1): for i in range(num_history+1):
# Grab the num of members and replay # Grab the num of members and replay
@ -394,7 +394,7 @@ def test_stess_random_groups(per_group, num_members, num_history):
def test_auth_scopes(): def test_auth_scopes():
tmp_state = state.State() tmp_state = state.State(state._EMPTY_DICT)
id = "1" id = "1"
assert not tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER) assert not tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)

View file

@ -5,12 +5,12 @@ import logging
import discord import discord
from discord.ext import commands from discord.ext import commands
import config import config
import state from state import load_from_file
from cogs.matchy_cog import MatchyCog from cogs.matchy_cog import MatchyCog
from cogs.owner_cog import OwnerCog from cogs.owner_cog import OwnerCog
State = state.load_from_file() _STATE_FILE = "state.json"
state = load_from_file(_STATE_FILE)
logger = logging.getLogger("matchy") logger = logging.getLogger("matchy")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -24,8 +24,8 @@ bot = commands.Bot(command_prefix='$',
@bot.event @bot.event
async def setup_hook(): async def setup_hook():
await bot.add_cog(MatchyCog(bot, State)) await bot.add_cog(MatchyCog(bot, state))
await bot.add_cog(OwnerCog(bot)) await bot.add_cog(OwnerCog(bot, state))
@bot.event @bot.event

View file

@ -2,9 +2,10 @@ import discord
import discord.ext.commands as commands import discord.ext.commands as commands
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import state
import discord.ext.test as dpytest import discord.ext.test as dpytest
from owner_cog import OwnerCog from cogs.owner_cog import OwnerCog
# Primarily borrowing from https://dpytest.readthedocs.io/en/latest/tutorials/using_pytest.html # Primarily borrowing from https://dpytest.readthedocs.io/en/latest/tutorials/using_pytest.html
# TODO: Test more somehow, though it seems like dpytest is pretty incomplete # TODO: Test more somehow, though it seems like dpytest is pretty incomplete
@ -19,7 +20,7 @@ async def bot():
b = commands.Bot(command_prefix="$", b = commands.Bot(command_prefix="$",
intents=intents) intents=intents)
await b._async_setup_hook() await b._async_setup_hook()
await b.add_cog(OwnerCog(b)) await b.add_cog(OwnerCog(b, state.State(state._EMPTY_DICT)))
dpytest.configure(b) dpytest.configure(b)
yield b yield b
await dpytest.empty_queue() await dpytest.empty_queue()
@ -32,3 +33,6 @@ async def test_must_be_owner(bot):
with pytest.raises(commands.errors.NotOwner): with pytest.raises(commands.errors.NotOwner):
await dpytest.message("$close") await dpytest.message("$close")
with pytest.raises(commands.errors.NotOwner):
await dpytest.message("$grant")

View file

@ -13,10 +13,6 @@ logger = logging.getLogger("state")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
# Location of the default state file
_STATE_FILE = "state.json"
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility # Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_VERSION = 4 _VERSION = 4
@ -174,10 +170,11 @@ def datetime_to_ts(ts: datetime) -> str:
class State(): class State():
def __init__(self, data: dict = _EMPTY_DICT): def __init__(self, data: dict, file: str | None = None):
"""Initialise and validate the state""" """Initialise and validate the state"""
self.validate(data) self.validate(data)
self._dict = copy.deepcopy(data) self._dict = copy.deepcopy(data)
self._file = file
def validate(self, dict: dict = None): def validate(self, dict: dict = None):
"""Initialise and validate a state dict""" """Initialise and validate a state dict"""
@ -208,7 +205,7 @@ class State():
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None: def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
"""Log the groups""" """Log the groups"""
ts = datetime_to_ts(ts or datetime.now()) ts = datetime_to_ts(ts or datetime.now())
with self._safe_wrap() as safe_state: with self._safe_wrap_write() as safe_state:
for group in groups: for group in groups:
# Update the matchee data with the matches # Update the matchee data with the matches
for m in group: for m in group:
@ -220,7 +217,7 @@ class State():
def set_user_scope(self, id: str, scope: str, value: bool = True): def set_user_scope(self, id: str, scope: str, value: bool = True):
"""Add an auth scope to a user""" """Add an auth scope to a user"""
with self._safe_wrap() as safe_state: with self._safe_wrap_write() as safe_state:
# Dive in # Dive in
user = safe_state._users.setdefault(str(id), {}) user = safe_state._users.setdefault(str(id), {})
scopes = user.setdefault(_Key.SCOPES, []) scopes = user.setdefault(_Key.SCOPES, [])
@ -260,7 +257,7 @@ class State():
def reactivate_users(self, channel_id: str): def reactivate_users(self, channel_id: str):
"""Reactivate any users who've passed their reactivation time on this channel""" """Reactivate any users who've passed their reactivation time on this channel"""
with self._safe_wrap() as safe_state: with self._safe_wrap_write() as safe_state:
for user in safe_state._users.values(): for user in safe_state._users.values():
channels = user.get(_Key.CHANNELS, {}) channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {}) channel = channels.get(str(channel_id), {})
@ -300,7 +297,7 @@ class State():
def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool: def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool:
"""Set up a match task on a channel""" """Set up a match task on a channel"""
with self._safe_wrap() as safe_state: with self._safe_wrap_write() as safe_state:
channel = safe_state._tasks.setdefault(str(channel_id), {}) channel = safe_state._tasks.setdefault(str(channel_id), {})
matches = channel.setdefault(_Key.MATCH_TASKS, []) matches = channel.setdefault(_Key.MATCH_TASKS, [])
@ -345,7 +342,7 @@ class State():
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value): def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
"""Set a user channel property helper""" """Set a user channel property helper"""
with self._safe_wrap() as safe_state: with self._safe_wrap_write() as safe_state:
# Dive in # Dive in
user = safe_state._users.setdefault(str(id), {}) user = safe_state._users.setdefault(str(id), {})
channels = user.setdefault(_Key.CHANNELS, {}) channels = user.setdefault(_Key.CHANNELS, {})
@ -355,7 +352,7 @@ class State():
channel[key] = value channel[key] = value
@contextmanager @contextmanager
def _safe_wrap(self): def _safe_wrap_write(self):
"""Safely run any function wrapped in a validate""" """Safely run any function wrapped in a validate"""
# Wrap in a temporary state to validate first to prevent corruption # Wrap in a temporary state to validate first to prevent corruption
tmp_state = State(self._dict) tmp_state = State(self._dict)
@ -366,6 +363,14 @@ class State():
tmp_state.validate() tmp_state.validate()
self._dict = tmp_state._dict self._dict = tmp_state._dict
# Write this change out if we have a file
if self._file:
self._save_to_file()
def _save_to_file(self):
"""Saves the state out to the chosen file"""
files.save(self._file, self.dict_internal_copy)
def _migrate(dict: dict): def _migrate(dict: dict):
"""Migrate a dict through versions""" """Migrate a dict through versions"""
@ -376,9 +381,9 @@ def _migrate(dict: dict):
dict[_Key.VERSION] = _VERSION dict[_Key.VERSION] = _VERSION
def load_from_file(file: str = _STATE_FILE) -> State: def load_from_file(file: str) -> State:
""" """
Load the state from a file Load the state from a files
Apply any required migrations Apply any required migrations
""" """
loaded = _EMPTY_DICT loaded = _EMPTY_DICT
@ -388,14 +393,9 @@ def load_from_file(file: str = _STATE_FILE) -> State:
loaded = files.load(file) loaded = files.load(file)
_migrate(loaded) _migrate(loaded)
st = State(loaded) st = State(loaded, file)
# Save out the migrated (or new) file # Save out the migrated (or new) file
files.save(file, st._dict) files.save(file, st._dict)
return st return st
def save_to_file(state: State, file: str = _STATE_FILE):
"""Saves the state out to a file"""
files.save(file, state.dict_internal_copy)

View file

@ -18,10 +18,10 @@ def test_simple_load_reload():
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json') path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path) st = state.load_from_file(path)
state.save_to_file(st, path) st._save_to_file()
st = state.load_from_file(path) st = state.load_from_file(path)
state.save_to_file(st, path) st._save_to_file()
st = state.load_from_file(path) st = state.load_from_file(path)
@ -30,13 +30,13 @@ def test_authscope():
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json') path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path) st = state.load_from_file(path)
state.save_to_file(st, path) st._save_to_file()
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER) assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
st = state.load_from_file(path) st = state.load_from_file(path)
st.set_user_scope(1, state.AuthScope.MATCHER) st.set_user_scope(1, state.AuthScope.MATCHER)
state.save_to_file(st, path) st._save_to_file()
st = state.load_from_file(path) st = state.load_from_file(path)
assert st.get_user_has_scope(1, state.AuthScope.MATCHER) assert st.get_user_has_scope(1, state.AuthScope.MATCHER)
@ -50,13 +50,13 @@ def test_channeljoin():
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json') path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path) st = state.load_from_file(path)
state.save_to_file(st, path) st._save_to_file()
assert not st.get_user_active_in_channel(1, "2") assert not st.get_user_active_in_channel(1, "2")
st = state.load_from_file(path) st = state.load_from_file(path)
st.set_user_active_in_channel(1, "2", True) st.set_user_active_in_channel(1, "2", True)
state.save_to_file(st, path) st._save_to_file()
st = state.load_from_file(path) st = state.load_from_file(path)
assert st.get_user_active_in_channel(1, "2") assert st.get_user_active_in_channel(1, "2")