Significant set of changes

* Use /join and /leave instead of roles
* Use scopes to check for user rights rather than using the config file
* Add /list to show the set of current people signed up
* Add a bunch more testing for various things
* Version both the config and the state
This commit is contained in:
Marc Di Luzio 2024-08-11 17:53:37 +01:00
parent 78834f5319
commit d3a22ff090
8 changed files with 537 additions and 295 deletions

View file

@ -20,18 +20,16 @@ Only usable by `OWNER` users, reloads the config and syncs commands, or closes d
Matchy is configured by a `config.json` file that takes this format: Matchy is configured by a `config.json` file that takes this format:
``` ```
{ {
"version": 1,
"token": "<<github bot token>>", "token": "<<github bot token>>",
"owners": [
<<owner id>>
]
} }
``` ```
User IDs can be grabbed by turning on Discord's developer mode and right clicking on a user. User IDs can be grabbed by turning on Discord's developer mode and right clicking on a user.
## TODO ## TODO
* Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html) * Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html)
* Implement /pause to pause a user for a little while
* Move more constants to the config
* Add scheduling functionality * Add scheduling functionality
* Version the config and history files * Fix logging in some sub files
* Implement /signup rather than using roles
* Implement authorisation scopes instead of just OWNER values
* Improve the weirdo * Improve the weirdo

View file

@ -1,38 +1,77 @@
"""Very simple config loading library""" """Very simple config loading library"""
from schema import Schema, And, Use from schema import Schema, And, Use
import files import files
import os
import logging
logger = logging.getLogger("config")
logger.setLevel(logging.INFO)
_FILE = "config.json" _FILE = "config.json"
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_VERSION = 1
class _Keys():
TOKEN = "token"
VERSION = "version"
# Removed
OWNERS = "owners"
_SCHEMA = Schema( _SCHEMA = Schema(
{ {
# Discord bot token # The current version
"token": And(Use(str)), _Keys.VERSION: And(Use(int)),
# ids of owners authorised to use owner-only commands # Discord bot token
"owners": And(Use(list[int])), _Keys.TOKEN: And(Use(str)),
} }
) )
def _migrate_to_v1(d: dict):
# Owners moved to History in v1
# Note: owners will be required to be re-added to the state.json
owners = d.pop(_Keys.OWNERS)
logger.warn(
"Migration removed owners from config, these must be re-added to the state.json")
logger.warn("Owners: %s", owners)
# Set of migration functions to apply
_MIGRATIONS = [
_migrate_to_v1
]
class Config(): class Config():
def __init__(self, data: dict): def __init__(self, data: dict):
"""Initialise and validate the config""" """Initialise and validate the config"""
_SCHEMA.validate(data) _SCHEMA.validate(data)
self.__dict__ = data self._dict = data
@property @property
def token(self) -> str: def token(self) -> str:
return self.__dict__["token"] return self._dict["token"]
@property
def owners(self) -> list[int]:
return self.__dict__["owners"]
def reload(self) -> None:
"""Reload the config back into the dict"""
self.__dict__ = load().__dict__
def load() -> Config: def _migrate(dict: dict):
"""Load the config""" """Migrate a dict through versions"""
return Config(files.load(_FILE)) version = dict.get("version", 0)
for i in range(version, _VERSION):
_MIGRATIONS[i](dict)
dict["version"] = _VERSION
def load_from_file(file: str = _FILE) -> Config:
"""
Load the state from a file
Apply any required migrations
"""
assert os.path.isfile(file)
loaded = files.load(file)
_migrate(loaded)
return Config(loaded)

View file

