Merge pull request #4 from mdiluz/allow-grant-scope
Add $grant to give users the matcher scope
This commit is contained in:
commit
b810dedb26
9 changed files with 76 additions and 52 deletions
|
@ -31,7 +31,10 @@ Allows a matcher to set a weekly schedule for matches in the channel, cancel can
|
|||
Only usable by users with the `owner` scope. Only usable in a DM with the bot user.
|
||||
|
||||
#### $sync and $close
|
||||
Syncs bot commands and reloads the state file, or closes down the bot.
|
||||
Syncs bot commands or closes down the bot.
|
||||
|
||||
#### $grant [user: int]
|
||||
Grant a given user the matcher scope to allow them to use `/match` and `/schedule`.
|
||||
|
||||
## Development
|
||||
Current development is on Linux, though running on Mac or Windows should work fine.
|
||||
|
@ -94,7 +97,6 @@ State is stored locally in a `state.json` file. This will be created by the bot.
|
|||
|
||||
## TODO
|
||||
* Implement better tests to the discordy parts of the codebase
|
||||
* Rethink the matcher scope, seems like maybe this could be simpler or removed
|
||||
* Implement a .json file upgrade test
|
||||
* Track if matches were successful
|
||||
* Improve the weirdo
|
||||
|
|
|
@ -9,7 +9,7 @@ from datetime import datetime, timedelta, time
|
|||
|
||||
import cogs.match_button as match_button
|
||||
import matching
|
||||
from state import State, save_to_file, AuthScope
|
||||
from state import State, AuthScope
|
||||
import util
|
||||
|
||||
logger = logging.getLogger("cog")
|
||||
|
@ -38,7 +38,6 @@ class MatchyCog(commands.Cog):
|
|||
|
||||
self.state.set_user_active_in_channel(
|
||||
interaction.user.id, interaction.channel.id)
|
||||
save_to_file(self.state)
|
||||
await interaction.response.send_message(
|
||||
f"Roger roger {interaction.user.mention}!\n"
|
||||
+ f"Added you to {interaction.channel.mention}!",
|
||||
|
@ -52,7 +51,6 @@ class MatchyCog(commands.Cog):
|
|||
|
||||
self.state.set_user_active_in_channel(
|
||||
interaction.user.id, interaction.channel.id, False)
|
||||
save_to_file(self.state)
|
||||
await interaction.response.send_message(
|
||||
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
|
||||
|
||||
|
@ -68,7 +66,6 @@ class MatchyCog(commands.Cog):
|
|||
until = datetime.now() + timedelta(days=days)
|
||||
self.state.set_user_paused_in_channel(
|
||||
interaction.user.id, interaction.channel.id, until)
|
||||
save_to_file(self.state)
|
||||
await interaction.response.send_message(
|
||||
f"Sure thing {interaction.user.mention}!\n"
|
||||
+ f"Paused you until {util.format_day(until)}!",
|
||||
|
@ -127,7 +124,6 @@ class MatchyCog(commands.Cog):
|
|||
# Add the scheduled task and save
|
||||
success = self.state.set_channel_match_task(
|
||||
channel_id, members_min, weekday, hour, not cancel)
|
||||
save_to_file(self.state)
|
||||
|
||||
# Let the user know what happened
|
||||
if not cancel:
|
||||
|
|
|
@ -3,22 +3,27 @@ Owner bot cog
|
|||
"""
|
||||
import logging
|
||||
from discord.ext import commands
|
||||
from state import State, AuthScope
|
||||
|
||||
logger = logging.getLogger("owner")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class OwnerCog(commands.Cog):
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
def __init__(self, bot: commands.Bot, state: State):
|
||||
self._bot = bot
|
||||
self._state = state
|
||||
|
||||
@commands.command()
|
||||
@commands.dm_only()
|
||||
@commands.is_owner()
|
||||
async def sync(self, ctx: commands.Context):
|
||||
"""Handle sync command"""
|
||||
"""
|
||||
Sync the bot commands
|
||||
You get rate limited if you do this too often so it's better to keep it on command
|
||||
"""
|
||||
msg = await ctx.reply(content="Syncing commands...", ephemeral=True)
|
||||
synced = await self.bot.tree.sync()
|
||||
synced = await self._bot.tree.sync()
|
||||
logger.info("Synced %s command(s)", len(synced))
|
||||
await msg.edit(content="Done!")
|
||||
|
||||
|
@ -26,7 +31,25 @@ class OwnerCog(commands.Cog):
|
|||
@commands.dm_only()
|
||||
@commands.is_owner()
|
||||
async def close(self, ctx: commands.Context):
|
||||
"""Handle restart command"""
|
||||
"""
|
||||
Handle close command
|
||||
Shuts down the bot when needed
|
||||
"""
|
||||
await ctx.reply("Closing bot...", ephemeral=True)
|
||||
logger.info("Closing down the bot")
|
||||
await self.bot.close()
|
||||
await self._bot.close()
|
||||
|
||||
@commands.command()
|
||||
@commands.dm_only()
|
||||
@commands.is_owner()
|
||||
async def grant(self, ctx: commands.Context, user: str):
|
||||
"""
|
||||
Handle grant command
|
||||
Grant the matcher scope to a given user
|
||||
"""
|
||||
if user.isdigit():
|
||||
self._state.set_user_scope(str(user), AuthScope.MATCHER)
|
||||
logger.info("Granting user %s matcher scope", user)
|
||||
await ctx.reply("Done!", ephemeral=True)
|
||||
else:
|
||||
await ctx.reply("Likely not a user...", ephemeral=True)
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import discord
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Protocol, runtime_checkable
|
||||
from state import State, save_to_file, ts_to_datetime
|
||||
from state import State, ts_to_datetime
|
||||
import util
|
||||
import config
|
||||
|
||||
|
@ -166,7 +166,7 @@ def iterate_all_shifts(list: list):
|
|||
|
||||
|
||||
def members_to_groups(matchees: list[Member],
|
||||
state: State = State(),
|
||||
state: State,
|
||||
per_group: int = 3,
|
||||
allow_fallback: bool = False) -> list[list[Member]]:
|
||||
"""Generate the groups from the set of matchees"""
|
||||
|
@ -224,7 +224,6 @@ async def match_groups_in_channel(state: State, channel: discord.channel, min: i
|
|||
|
||||
# Save the groups to the history
|
||||
state.log_groups(groups)
|
||||
save_to_file(state)
|
||||
|
||||
logger.info("Done! Matched into %s groups.", len(groups))
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ def members_to_groups_validate(matchees: list[Member], tmp_state: state.State, p
|
|||
], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"])
|
||||
def test_members_to_groups_no_history(matchees, per_group):
|
||||
"""Test simple group matching works"""
|
||||
tmp_state = state.State()
|
||||
tmp_state = state.State(state._EMPTY_DICT)
|
||||
members_to_groups_validate(matchees, tmp_state, per_group)
|
||||
|
||||
|
||||
|
@ -328,7 +328,7 @@ def items_found_in_lists(list_of_lists, items):
|
|||
], ids=['simple_history', 'fallback', 'example_1', 'example_2', 'example_3'])
|
||||
def test_unique_regressions(history_data, matchees, per_group, checks):
|
||||
"""Test a bunch of unqiue failures that happened in the past"""
|
||||
tmp_state = state.State()
|
||||
tmp_state = state.State(state._EMPTY_DICT)
|
||||
|
||||
# Replay the history
|
||||
for d in history_data:
|
||||
|
@ -380,7 +380,7 @@ def test_stess_random_groups(per_group, num_members, num_history):
|
|||
member.roles = [Role(i) for i in rand.sample(range(1, 8), 3)]
|
||||
|
||||
# For each history item match up groups and log those
|
||||
cumulative_state = state.State()
|
||||
cumulative_state = state.State(state._EMPTY_DICT)
|
||||
for i in range(num_history+1):
|
||||
|
||||
# Grab the num of members and replay
|
||||
|
@ -394,7 +394,7 @@ def test_stess_random_groups(per_group, num_members, num_history):
|
|||
|
||||
|
||||
def test_auth_scopes():
|
||||
tmp_state = state.State()
|
||||
tmp_state = state.State(state._EMPTY_DICT)
|
||||
|
||||
id = "1"
|
||||
assert not tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
|
||||
|
|
10
py/matchy.py
10
py/matchy.py
|
@ -5,12 +5,12 @@ import logging
|
|||
import discord
|
||||
from discord.ext import commands
|
||||
import config
|
||||
import state
|
||||
from state import load_from_file
|
||||
from cogs.matchy_cog import MatchyCog
|
||||
from cogs.owner_cog import OwnerCog
|
||||
|
||||
State = state.load_from_file()
|
||||
|
||||
_STATE_FILE = "state.json"
|
||||
state = load_from_file(_STATE_FILE)
|
||||
|
||||
logger = logging.getLogger("matchy")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
@ -24,8 +24,8 @@ bot = commands.Bot(command_prefix='$',
|
|||
|
||||
@bot.event
|
||||
async def setup_hook():
|
||||
await bot.add_cog(MatchyCog(bot, State))
|
||||
await bot.add_cog(OwnerCog(bot))
|
||||
await bot.add_cog(MatchyCog(bot, state))
|
||||
await bot.add_cog(OwnerCog(bot, state))
|
||||
|
||||
|
||||
@bot.event
|
||||
|
|
|
@ -2,9 +2,10 @@ import discord
|
|||
import discord.ext.commands as commands
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import state
|
||||
import discord.ext.test as dpytest
|
||||
|
||||
from owner_cog import OwnerCog
|
||||
from cogs.owner_cog import OwnerCog
|
||||
|
||||
# Primarily borrowing from https://dpytest.readthedocs.io/en/latest/tutorials/using_pytest.html
|
||||
# TODO: Test more somehow, though it seems like dpytest is pretty incomplete
|
||||
|
@ -19,7 +20,7 @@ async def bot():
|
|||
b = commands.Bot(command_prefix="$",
|
||||
intents=intents)
|
||||
await b._async_setup_hook()
|
||||
await b.add_cog(OwnerCog(b))
|
||||
await b.add_cog(OwnerCog(b, state.State(state._EMPTY_DICT)))
|
||||
dpytest.configure(b)
|
||||
yield b
|
||||
await dpytest.empty_queue()
|
||||
|
@ -32,3 +33,6 @@ async def test_must_be_owner(bot):
|
|||
|
||||
with pytest.raises(commands.errors.NotOwner):
|
||||
await dpytest.message("$close")
|
||||
|
||||
with pytest.raises(commands.errors.NotOwner):
|
||||
await dpytest.message("$grant")
|
38
py/state.py
38
py/state.py
|
@ -13,10 +13,6 @@ logger = logging.getLogger("state")
|
|||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
# Location of the default state file
|
||||
_STATE_FILE = "state.json"
|
||||
|
||||
|
||||
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
|
||||
_VERSION = 4
|
||||
|
||||
|
@ -174,10 +170,11 @@ def datetime_to_ts(ts: datetime) -> str:
|
|||
|
||||
|
||||
class State():
|
||||
def __init__(self, data: dict = _EMPTY_DICT):
|
||||
def __init__(self, data: dict, file: str | None = None):
|
||||
"""Initialise and validate the state"""
|
||||
self.validate(data)
|
||||
self._dict = copy.deepcopy(data)
|
||||
self._file = file
|
||||
|
||||
def validate(self, dict: dict = None):
|
||||
"""Initialise and validate a state dict"""
|
||||
|
@ -208,7 +205,7 @@ class State():
|
|||
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
|
||||
"""Log the groups"""
|
||||
ts = datetime_to_ts(ts or datetime.now())
|
||||
with self._safe_wrap() as safe_state:
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
for group in groups:
|
||||
# Update the matchee data with the matches
|
||||
for m in group:
|
||||
|
@ -220,7 +217,7 @@ class State():
|
|||
|
||||
def set_user_scope(self, id: str, scope: str, value: bool = True):
|
||||
"""Add an auth scope to a user"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
# Dive in
|
||||
user = safe_state._users.setdefault(str(id), {})
|
||||
scopes = user.setdefault(_Key.SCOPES, [])
|
||||
|
@ -260,7 +257,7 @@ class State():
|
|||
|
||||
def reactivate_users(self, channel_id: str):
|
||||
"""Reactivate any users who've passed their reactivation time on this channel"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
for user in safe_state._users.values():
|
||||
channels = user.get(_Key.CHANNELS, {})
|
||||
channel = channels.get(str(channel_id), {})
|
||||
|
@ -300,7 +297,7 @@ class State():
|
|||
|
||||
def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool:
|
||||
"""Set up a match task on a channel"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
channel = safe_state._tasks.setdefault(str(channel_id), {})
|
||||
matches = channel.setdefault(_Key.MATCH_TASKS, [])
|
||||
|
||||
|
@ -345,7 +342,7 @@ class State():
|
|||
|
||||
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
|
||||
"""Set a user channel property helper"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
# Dive in
|
||||
user = safe_state._users.setdefault(str(id), {})
|
||||
channels = user.setdefault(_Key.CHANNELS, {})
|
||||
|
@ -355,7 +352,7 @@ class State():
|
|||
channel[key] = value
|
||||
|
||||
@contextmanager
|
||||
def _safe_wrap(self):
|
||||
def _safe_wrap_write(self):
|
||||
"""Safely run any function wrapped in a validate"""
|
||||
# Wrap in a temporary state to validate first to prevent corruption
|
||||
tmp_state = State(self._dict)
|
||||
|
@ -366,6 +363,14 @@ class State():
|
|||
tmp_state.validate()
|
||||
self._dict = tmp_state._dict
|
||||
|
||||
# Write this change out if we have a file
|
||||
if self._file:
|
||||
self._save_to_file()
|
||||
|
||||
def _save_to_file(self):
|
||||
"""Saves the state out to the chosen file"""
|
||||
files.save(self._file, self.dict_internal_copy)
|
||||
|
||||
|
||||
def _migrate(dict: dict):
|
||||
"""Migrate a dict through versions"""
|
||||
|
@ -376,9 +381,9 @@ def _migrate(dict: dict):
|
|||
dict[_Key.VERSION] = _VERSION
|
||||
|
||||
|
||||
def load_from_file(file: str = _STATE_FILE) -> State:
|
||||
def load_from_file(file: str) -> State:
|
||||
"""
|
||||
Load the state from a file
|
||||
Load the state from a files
|
||||
Apply any required migrations
|
||||
"""
|
||||
loaded = _EMPTY_DICT
|
||||
|
@ -388,14 +393,9 @@ def load_from_file(file: str = _STATE_FILE) -> State:
|
|||
loaded = files.load(file)
|
||||
_migrate(loaded)
|
||||
|
||||
st = State(loaded)
|
||||
st = State(loaded, file)
|
||||
|
||||
# Save out the migrated (or new) file
|
||||
files.save(file, st._dict)
|
||||
|
||||
return st
|
||||
|
||||
|
||||
def save_to_file(state: State, file: str = _STATE_FILE):
|
||||
"""Saves the state out to a file"""
|
||||
files.save(file, state.dict_internal_copy)
|
||||
|
|
|
@ -18,10 +18,10 @@ def test_simple_load_reload():
|
|||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
st._save_to_file()
|
||||
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
st._save_to_file()
|
||||
st = state.load_from_file(path)
|
||||
|
||||
|
||||
|
@ -30,13 +30,13 @@ def test_authscope():
|
|||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
st._save_to_file()
|
||||
|
||||
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
|
||||
|
||||
st = state.load_from_file(path)
|
||||
st.set_user_scope(1, state.AuthScope.MATCHER)
|
||||
state.save_to_file(st, path)
|
||||
st._save_to_file()
|
||||
|
||||
st = state.load_from_file(path)
|
||||
assert st.get_user_has_scope(1, state.AuthScope.MATCHER)
|
||||
|
@ -50,13 +50,13 @@ def test_channeljoin():
|
|||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
st._save_to_file()
|
||||
|
||||
assert not st.get_user_active_in_channel(1, "2")
|
||||
|
||||
st = state.load_from_file(path)
|
||||
st.set_user_active_in_channel(1, "2", True)
|
||||
state.save_to_file(st, path)
|
||||
st._save_to_file()
|
||||
|
||||
st = state.load_from_file(path)
|
||||
assert st.get_user_active_in_channel(1, "2")
|
||||
|
|
Loading…
Add table
Reference in a new issue