@ -1,125 +0,0 @@
"""Store matching history"""
import os
from datetime import datetime
from schema import Schema, And, Use, Optional
from typing import Protocol
import files
import copy
_FILE = "history.json"
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_DEFAULT_DICT = {
"history": {},
"matchees": {}
}
_TIME_FORMAT = "%a %b %d %H:%M:%S %Y"
_SCHEMA = Schema(
{
Optional("history"): {
Optional(str): { # a datetime
"groups": [
{
"members": [
# The ID of each matchee in the match
And(Use(int))
]
}
]
}
},
Optional("matchees"): {
Optional(str): {
Optional("matches"): {
# Matchee ID and Datetime pair
Optional(str): And(Use(str))
}
}
}
}
)
class Member(Protocol):
@property
def id(self) -> int:
pass
def ts_to_datetime(ts: str) -> datetime:
"""Convert a ts to datetime using the history format"""
return datetime.strptime(ts, _TIME_FORMAT)
def validate(dict: dict):
"""Initialise and validate the history"""
_SCHEMA.validate(dict)
class History():
def __init__(self, data: dict = _DEFAULT_DICT):
"""Initialise and validate the history"""
validate(data)
self.__dict__ = copy.deepcopy(data)
@property
def history(self) -> list[dict]:
return self.__dict__["history"]
@property
def matchees(self) -> dict[str, dict]:
return self.__dict__["matchees"]
def save(self) -> None:
"""Save out the history"""
files.save(_FILE, self.__dict__)
def oldest(self) -> datetime:
"""Grab the oldest timestamp in history"""
if not self.history:
return None
times = (ts_to_datetime(dt) for dt in self.history.keys())
return sorted(times)[0]
def log_groups_to_history(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None:
"""Log the groups"""
tmp_history = History(self.__dict__)
ts = datetime.strftime(ts, _TIME_FORMAT)
# Grab or create the hitory item for this set of groups
history_item = {}
tmp_history.history[ts] = history_item
history_item_groups = []
history_item["groups"] = history_item_groups
for group in groups:
# Add the group data
history_item_groups.append({
"members": list(m.id for m in group)
})
# Update the matchee data with the matches
for m in group:
matchee = tmp_history.matchees.get(str(m.id), {})
matchee_matches = matchee.get("matches", {})
for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts
matchee["matches"] = matchee_matches
tmp_history.matchees[str(m.id)] = matchee
# Validate before storing the result
validate(self.__dict__)
self.__dict__ = tmp_history.__dict__
def save_groups_to_history(self, groups: list[list[Member]]) -> None:
"""Save out the groups to the history file"""
self.log_groups_to_history(groups)
self.save()
def load() -> History:
"""Load the history"""
return History(files.load(_FILE) if os.path.isfile(_FILE) else _DEFAULT_DICT)

View file

@ -1,6 +1,5 @@
"""Utility functions for matchy""" """Utility functions for matchy"""
import logging import logging
import random
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Protocol, runtime_checkable from typing import Protocol, runtime_checkable
import state import state
@ -9,17 +8,16 @@ import state
# Number of days to step forward from the start of history for each match attempt # Number of days to step forward from the start of history for each match attempt
_ATTEMPT_TIMESTEP_INCREMENT = timedelta(days=7) _ATTEMPT_TIMESTEP_INCREMENT = timedelta(days=7)
# Attempts for each of those time periods
_ATTEMPTS_PER_TIMESTEP = 3
# Various eligability scoring factors for group meetups class _ScoreFactors(int):
_SCORE_CURRENT_MEMBERS = 2**1 """Various eligability scoring factors for group meetups"""
_SCORE_REPEAT_ROLE = 2**2 REPEAT_ROLE = 2**2
_SCORE_REPEAT_MATCH = 2**3 REPEAT_MATCH = 2**3
_SCORE_EXTRA_MEMBERS = 2**4 EXTRA_MEMBER = 2**5
# Scores higher than this are fully rejected # Scores higher than this are fully rejected
_SCORE_UPPER_THRESHOLD = 2**6 UPPER_THRESHOLD = 2**6
logger = logging.getLogger("matching") logger = logging.getLogger("matching")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -69,33 +67,42 @@ def members_to_groups_simple(matchees: list[Member], per_group: int) -> tuple[bo
def get_member_group_eligibility_score(member: Member, def get_member_group_eligibility_score(member: Member,
group: list[Member], group: list[Member],
relevant_matches: list[int], prior_matches: list[int],
per_group: int) -> int: per_group: int) -> float:
"""Rates a member against a group""" """Rates a member against a group"""
rating = len(group) * _SCORE_CURRENT_MEMBERS # An empty group is a "perfect" score atomatically
rating = 0
if not group:
return rating
repeat_meetings = sum(m.id in relevant_matches for m in group) # Add score based on prior matchups of this user
rating += repeat_meetings * _SCORE_REPEAT_MATCH rating += sum(m.id in prior_matches for m in group) * \
_ScoreFactors.REPEAT_MATCH
repeat_roles = sum(r in member.roles for r in (m.roles for m in group)) # Calculate the number of roles that match
rating += (repeat_roles * _SCORE_REPEAT_ROLE) all_role_ids = set(r.id for mr in [r.roles for r in group] for r in mr)
member_role_ids = [r.id for r in member.roles]
repeat_roles = sum(id in member_role_ids for id in all_role_ids)
rating += repeat_roles * _ScoreFactors.REPEAT_ROLE
extra_members = len(group) - per_group # Add score based on the number of extra members
if extra_members > 0: # Calculate the member offset (+1 for this user)
rating += extra_members * _SCORE_EXTRA_MEMBERS extra_members = (len(group) - per_group) + 1
if extra_members >= 0:
rating += extra_members * _ScoreFactors.EXTRA_MEMBER
return rating return rating
def attempt_create_groups(matchees: list[Member], def attempt_create_groups(matchees: list[Member],
hist: state.State, current_state: state.State,
oldest_relevant_ts: datetime, oldest_relevant_ts: datetime,
per_group: int) -> tuple[bool, list[list[Member]]]: per_group: int) -> tuple[bool, list[list[Member]]]:
"""History aware group matching""" """History aware group matching"""
num_groups = max(len(matchees)//per_group, 1) num_groups = max(len(matchees)//per_group, 1)
# Set up the groups in place # Set up the groups in place
groups = list([] for _ in range(num_groups)) groups = [[] for _ in range(num_groups)]
matchees_left = matchees.copy() matchees_left = matchees.copy()
@ -103,21 +110,21 @@ def attempt_create_groups(matchees: list[Member],
while matchees_left: while matchees_left:
# Get the next matchee to place # Get the next matchee to place
matchee = matchees_left.pop() matchee = matchees_left.pop()
matchee_matches = hist.matchees.get( matchee_matches = current_state.get_user_matches(matchee.id)
str(matchee.id), {}).get("matches", {}) relevant_matches = [int(id) for id, ts
relevant_matches = list(int(id) for id, ts in matchee_matches.items() in matchee_matches.items()
if state.ts_to_datetime(ts) >= oldest_relevant_ts) if state.ts_to_datetime(ts) >= oldest_relevant_ts]
# Try every single group from the current group onwards # Try every single group from the current group onwards
# Progressing through the groups like this ensures we slowly fill them up with compatible people # Progressing through the groups like this ensures we slowly fill them up with compatible people
scores: list[tuple[int, int]] = [] scores: list[tuple[int, float]] = []
for group in groups: for group in groups:
score = get_member_group_eligibility_score( score = get_member_group_eligibility_score(
matchee, group, relevant_matches, num_groups) matchee, group, relevant_matches, per_group)
# If the score isn't too high, consider this group # If the score isn't too high, consider this group
if score <= _SCORE_UPPER_THRESHOLD: if score <= _ScoreFactors.UPPER_THRESHOLD:
scores.append((group, score)) scores.append((group, score))
# Optimisation: # Optimisation:
@ -143,31 +150,41 @@ def datetime_range(start_time: datetime, increment: timedelta, end: datetime):
current += increment current += increment
def iterate_all_shifts(list: list):
"""Yields each shifted variation of the input list"""
yield list
for _ in range(len(list)-1):
list = list[1:] + [list[0]]
yield list
def members_to_groups(matchees: list[Member], def members_to_groups(matchees: list[Member],
hist: state.State = state.State(), hist: state.State = state.State(),
per_group: int = 3, per_group: int = 3,
allow_fallback: bool = False) -> list[list[Member]]: allow_fallback: bool = False) -> list[list[Member]]:
"""Generate the groups from the set of matchees""" """Generate the groups from the set of matchees"""
attempts = 0 # Tracking for logging purposes attempts = 0 # Tracking for logging purposes
rand = random.Random(117) # Some stable randomness num_groups = len(matchees)//per_group
# Bail early if there's no-one to match
if not matchees:
return []
# Grab the oldest timestamp # Grab the oldest timestamp
history_start = hist.oldest_history() or datetime.now() history_start = hist.get_oldest_timestamp() or datetime.now()
# Walk from the start of time until now using the timestep increment # Walk from the start of time until now using the timestep increment
for oldest_relevant_datetime in datetime_range(history_start, _ATTEMPT_TIMESTEP_INCREMENT, datetime.now()): for oldest_relevant_datetime in datetime_range(history_start, _ATTEMPT_TIMESTEP_INCREMENT, datetime.now()):
# Have a few attempts before stepping forward in time # Attempt with each starting matchee
for _ in range(_ATTEMPTS_PER_TIMESTEP): for shifted_matchees in iterate_all_shifts(matchees):
rand.shuffle(matchees) # Shuffle the matchees each attempt
attempts += 1 attempts += 1
groups = attempt_create_groups( groups = attempt_create_groups(
matchees, hist, oldest_relevant_datetime, per_group) shifted_matchees, hist, oldest_relevant_datetime, per_group)
# Fail the match if our groups aren't big enough # Fail the match if our groups aren't big enough
if (len(matchees)//per_group) <= 1 or (groups and all(len(g) >= per_group for g in groups)): if num_groups <= 1 or (groups and all(len(g) >= per_group for g in groups)):
logger.info("Matched groups after %s attempt(s)", attempts) logger.info("Matched groups after %s attempt(s)", attempts)
return groups return groups
@ -176,6 +193,10 @@ def members_to_groups(matchees: list[Member],
logger.info("Fell back to simple groups after %s attempt(s)", attempts) logger.info("Fell back to simple groups after %s attempt(s)", attempts)
return members_to_groups_simple(matchees, per_group) return members_to_groups_simple(matchees, per_group)
# Simply assert false, this should never happen
# And should be caught by tests
assert False
def group_to_message(group: list[Member]) -> str: def group_to_message(group: list[Member]) -> str:
"""Get the message to send for each group""" """Get the message to send for each group"""
@ -185,8 +206,3 @@ def group_to_message(group: list[Member]) -> str:
else: else:
mentions = mentions[0] mentions = mentions[0]
return f"Matched up {mentions}!" return f"Matched up {mentions}!"
def get_role_from_guild(guild: Guild, role: str) -> Role:
"""Find a role in a guild"""
return next((r for r in guild.roles if r.name == role), None)

View file

@ -1,5 +1,5 @@
""" """
Test functions for Matchy Test functions for the matching module
""" """
import discord import discord
import pytest import pytest
@ -30,6 +30,7 @@ class Role():
class Member(): class Member():
def __init__(self, id: int, roles: list[Role] = []): def __init__(self, id: int, roles: list[Role] = []):
self._id = id self._id = id
self._roles = roles
@property @property
def mention(self) -> str: def mention(self) -> str:
@ -37,7 +38,7 @@ class Member():
@property @property
def roles(self) -> list[Role]: def roles(self) -> list[Role]:
return [] return self._roles
@property @property
def id(self) -> int: def id(self) -> int:
@ -153,7 +154,59 @@ def items_found_in_lists(list_of_lists, items):
# Nothing specific to validate # Nothing specific to validate
] ]
), ),
], ids=['simple_history', 'fallback']) # Specific test pulled out of the stress test
(
[
{
"ts": datetime.now() - timedelta(days=4),
"groups": [
[Member(i) for i in [1, 2, 3, 4, 5, 6,
7, 8, 9, 10, 11, 12, 13, 14, 15]]
]
},
{
"ts": datetime.now() - timedelta(days=5),
"groups": [
[Member(i) for i in [1, 2, 3, 4, 5, 6, 7, 8]]
]
}
],
[Member(i) for i in [1, 2, 11, 4, 12, 3, 7, 5, 8, 10, 9, 6]],
3,
[
# Nothing specific to validate
]
),
# Silly example that failued due to bad role logic
(
[
# No history
],
[
# print([(m.id, [r.id for r in m.roles]) for m in matchees]) to get the below
Member(i, [Role(r) for r in roles]) for (i, roles) in
[
(4, [1, 2, 3, 4, 5, 6, 7, 8]),
(8, [1]),
(9, [1, 2, 3, 4, 5]),
(6, [1, 2, 3]),
(11, [1, 2, 3]),
(7, [1, 2, 3, 4, 5, 6, 7]),
(1, [1, 2, 3, 4]),
(5, [1, 2, 3, 4, 5]),
(12, [1, 2, 3, 4]),
(10, [1]),
(13, [1, 2, 3, 4, 5, 6]),
(2, [1, 2, 3, 4, 5, 6]),
(3, [1, 2, 3, 4, 5, 6, 7])
]
],
2,
[
# Nothing else
]
)
], ids=['simple_history', 'fallback', 'example_1', 'example_2'])
def test_members_to_groups_with_history(history_data, matchees, per_group, checks): def test_members_to_groups_with_history(history_data, matchees, per_group, checks):
"""Test more advanced group matching works""" """Test more advanced group matching works"""
tmp_state = state.State() tmp_state = state.State()
@ -180,8 +233,8 @@ def test_members_to_groups_stress_test():
# Slowly ramp a randomized shuffled list of members with randomised roles # Slowly ramp a randomized shuffled list of members with randomised roles
for num_members in range(1, 5): for num_members in range(1, 5):
matchees = list(Member(i, list(Role(i) for i in range(1, rand.randint(2, num_members*2 + 1)))) matchees = [Member(i, [Role(i) for i in range(1, rand.randint(2, num_members*2 + 1))])
for i in range(1, rand.randint(2, num_members*10 + 1))) for i in range(1, rand.randint(2, num_members*10 + 1))]
rand.shuffle(matchees) rand.shuffle(matchees)
for num_history in range(8): for num_history in range(8):
@ -190,14 +243,14 @@ def test_members_to_groups_stress_test():
# Start some time from now to the past # Start some time from now to the past
time = datetime.now() - timedelta(days=rand.randint(0, num_history*5)) time = datetime.now() - timedelta(days=rand.randint(0, num_history*5))
history_data = [] history_data = []
for x in range(0, num_history): for _ in range(0, num_history):
run = { run = {
"ts": time "ts": time
} }
groups = [] groups = []
for y in range(1, num_history): for y in range(1, num_history):
groups.append(list(Member(i) groups.append([Member(i)
for i in range(1, max(num_members, rand.randint(2, num_members*10 + 1))))) for i in range(1, max(num_members, rand.randint(2, num_members*10 + 1)))])
run["groups"] = groups run["groups"] = groups
history_data.append(run) history_data.append(run)
@ -212,4 +265,32 @@ def test_members_to_groups_stress_test():
for d in history_data: for d in history_data:
tmp_state.log_groups(d["groups"], d["ts"]) tmp_state.log_groups(d["groups"], d["ts"])
inner_validate_members_to_groups(matchees, tmp_state, per_group) inner_validate_members_to_groups(
matchees, tmp_state, per_group)
def test_auth_scopes():
tmp_state = state.State()
id = "1"
tmp_state.set_user_scope(id, state.AuthScope.OWNER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
id = "2"
tmp_state.set_user_scope(id, state.AuthScope.MATCHER)
assert not tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
tmp_state.validate()
def test_iterate_all_shifts():
original = [1, 2, 3, 4]
lists = [val for val in matching.iterate_all_shifts(original)]
assert lists == [
[1, 2, 3, 4],
[2, 3, 4, 1],
[3, 4, 1, 2],
[4, 1, 2, 3],
]

130
matchy.py
View file

@ -11,8 +11,11 @@ import config
import re import re
Config = config.load() STATE_FILE = "state.json"
State = state.load() CONFIG_FILE = "config.json"
Config = config.load_from_file(CONFIG_FILE)
State = state.load_from_file(STATE_FILE)
logger = logging.getLogger("matchy") logger = logging.getLogger("matchy")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -39,7 +42,7 @@ async def on_ready():
def owner_only(ctx: commands.Context) -> bool: def owner_only(ctx: commands.Context) -> bool:
"""Checks the author is an owner""" """Checks the author is an owner"""
return ctx.message.author.id in Config.owners return State.get_user_has_scope(ctx.message.author.id, state.AuthScope.OWNER)
@bot.command() @bot.command()
@ -47,9 +50,10 @@ def owner_only(ctx: commands.Context) -> bool:
@commands.check(owner_only) @commands.check(owner_only)
async def sync(ctx: commands.Context): async def sync(ctx: commands.Context):
"""Handle sync command""" """Handle sync command"""
msg = await ctx.reply("Reloading config...", ephemeral=True) msg = await ctx.reply("Reloading state...", ephemeral=True)
Config.reload() global State
logger.info("Reloaded config") State = state.load_from_file(STATE_FILE)
logger.info("Reloaded state")
await msg.edit(content="Syncing commands...") await msg.edit(content="Syncing commands...")
synced = await bot.tree.sync() synced = await bot.tree.sync()
@ -68,96 +72,112 @@ async def close(ctx: commands.Context):
await bot.close() await bot.close()
@bot.tree.command(description="Join the matchees for this channel")
@commands.guild_only()
async def join(interaction: discord.Interaction):
State.set_use_active_in_channel(
interaction.user.id, interaction.channel.id)
state.save_to_file(State, STATE_FILE)
await interaction.response.send_message(
f"Roger roger {interaction.user.mention}!\n"
+ f"Added you to {interaction.channel.mention}!",
ephemeral=True, silent=True)
@bot.tree.command(description="Leave the matchees for this channel")
@commands.guild_only()
async def leave(interaction: discord.Interaction):
State.set_use_active_in_channel(
interaction.user.id, interaction.channel.id, False)
state.save_to_file(State, STATE_FILE)
await interaction.response.send_message(
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
@bot.tree.command(description="List the matchees for this channel")
@commands.guild_only()
async def list(interaction: discord.Interaction):
matchees = get_matchees_in_channel(interaction.channel)
mentions = [m.mention for m in matchees]
msg = "Current matchees in this channel:\n" + \
f"{', '.join(mentions[:-1])} and {mentions[-1]}"
await interaction.response.send_message(msg, ephemeral=True, silent=True)
@bot.tree.command(description="Match up matchees") @bot.tree.command(description="Match up matchees")
@commands.guild_only() @commands.guild_only()
@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)", @app_commands.describe(members_min="Minimum matchees per match (defaults to 3)")
matchee_role="Role for matchees (defaults to @Matchee)") async def match(interaction: discord.Interaction, members_min: int = None):
async def match(interaction: discord.Interaction, members_min: int = None, matchee_role: str = None):
"""Match groups of channel members""" """Match groups of channel members"""
logger.info("Handling request '/match group_min=%s matchee_role=%s'", logger.info("Handling request '/match group_min=%s", members_min)
members_min, matchee_role)
logger.info("User %s from %s in #%s", interaction.user, logger.info("User %s from %s in #%s", interaction.user,
interaction.guild.name, interaction.channel.name) interaction.guild.name, interaction.channel.name)
# Sort out the defaults, if not specified they'll come in as None # Sort out the defaults, if not specified they'll come in as None
if not members_min: if not members_min:
members_min = 3 members_min = 3
if not matchee_role:
matchee_role = "Matchee"
# Grab the roles and verify the given role # Grab the groups
matcher = matching.get_role_from_guild(interaction.guild, "Matcher") groups = active_members_to_groups(interaction.channel, members_min)
matcher = matcher and matcher in interaction.user.roles
matchee = matching.get_role_from_guild(interaction.guild, matchee_role) # Let the user know when there's nobody to match
if not matchee: if not groups:
await interaction.response.send_message(f"Server is missing '{matchee_role}' role :(", ephemeral=True) await interaction.response.send_message("Nobody to match up :(", ephemeral=True, silent=True)
return return
# Create some example groups to show the user
matchees = list(
m for m in interaction.channel.members if matchee in m.roles)
groups = matching.members_to_groups(
matchees, State, members_min, allow_fallback=True)
# Post about all the groups with a button to send to the channel # Post about all the groups with a button to send to the channel
groups_list = '\n'.join(matching.group_to_message(g) for g in groups) groups_list = '\n'.join(matching.group_to_message(g) for g in groups)
msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}" msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}"
view = discord.utils.MISSING view = discord.utils.MISSING
if not matcher: if State.get_user_has_scope(interaction.user.id, state.AuthScope.MATCHER):
# Let a non-matcher know why they don't have the button # Let a non-matcher know why they don't have the button
msg += "\n\nYou'll need the 'Matcher' role to post this to the channel, sorry!" msg += f"\n\nYou'll need the {state.AuthScope.MATCHER} scope to post this to the channel, sorry!"
else: else:
# Otherwise set up the button # Otherwise set up the button
msg += "\n\nClick the button to match up groups and send them to the channel.\n" msg += "\n\nClick the button to match up groups and send them to the channel.\n"
view = discord.ui.View(timeout=None) view = discord.ui.View(timeout=None)
view.add_item(DynamicGroupButton(members_min, matchee_role)) view.add_item(DynamicGroupButton(members_min))
await interaction.response.send_message(msg, ephemeral=True, silent=True, view=view) await interaction.response.send_message(msg, ephemeral=True, silent=True, view=view)
logger.info("Done.") logger.info("Done.")
# Increment when adjusting the custom_id so we don't confuse old users
_BUTTON_CUSTOM_ID_VERSION = 1
class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button], class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
template=r'match:min:(?P<min>[0-9]+):role:(?P<role>[@\w\s]+)'): template=f'match:v{_BUTTON_CUSTOM_ID_VERSION}:' + r'min:(?P<min>[0-9]+)'):
def __init__(self, min: int, role: str) -> None: def __init__(self, min: int) -> None:
super().__init__( super().__init__(
discord.ui.Button( discord.ui.Button(
label='Match Groups!', label='Match Groups!',
style=discord.ButtonStyle.blurple, style=discord.ButtonStyle.blurple,
custom_id=f'match:min:{min}:role:{role}', custom_id=f'match:min:{min}',
) )
) )
self.min: int = min self.min: int = min
self.role: int = role
# This is called when the button is clicked and the custom_id matches the template. # This is called when the button is clicked and the custom_id matches the template.
@classmethod @classmethod
async def from_custom_id(cls, interaction: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /): async def from_custom_id(cls, interaction: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /):
min = int(match['min']) min = int(match['min'])
role = str(match['role']) return cls(min)
return cls(min, role)
async def callback(self, interaction: discord.Interaction) -> None: async def callback(self, interaction: discord.Interaction) -> None:
"""Match up people when the button is pressed""" """Match up people when the button is pressed"""
logger.info("Handling button press min=%s role=%s'", logger.info("Handling button press min=%s", self.min)
self.min, self.role)
logger.info("User %s from %s in #%s", interaction.user, logger.info("User %s from %s in #%s", interaction.user,
interaction.guild.name, interaction.channel.name) interaction.guild.name, interaction.channel.name)
# Let the user know we've recieved the message # Let the user know we've recieved the message
await interaction.response.send_message(content="Matchy is matching matchees...", ephemeral=True) await interaction.response.send_message(content="Matchy is matching matchees...", ephemeral=True)
# Grab the role groups = active_members_to_groups(interaction.channel, self.min)
matchee = matching.get_role_from_guild(interaction.guild, self.role)
# Create our groups!
matchees = list(
m for m in interaction.channel.members if matchee in m.roles)
groups = matching.members_to_groups(
matchees, State, self.min, allow_fallback=True)
# Send the groups # Send the groups
for msg in (matching.group_to_message(g) for g in groups): for msg in (matching.group_to_message(g) for g in groups):
@ -167,10 +187,26 @@ class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
await interaction.channel.send("That's all folks, happy matching and remember - DFTBA!") await interaction.channel.send("That's all folks, happy matching and remember - DFTBA!")
# Save the groups to the history # Save the groups to the history
State.save_groups(groups) State.log_groups(groups)
state.save_to_file(State, STATE_FILE)
logger.info("Done. Matched %s matchees into %s groups.", logger.info("Done! Matched into %s groups.", len(groups))
len(matchees), len(groups))
def get_matchees_in_channel(channel: discord.channel):
"""Fetches the matchees in a channel"""
# Gather up the prospective matchees
return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)]
def active_members_to_groups(channel: discord.channel, min_members: int):
"""Helper to create groups from channel members"""
# Gather up the prospective matchees
matchees = get_matchees_in_channel(channel)
# Create our groups!
return matching.members_to_groups(matchees, State, min_members, allow_fallback=True)
if __name__ == "__main__": if __name__ == "__main__":

238
state.py
View file

@ -5,22 +5,65 @@ from schema import Schema, And, Use, Optional
from typing import Protocol from typing import Protocol
import files import files
import copy import copy
import logging
logger = logging.getLogger("state")
logger.setLevel(logging.INFO)
_FILE = "state.json"
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility # Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_DEFAULT_DICT = { _VERSION = 1
"history": {},
"matchees": {}
} def _migrate_to_v1(d: dict):
logger.info("Renaming %s to %s", _Key.MATCHEES, _Key.USERS)
d[_Key.USERS] = d[_Key.MATCHEES]
del d[_Key.MATCHEES]
# Set of migration functions to apply
_MIGRATIONS = [
_migrate_to_v1
]
class AuthScope(str):
"""Various auth scopes"""
OWNER = "owner"
MATCHER = "matcher"
class _Key(str):
"""Various keys used in the schema"""
HISTORY = "history"
GROUPS = "groups"
MEMBERS = "members"
USERS = "users"
SCOPES = "scopes"
MATCHES = "matches"
ACTIVE = "active"
CHANNELS = "channels"
REACTIVATE = "reactivate"
VERSION = "version"
# Unused
MATCHEES = "matchees"
_TIME_FORMAT = "%a %b %d %H:%M:%S %Y" _TIME_FORMAT = "%a %b %d %H:%M:%S %Y"
_SCHEMA = Schema( _SCHEMA = Schema(
{ {
Optional("history"): { # The current version
Optional(str): { # a datetime _Key.VERSION: And(Use(int)),
"groups": [
Optional(_Key.HISTORY): {
# A datetime
Optional(str): {
_Key.GROUPS: [
{ {
"members": [ _Key.MEMBERS: [
# The ID of each matchee in the match # The ID of each matchee in the match
And(Use(int)) And(Use(int))
] ]
@ -28,17 +71,33 @@ _SCHEMA = Schema(
] ]
} }
}, },
Optional("matchees"): { Optional(_Key.USERS): {
Optional(str): { Optional(str): {
Optional("matches"): { Optional(_Key.SCOPES): And(Use(list[str])),
Optional(_Key.MATCHES): {
# Matchee ID and Datetime pair # Matchee ID and Datetime pair
Optional(str): And(Use(str)) Optional(str): And(Use(str))
},
Optional(_Key.CHANNELS): {
# The channel ID
Optional(str): {
# Whether the user is signed up in this channel
_Key.ACTIVE: And(Use(bool)),
} }
} }
} }
},
} }
) )
# Empty but schema-valid internal dict
_EMPTY_DICT = {
_Key.HISTORY: {},
_Key.USERS: {},
_Key.VERSION: _VERSION
}
assert _SCHEMA.validate(_EMPTY_DICT)
class Member(Protocol): class Member(Protocol):
@property @property
@ -51,75 +110,148 @@ def ts_to_datetime(ts: str) -> datetime:
return datetime.strptime(ts, _TIME_FORMAT) return datetime.strptime(ts, _TIME_FORMAT)
def validate(dict: dict): class State():
def __init__(self, data: dict = _EMPTY_DICT):
"""Initialise and validate the state""" """Initialise and validate the state"""
self.validate(data)
self._dict = copy.deepcopy(data)
@property
def _history(self) -> dict[str]:
return self._dict[_Key.HISTORY]
@property
def _users(self) -> dict[str]:
return self._dict[_Key.USERS]
def validate(self, dict: dict = None):
"""Initialise and validate a state dict"""
if not dict:
dict = self._dict
_SCHEMA.validate(dict) _SCHEMA.validate(dict)
def get_oldest_timestamp(self) -> datetime:
class State():
def __init__(self, data: dict = _DEFAULT_DICT):
"""Initialise and validate the state"""
validate(data)
self.__dict__ = copy.deepcopy(data)
@property
def history(self) -> list[dict]:
return self.__dict__["history"]
@property
def matchees(self) -> dict[str, dict]:
return self.__dict__["matchees"]
def save(self) -> None:
"""Save out the state"""
files.save(_FILE, self.__dict__)
def oldest_history(self) -> datetime:
"""Grab the oldest timestamp in history""" """Grab the oldest timestamp in history"""
if not self.history: times = (ts_to_datetime(dt) for dt in self._history.keys())
return None return next(times, None)
times = (ts_to_datetime(dt) for dt in self.history.keys())
return sorted(times)[0] def get_user_matches(self, id: int) -> list[int]:
return self._users.get(str(id), {}).get(_Key.MATCHES, {})
def log_groups(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None: def log_groups(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None:
"""Log the groups""" """Log the groups"""
tmp_state = State(self.__dict__) tmp_state = State(self._dict)
ts = datetime.strftime(ts, _TIME_FORMAT) ts = datetime.strftime(ts, _TIME_FORMAT)
# Grab or create the hitory item for this set of groups # Grab or create the hitory item for this set of groups
history_item = {} history_item = {}
tmp_state.history[ts] = history_item tmp_state._history[ts] = history_item
history_item_groups = [] history_item_groups = []
history_item["groups"] = history_item_groups history_item[_Key.GROUPS] = history_item_groups
for group in groups: for group in groups:
# Add the group data # Add the group data
history_item_groups.append({ history_item_groups.append({
"members": list(m.id for m in group) _Key.MEMBERS: [m.id for m in group]
}) })
# Update the matchee data with the matches # Update the matchee data with the matches
for m in group: for m in group:
matchee = tmp_state.matchees.get(str(m.id), {}) matchee = tmp_state._users.get(str(m.id), {})
matchee_matches = matchee.get("matches", {}) matchee_matches = matchee.get(_Key.MATCHES, {})
for o in (o for o in group if o.id != m.id): for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts matchee_matches[str(o.id)] = ts
matchee["matches"] = matchee_matches matchee[_Key.MATCHES] = matchee_matches
tmp_state.matchees[str(m.id)] = matchee tmp_state._users[str(m.id)] = matchee
# Validate before storing the result # Validate before storing the result
validate(self.__dict__) tmp_state.validate()
self.__dict__ = tmp_state.__dict__ self._dict = tmp_state._dict
def save_groups(self, groups: list[list[Member]]) -> None: def set_user_scope(self, id: str, scope: str, value: bool = True):
"""Save out the groups to the state file""" """Add an auth scope to a user"""
self.log_groups(groups) # Dive in
self.save() user = self._users.get(str(id), {})
scopes = user.get(_Key.SCOPES, [])
# Set the value
if value and scope not in scopes:
scopes.append(scope)
elif not value and scope in scopes:
scopes.remove(scope)
# Roll out
user[_Key.SCOPES] = scopes
self._users[id] = user
def get_user_has_scope(self, id: str, scope: str) -> bool:
"""
Check if a user has an auth scope
"owner" users have all scopes
"""
user = self._users.get(str(id), {})
scopes = user.get(_Key.SCOPES, [])
return AuthScope.OWNER in scopes or scope in scopes
def set_use_active_in_channel(self, id: str, channel_id: str, active: bool = True):
"""Set a user as active (or not) on a given channel"""
# Dive in
user = self._users.get(str(id), {})
channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {})
# Set the value
channel[_Key.ACTIVE] = active
# Unroll
channels[str(channel_id)] = channel
user[_Key.CHANNELS] = channels
self._users[str(id)] = user
def get_user_active_in_channel(self, id: str, channel_id: str) -> bool:
"""Get a users active channels"""
user = self._users.get(str(id), {})
channels = user.get(_Key.CHANNELS, {})
return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)]
@property
def dict_internal(self) -> dict:
"""Only to be used to get the internal dict as a copy"""
return copy.deepcopy(self._dict)
def load() -> State: def _migrate(dict: dict):
"""Load the state""" """Migrate a dict through versions"""
return State(files.load(_FILE) if os.path.isfile(_FILE) else _DEFAULT_DICT) version = dict.get("version", 0)
for i in range(version, _VERSION):
logger.info("Migrating from v%s to v%s", version, version+1)
_MIGRATIONS[i](dict)
dict[_Key.VERSION] = _VERSION
def load_from_file(file: str) -> State:
"""
Load the state from a file
Apply any required migrations
"""
loaded = _EMPTY_DICT
# If there's a file load it and try to migrate
if os.path.isfile(file):
loaded = files.load(file)
_migrate(loaded)
st = State(loaded)
# Save out the migrated (or new) file
files.save(file, st._dict)
return st
def save_to_file(state: State, file: str):
"""Saves the state out to a file"""
files.save(file, state.dict_internal)

65
state_test.py Normal file
View file

@ -0,0 +1,65 @@
"""
Test functions for the state module
"""
import state
import tempfile
import os
def test_basic_state():
"""Simple validate basic state load"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
state.load_from_file(path)
def test_simple_load_reload():
"""Test a basic load, save, reload"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
st = state.load_from_file(path)
state.save_to_file(st, path)
st = state.load_from_file(path)
def test_authscope():
"""Test setting and getting an auth scope"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
st = state.load_from_file(path)
st.set_user_scope(1, state.AuthScope.MATCHER)
state.save_to_file(st, path)
st = state.load_from_file(path)
assert st.get_user_has_scope(1, state.AuthScope.MATCHER)
st.set_user_scope(1, state.AuthScope.MATCHER, False)
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
def test_channeljoin():
"""Test setting and getting an active channel"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
assert not st.get_user_active_in_channel(1, "2")
st = state.load_from_file(path)
st.set_use_active_in_channel(1, "2", True)
state.save_to_file(st, path)
st = state.load_from_file(path)
assert st.get_user_active_in_channel(1, "2")
st.set_use_active_in_channel(1, "2", False)
assert not st.get_user_active_in_channel(1, "2")