Significant set of changes

* Use /join and /leave instead of roles
* Use scopes to check for user rights rather than using the config file
* Add /list to show the set of current people signed up
* Add a bunch more testing for various things
* Version both the config and the state
This commit is contained in:
Marc Di Luzio 2024-08-11 17:53:37 +01:00
parent 78834f5319
commit d3a22ff090
8 changed files with 537 additions and 295 deletions

View file

@ -20,18 +20,16 @@ Only usable by `OWNER` users, reloads the config and syncs commands, or closes d
Matchy is configured by a `config.json` file that takes this format:
```
{
"version": 1,
"token": "<<github bot token>>",
"owners": [
<<owner id>>
]
}
```
User IDs can be grabbed by turning on Discord's developer mode and right clicking on a user.
## TODO
* Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html)
* Implement /pause to pause a user for a little while
* Move more constants to the config
* Add scheduling functionality
* Version the config and history files
* Implement /signup rather than using roles
* Implement authorisation scopes instead of just OWNER values
* Fix logging in some sub files
* Improve the weirdo

View file

@ -1,38 +1,77 @@
"""Very simple config loading library"""
from schema import Schema, And, Use
import files
import os
import logging
logger = logging.getLogger("config")
logger.setLevel(logging.INFO)
_FILE = "config.json"
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_VERSION = 1
class _Keys():
TOKEN = "token"
VERSION = "version"
# Removed
OWNERS = "owners"
_SCHEMA = Schema(
{
# Discord bot token
"token": And(Use(str)),
# The current version
_Keys.VERSION: And(Use(int)),
# ids of owners authorised to use owner-only commands
"owners": And(Use(list[int])),
# Discord bot token
_Keys.TOKEN: And(Use(str)),
}
)
def _migrate_to_v1(d: dict):
# Owners moved to History in v1
# Note: owners will be required to be re-added to the state.json
owners = d.pop(_Keys.OWNERS)
logger.warn(
"Migration removed owners from config, these must be re-added to the state.json")
logger.warn("Owners: %s", owners)
# Set of migration functions to apply
_MIGRATIONS = [
_migrate_to_v1
]
class Config():
def __init__(self, data: dict):
"""Initialise and validate the config"""
_SCHEMA.validate(data)
self.__dict__ = data
self._dict = data
@property
def token(self) -> str:
return self.__dict__["token"]
@property
def owners(self) -> list[int]:
return self.__dict__["owners"]
def reload(self) -> None:
"""Reload the config back into the dict"""
self.__dict__ = load().__dict__
return self._dict["token"]
def load() -> Config:
"""Load the config"""
return Config(files.load(_FILE))
def _migrate(dict: dict):
"""Migrate a dict through versions"""
version = dict.get("version", 0)
for i in range(version, _VERSION):
_MIGRATIONS[i](dict)
dict["version"] = _VERSION
def load_from_file(file: str = _FILE) -> Config:
"""
Load the state from a file
Apply any required migrations
"""
assert os.path.isfile(file)
loaded = files.load(file)
_migrate(loaded)
return Config(loaded)

View file

@ -1,125 +0,0 @@
"""Store matching history"""
import os
from datetime import datetime
from schema import Schema, And, Use, Optional
from typing import Protocol
import files
import copy
_FILE = "history.json"
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_DEFAULT_DICT = {
"history": {},
"matchees": {}
}
_TIME_FORMAT = "%a %b %d %H:%M:%S %Y"
_SCHEMA = Schema(
{
Optional("history"): {
Optional(str): { # a datetime
"groups": [
{
"members": [
# The ID of each matchee in the match
And(Use(int))
]
}
]
}
},
Optional("matchees"): {
Optional(str): {
Optional("matches"): {
# Matchee ID and Datetime pair
Optional(str): And(Use(str))
}
}
}
}
)
class Member(Protocol):
@property
def id(self) -> int:
pass
def ts_to_datetime(ts: str) -> datetime:
"""Convert a ts to datetime using the history format"""
return datetime.strptime(ts, _TIME_FORMAT)
def validate(dict: dict):
"""Initialise and validate the history"""
_SCHEMA.validate(dict)
class History():
def __init__(self, data: dict = _DEFAULT_DICT):
"""Initialise and validate the history"""
validate(data)
self.__dict__ = copy.deepcopy(data)
@property
def history(self) -> list[dict]:
return self.__dict__["history"]
@property
def matchees(self) -> dict[str, dict]:
return self.__dict__["matchees"]
def save(self) -> None:
"""Save out the history"""
files.save(_FILE, self.__dict__)
def oldest(self) -> datetime:
"""Grab the oldest timestamp in history"""
if not self.history:
return None
times = (ts_to_datetime(dt) for dt in self.history.keys())
return sorted(times)[0]
def log_groups_to_history(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None:
"""Log the groups"""
tmp_history = History(self.__dict__)
ts = datetime.strftime(ts, _TIME_FORMAT)
# Grab or create the hitory item for this set of groups
history_item = {}
tmp_history.history[ts] = history_item
history_item_groups = []
history_item["groups"] = history_item_groups
for group in groups:
# Add the group data
history_item_groups.append({
"members": list(m.id for m in group)
})
# Update the matchee data with the matches
for m in group:
matchee = tmp_history.matchees.get(str(m.id), {})
matchee_matches = matchee.get("matches", {})
for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts
matchee["matches"] = matchee_matches
tmp_history.matchees[str(m.id)] = matchee
# Validate before storing the result
validate(self.__dict__)
self.__dict__ = tmp_history.__dict__
def save_groups_to_history(self, groups: list[list[Member]]) -> None:
"""Save out the groups to the history file"""
self.log_groups_to_history(groups)
self.save()
def load() -> History:
"""Load the history"""
return History(files.load(_FILE) if os.path.isfile(_FILE) else _DEFAULT_DICT)

View file

@ -1,6 +1,5 @@
"""Utility functions for matchy"""
import logging
import random
from datetime import datetime, timedelta
from typing import Protocol, runtime_checkable
import state
@ -9,17 +8,16 @@ import state
# Number of days to step forward from the start of history for each match attempt
_ATTEMPT_TIMESTEP_INCREMENT = timedelta(days=7)
# Attempts for each of those time periods
_ATTEMPTS_PER_TIMESTEP = 3
# Various eligability scoring factors for group meetups
_SCORE_CURRENT_MEMBERS = 2**1
_SCORE_REPEAT_ROLE = 2**2
_SCORE_REPEAT_MATCH = 2**3
_SCORE_EXTRA_MEMBERS = 2**4
class _ScoreFactors(int):
"""Various eligability scoring factors for group meetups"""
REPEAT_ROLE = 2**2
REPEAT_MATCH = 2**3
EXTRA_MEMBER = 2**5
# Scores higher than this are fully rejected
UPPER_THRESHOLD = 2**6
# Scores higher than this are fully rejected
_SCORE_UPPER_THRESHOLD = 2**6
logger = logging.getLogger("matching")
logger.setLevel(logging.INFO)
@ -69,33 +67,42 @@ def members_to_groups_simple(matchees: list[Member], per_group: int) -> tuple[bo
def get_member_group_eligibility_score(member: Member,
group: list[Member],
relevant_matches: list[int],
per_group: int) -> int:
prior_matches: list[int],
per_group: int) -> float:
"""Rates a member against a group"""
rating = len(group) * _SCORE_CURRENT_MEMBERS
# An empty group is a "perfect" score atomatically
rating = 0
if not group:
return rating
repeat_meetings = sum(m.id in relevant_matches for m in group)
rating += repeat_meetings * _SCORE_REPEAT_MATCH
# Add score based on prior matchups of this user
rating += sum(m.id in prior_matches for m in group) * \
_ScoreFactors.REPEAT_MATCH
repeat_roles = sum(r in member.roles for r in (m.roles for m in group))
rating += (repeat_roles * _SCORE_REPEAT_ROLE)
# Calculate the number of roles that match
all_role_ids = set(r.id for mr in [r.roles for r in group] for r in mr)
member_role_ids = [r.id for r in member.roles]
repeat_roles = sum(id in member_role_ids for id in all_role_ids)
rating += repeat_roles * _ScoreFactors.REPEAT_ROLE
extra_members = len(group) - per_group
if extra_members > 0:
rating += extra_members * _SCORE_EXTRA_MEMBERS
# Add score based on the number of extra members
# Calculate the member offset (+1 for this user)
extra_members = (len(group) - per_group) + 1
if extra_members >= 0:
rating += extra_members * _ScoreFactors.EXTRA_MEMBER
return rating
def attempt_create_groups(matchees: list[Member],
hist: state.State,
current_state: state.State,
oldest_relevant_ts: datetime,
per_group: int) -> tuple[bool, list[list[Member]]]:
"""History aware group matching"""
num_groups = max(len(matchees)//per_group, 1)
# Set up the groups in place
groups = list([] for _ in range(num_groups))
groups = [[] for _ in range(num_groups)]
matchees_left = matchees.copy()
@ -103,21 +110,21 @@ def attempt_create_groups(matchees: list[Member],
while matchees_left:
# Get the next matchee to place
matchee = matchees_left.pop()
matchee_matches = hist.matchees.get(
str(matchee.id), {}).get("matches", {})
relevant_matches = list(int(id) for id, ts in matchee_matches.items()
if state.ts_to_datetime(ts) >= oldest_relevant_ts)
matchee_matches = current_state.get_user_matches(matchee.id)
relevant_matches = [int(id) for id, ts
in matchee_matches.items()
if state.ts_to_datetime(ts) >= oldest_relevant_ts]
# Try every single group from the current group onwards
# Progressing through the groups like this ensures we slowly fill them up with compatible people
scores: list[tuple[int, int]] = []
scores: list[tuple[int, float]] = []
for group in groups:
score = get_member_group_eligibility_score(
matchee, group, relevant_matches, num_groups)
matchee, group, relevant_matches, per_group)
# If the score isn't too high, consider this group
if score <= _SCORE_UPPER_THRESHOLD:
if score <= _ScoreFactors.UPPER_THRESHOLD:
scores.append((group, score))
# Optimisation:
@ -143,31 +150,41 @@ def datetime_range(start_time: datetime, increment: timedelta, end: datetime):
current += increment
def iterate_all_shifts(list: list):
"""Yields each shifted variation of the input list"""
yield list
for _ in range(len(list)-1):
list = list[1:] + [list[0]]
yield list
def members_to_groups(matchees: list[Member],
hist: state.State = state.State(),
per_group: int = 3,
allow_fallback: bool = False) -> list[list[Member]]:
"""Generate the groups from the set of matchees"""
attempts = 0 # Tracking for logging purposes
rand = random.Random(117) # Some stable randomness
num_groups = len(matchees)//per_group
# Bail early if there's no-one to match
if not matchees:
return []
# Grab the oldest timestamp
history_start = hist.oldest_history() or datetime.now()
history_start = hist.get_oldest_timestamp() or datetime.now()
# Walk from the start of time until now using the timestep increment
for oldest_relevant_datetime in datetime_range(history_start, _ATTEMPT_TIMESTEP_INCREMENT, datetime.now()):
# Have a few attempts before stepping forward in time
for _ in range(_ATTEMPTS_PER_TIMESTEP):
rand.shuffle(matchees) # Shuffle the matchees each attempt
# Attempt with each starting matchee
for shifted_matchees in iterate_all_shifts(matchees):
attempts += 1
groups = attempt_create_groups(
matchees, hist, oldest_relevant_datetime, per_group)
shifted_matchees, hist, oldest_relevant_datetime, per_group)
# Fail the match if our groups aren't big enough
if (len(matchees)//per_group) <= 1 or (groups and all(len(g) >= per_group for g in groups)):
if num_groups <= 1 or (groups and all(len(g) >= per_group for g in groups)):
logger.info("Matched groups after %s attempt(s)", attempts)
return groups
@ -176,6 +193,10 @@ def members_to_groups(matchees: list[Member],
logger.info("Fell back to simple groups after %s attempt(s)", attempts)
return members_to_groups_simple(matchees, per_group)
# Simply assert false, this should never happen
# And should be caught by tests
assert False
def group_to_message(group: list[Member]) -> str:
"""Get the message to send for each group"""
@ -185,8 +206,3 @@ def group_to_message(group: list[Member]) -> str:
else:
mentions = mentions[0]
return f"Matched up {mentions}!"
def get_role_from_guild(guild: Guild, role: str) -> Role:
"""Find a role in a guild"""
return next((r for r in guild.roles if r.name == role), None)

View file

@ -1,5 +1,5 @@
"""
Test functions for Matchy
Test functions for the matching module
"""
import discord
import pytest
@ -30,6 +30,7 @@ class Role():
class Member():
def __init__(self, id: int, roles: list[Role] = []):
self._id = id
self._roles = roles
@property
def mention(self) -> str:
@ -37,7 +38,7 @@ class Member():
@property
def roles(self) -> list[Role]:
return []
return self._roles
@property
def id(self) -> int:
@ -153,7 +154,59 @@ def items_found_in_lists(list_of_lists, items):
# Nothing specific to validate
]
),
], ids=['simple_history', 'fallback'])
# Specific test pulled out of the stress test
(
[
{
"ts": datetime.now() - timedelta(days=4),
"groups": [
[Member(i) for i in [1, 2, 3, 4, 5, 6,
7, 8, 9, 10, 11, 12, 13, 14, 15]]
]
},
{
"ts": datetime.now() - timedelta(days=5),
"groups": [
[Member(i) for i in [1, 2, 3, 4, 5, 6, 7, 8]]
]
}
],
[Member(i) for i in [1, 2, 11, 4, 12, 3, 7, 5, 8, 10, 9, 6]],
3,
[
# Nothing specific to validate
]
),
# Silly example that failued due to bad role logic
(
[
# No history
],
[
# print([(m.id, [r.id for r in m.roles]) for m in matchees]) to get the below
Member(i, [Role(r) for r in roles]) for (i, roles) in
[
(4, [1, 2, 3, 4, 5, 6, 7, 8]),
(8, [1]),
(9, [1, 2, 3, 4, 5]),
(6, [1, 2, 3]),
(11, [1, 2, 3]),
(7, [1, 2, 3, 4, 5, 6, 7]),
(1, [1, 2, 3, 4]),
(5, [1, 2, 3, 4, 5]),
(12, [1, 2, 3, 4]),
(10, [1]),
(13, [1, 2, 3, 4, 5, 6]),
(2, [1, 2, 3, 4, 5, 6]),
(3, [1, 2, 3, 4, 5, 6, 7])
]
],
2,
[
# Nothing else
]
)
], ids=['simple_history', 'fallback', 'example_1', 'example_2'])
def test_members_to_groups_with_history(history_data, matchees, per_group, checks):
"""Test more advanced group matching works"""
tmp_state = state.State()
@ -180,8 +233,8 @@ def test_members_to_groups_stress_test():
# Slowly ramp a randomized shuffled list of members with randomised roles
for num_members in range(1, 5):
matchees = list(Member(i, list(Role(i) for i in range(1, rand.randint(2, num_members*2 + 1))))
for i in range(1, rand.randint(2, num_members*10 + 1)))
matchees = [Member(i, [Role(i) for i in range(1, rand.randint(2, num_members*2 + 1))])
for i in range(1, rand.randint(2, num_members*10 + 1))]
rand.shuffle(matchees)
for num_history in range(8):
@ -190,14 +243,14 @@ def test_members_to_groups_stress_test():
# Start some time from now to the past
time = datetime.now() - timedelta(days=rand.randint(0, num_history*5))
history_data = []
for x in range(0, num_history):
for _ in range(0, num_history):
run = {
"ts": time
}
groups = []
for y in range(1, num_history):
groups.append(list(Member(i)
for i in range(1, max(num_members, rand.randint(2, num_members*10 + 1)))))
groups.append([Member(i)
for i in range(1, max(num_members, rand.randint(2, num_members*10 + 1)))])
run["groups"] = groups
history_data.append(run)
@ -212,4 +265,32 @@ def test_members_to_groups_stress_test():
for d in history_data:
tmp_state.log_groups(d["groups"], d["ts"])
inner_validate_members_to_groups(matchees, tmp_state, per_group)
inner_validate_members_to_groups(
matchees, tmp_state, per_group)
def test_auth_scopes():
tmp_state = state.State()
id = "1"
tmp_state.set_user_scope(id, state.AuthScope.OWNER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
id = "2"
tmp_state.set_user_scope(id, state.AuthScope.MATCHER)
assert not tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
tmp_state.validate()
def test_iterate_all_shifts():
original = [1, 2, 3, 4]
lists = [val for val in matching.iterate_all_shifts(original)]
assert lists == [
[1, 2, 3, 4],
[2, 3, 4, 1],
[3, 4, 1, 2],
[4, 1, 2, 3],
]

130
matchy.py
View file

@ -11,8 +11,11 @@ import config
import re
Config = config.load()
State = state.load()
STATE_FILE = "state.json"
CONFIG_FILE = "config.json"
Config = config.load_from_file(CONFIG_FILE)
State = state.load_from_file(STATE_FILE)
logger = logging.getLogger("matchy")
logger.setLevel(logging.INFO)
@ -39,7 +42,7 @@ async def on_ready():
def owner_only(ctx: commands.Context) -> bool:
"""Checks the author is an owner"""
return ctx.message.author.id in Config.owners
return State.get_user_has_scope(ctx.message.author.id, state.AuthScope.OWNER)
@bot.command()
@ -47,9 +50,10 @@ def owner_only(ctx: commands.Context) -> bool:
@commands.check(owner_only)
async def sync(ctx: commands.Context):
"""Handle sync command"""
msg = await ctx.reply("Reloading config...", ephemeral=True)
Config.reload()
logger.info("Reloaded config")
msg = await ctx.reply("Reloading state...", ephemeral=True)
global State
State = state.load_from_file(STATE_FILE)
logger.info("Reloaded state")
await msg.edit(content="Syncing commands...")
synced = await bot.tree.sync()
@ -68,96 +72,112 @@ async def close(ctx: commands.Context):
await bot.close()
@bot.tree.command(description="Join the matchees for this channel")
@commands.guild_only()
async def join(interaction: discord.Interaction):
State.set_use_active_in_channel(
interaction.user.id, interaction.channel.id)
state.save_to_file(State, STATE_FILE)
await interaction.response.send_message(
f"Roger roger {interaction.user.mention}!\n"
+ f"Added you to {interaction.channel.mention}!",
ephemeral=True, silent=True)
@bot.tree.command(description="Leave the matchees for this channel")
@commands.guild_only()
async def leave(interaction: discord.Interaction):
State.set_use_active_in_channel(
interaction.user.id, interaction.channel.id, False)
state.save_to_file(State, STATE_FILE)
await interaction.response.send_message(
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
@bot.tree.command(description="List the matchees for this channel")
@commands.guild_only()
async def list(interaction: discord.Interaction):
matchees = get_matchees_in_channel(interaction.channel)
mentions = [m.mention for m in matchees]
msg = "Current matchees in this channel:\n" + \
f"{', '.join(mentions[:-1])} and {mentions[-1]}"
await interaction.response.send_message(msg, ephemeral=True, silent=True)
@bot.tree.command(description="Match up matchees")
@commands.guild_only()
@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)",
matchee_role="Role for matchees (defaults to @Matchee)")
async def match(interaction: discord.Interaction, members_min: int = None, matchee_role: str = None):
@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)")
async def match(interaction: discord.Interaction, members_min: int = None):
"""Match groups of channel members"""
logger.info("Handling request '/match group_min=%s matchee_role=%s'",
members_min, matchee_role)
logger.info("Handling request '/match group_min=%s", members_min)
logger.info("User %s from %s in #%s", interaction.user,
interaction.guild.name, interaction.channel.name)
# Sort out the defaults, if not specified they'll come in as None
if not members_min:
members_min = 3
if not matchee_role:
matchee_role = "Matchee"
# Grab the roles and verify the given role
matcher = matching.get_role_from_guild(interaction.guild, "Matcher")
matcher = matcher and matcher in interaction.user.roles
matchee = matching.get_role_from_guild(interaction.guild, matchee_role)
if not matchee:
await interaction.response.send_message(f"Server is missing '{matchee_role}' role :(", ephemeral=True)
# Grab the groups
groups = active_members_to_groups(interaction.channel, members_min)
# Let the user know when there's nobody to match
if not groups:
await interaction.response.send_message("Nobody to match up :(", ephemeral=True, silent=True)
return
# Create some example groups to show the user
matchees = list(
m for m in interaction.channel.members if matchee in m.roles)
groups = matching.members_to_groups(
matchees, State, members_min, allow_fallback=True)
# Post about all the groups with a button to send to the channel
groups_list = '\n'.join(matching.group_to_message(g) for g in groups)
msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}"
view = discord.utils.MISSING
if not matcher:
if State.get_user_has_scope(interaction.user.id, state.AuthScope.MATCHER):
# Let a non-matcher know why they don't have the button
msg += "\n\nYou'll need the 'Matcher' role to post this to the channel, sorry!"
msg += f"\n\nYou'll need the {state.AuthScope.MATCHER} scope to post this to the channel, sorry!"
else:
# Otherwise set up the button
msg += "\n\nClick the button to match up groups and send them to the channel.\n"
view = discord.ui.View(timeout=None)
view.add_item(DynamicGroupButton(members_min, matchee_role))
view.add_item(DynamicGroupButton(members_min))
await interaction.response.send_message(msg, ephemeral=True, silent=True, view=view)
logger.info("Done.")
# Increment when adjusting the custom_id so we don't confuse old users
_BUTTON_CUSTOM_ID_VERSION = 1
class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
template=r'match:min:(?P<min>[0-9]+):role:(?P<role>[@\w\s]+)'):
def __init__(self, min: int, role: str) -> None:
template=f'match:v{_BUTTON_CUSTOM_ID_VERSION}:' + r'min:(?P<min>[0-9]+)'):
def __init__(self, min: int) -> None:
super().__init__(
discord.ui.Button(
label='Match Groups!',
style=discord.ButtonStyle.blurple,
custom_id=f'match:min:{min}:role:{role}',
custom_id=f'match:min:{min}',
)
)
self.min: int = min
self.role: int = role
# This is called when the button is clicked and the custom_id matches the template.
@classmethod
async def from_custom_id(cls, interaction: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /):
min = int(match['min'])
role = str(match['role'])
return cls(min, role)
return cls(min)
async def callback(self, interaction: discord.Interaction) -> None:
"""Match up people when the button is pressed"""
logger.info("Handling button press min=%s role=%s'",
self.min, self.role)
logger.info("Handling button press min=%s", self.min)
logger.info("User %s from %s in #%s", interaction.user,
interaction.guild.name, interaction.channel.name)
# Let the user know we've recieved the message
await interaction.response.send_message(content="Matchy is matching matchees...", ephemeral=True)
# Grab the role
matchee = matching.get_role_from_guild(interaction.guild, self.role)
# Create our groups!
matchees = list(
m for m in interaction.channel.members if matchee in m.roles)
groups = matching.members_to_groups(
matchees, State, self.min, allow_fallback=True)
groups = active_members_to_groups(interaction.channel, self.min)
# Send the groups
for msg in (matching.group_to_message(g) for g in groups):
@ -167,10 +187,26 @@ class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
await interaction.channel.send("That's all folks, happy matching and remember - DFTBA!")
# Save the groups to the history
State.save_groups(groups)
State.log_groups(groups)
state.save_to_file(State, STATE_FILE)
logger.info("Done. Matched %s matchees into %s groups.",
len(matchees), len(groups))
logger.info("Done! Matched into %s groups.", len(groups))
def get_matchees_in_channel(channel: discord.channel):
"""Fetches the matchees in a channel"""
# Gather up the prospective matchees
return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)]
def active_members_to_groups(channel: discord.channel, min_members: int):
"""Helper to create groups from channel members"""
# Gather up the prospective matchees
matchees = get_matchees_in_channel(channel)
# Create our groups!
return matching.members_to_groups(matchees, State, min_members, allow_fallback=True)
if __name__ == "__main__":

238
state.py
View file

@ -5,22 +5,65 @@ from schema import Schema, And, Use, Optional
from typing import Protocol
import files
import copy
import logging
logger = logging.getLogger("state")
logger.setLevel(logging.INFO)
_FILE = "state.json"
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_DEFAULT_DICT = {
"history": {},
"matchees": {}
}
_VERSION = 1
def _migrate_to_v1(d: dict):
logger.info("Renaming %s to %s", _Key.MATCHEES, _Key.USERS)
d[_Key.USERS] = d[_Key.MATCHEES]
del d[_Key.MATCHEES]
# Set of migration functions to apply
_MIGRATIONS = [
_migrate_to_v1
]
class AuthScope(str):
"""Various auth scopes"""
OWNER = "owner"
MATCHER = "matcher"
class _Key(str):
"""Various keys used in the schema"""
HISTORY = "history"
GROUPS = "groups"
MEMBERS = "members"
USERS = "users"
SCOPES = "scopes"
MATCHES = "matches"
ACTIVE = "active"
CHANNELS = "channels"
REACTIVATE = "reactivate"
VERSION = "version"
# Unused
MATCHEES = "matchees"
_TIME_FORMAT = "%a %b %d %H:%M:%S %Y"
_SCHEMA = Schema(
{
Optional("history"): {
Optional(str): { # a datetime
"groups": [
# The current version
_Key.VERSION: And(Use(int)),
Optional(_Key.HISTORY): {
# A datetime
Optional(str): {
_Key.GROUPS: [
{
"members": [
_Key.MEMBERS: [
# The ID of each matchee in the match
And(Use(int))
]
@ -28,17 +71,33 @@ _SCHEMA = Schema(
]
}
},
Optional("matchees"): {
Optional(_Key.USERS): {
Optional(str): {
Optional("matches"): {
Optional(_Key.SCOPES): And(Use(list[str])),
Optional(_Key.MATCHES): {
# Matchee ID and Datetime pair
Optional(str): And(Use(str))
},
Optional(_Key.CHANNELS): {
# The channel ID
Optional(str): {
# Whether the user is signed up in this channel
_Key.ACTIVE: And(Use(bool)),
}
}
}
},
}
)
# Empty but schema-valid internal dict
_EMPTY_DICT = {
_Key.HISTORY: {},
_Key.USERS: {},
_Key.VERSION: _VERSION
}
assert _SCHEMA.validate(_EMPTY_DICT)
class Member(Protocol):
@property
@ -51,75 +110,148 @@ def ts_to_datetime(ts: str) -> datetime:
return datetime.strptime(ts, _TIME_FORMAT)
def validate(dict: dict):
class State():
def __init__(self, data: dict = _EMPTY_DICT):
"""Initialise and validate the state"""
self.validate(data)
self._dict = copy.deepcopy(data)
@property
def _history(self) -> dict[str]:
return self._dict[_Key.HISTORY]
@property
def _users(self) -> dict[str]:
return self._dict[_Key.USERS]
def validate(self, dict: dict = None):
"""Initialise and validate a state dict"""
if not dict:
dict = self._dict
_SCHEMA.validate(dict)
class State():
def __init__(self, data: dict = _DEFAULT_DICT):
"""Initialise and validate the state"""
validate(data)
self.__dict__ = copy.deepcopy(data)
@property
def history(self) -> list[dict]:
return self.__dict__["history"]
@property
def matchees(self) -> dict[str, dict]:
return self.__dict__["matchees"]
def save(self) -> None:
"""Save out the state"""
files.save(_FILE, self.__dict__)
def oldest_history(self) -> datetime:
def get_oldest_timestamp(self) -> datetime:
"""Grab the oldest timestamp in history"""
if not self.history:
return None
times = (ts_to_datetime(dt) for dt in self.history.keys())
return sorted(times)[0]
times = (ts_to_datetime(dt) for dt in self._history.keys())
return next(times, None)
def get_user_matches(self, id: int) -> list[int]:
return self._users.get(str(id), {}).get(_Key.MATCHES, {})
def log_groups(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None:
"""Log the groups"""
tmp_state = State(self.__dict__)
tmp_state = State(self._dict)
ts = datetime.strftime(ts, _TIME_FORMAT)
# Grab or create the hitory item for this set of groups
history_item = {}
tmp_state.history[ts] = history_item
tmp_state._history[ts] = history_item
history_item_groups = []
history_item["groups"] = history_item_groups
history_item[_Key.GROUPS] = history_item_groups
for group in groups:
# Add the group data
history_item_groups.append({
"members": list(m.id for m in group)
_Key.MEMBERS: [m.id for m in group]
})
# Update the matchee data with the matches
for m in group:
matchee = tmp_state.matchees.get(str(m.id), {})
matchee_matches = matchee.get("matches", {})
matchee = tmp_state._users.get(str(m.id), {})
matchee_matches = matchee.get(_Key.MATCHES, {})
for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts
matchee["matches"] = matchee_matches
tmp_state.matchees[str(m.id)] = matchee
matchee[_Key.MATCHES] = matchee_matches
tmp_state._users[str(m.id)] = matchee
# Validate before storing the result
validate(self.__dict__)
self.__dict__ = tmp_state.__dict__
tmp_state.validate()
self._dict = tmp_state._dict
def save_groups(self, groups: list[list[Member]]) -> None:
"""Save out the groups to the state file"""
self.log_groups(groups)
self.save()
def set_user_scope(self, id: str, scope: str, value: bool = True):
"""Add an auth scope to a user"""
# Dive in
user = self._users.get(str(id), {})
scopes = user.get(_Key.SCOPES, [])
# Set the value
if value and scope not in scopes:
scopes.append(scope)
elif not value and scope in scopes:
scopes.remove(scope)
# Roll out
user[_Key.SCOPES] = scopes
self._users[id] = user
def get_user_has_scope(self, id: str, scope: str) -> bool:
"""
Check if a user has an auth scope
"owner" users have all scopes
"""
user = self._users.get(str(id), {})
scopes = user.get(_Key.SCOPES, [])
return AuthScope.OWNER in scopes or scope in scopes
def set_use_active_in_channel(self, id: str, channel_id: str, active: bool = True):
"""Set a user as active (or not) on a given channel"""
# Dive in
user = self._users.get(str(id), {})
channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {})
# Set the value
channel[_Key.ACTIVE] = active
# Unroll
channels[str(channel_id)] = channel
user[_Key.CHANNELS] = channels
self._users[str(id)] = user
def get_user_active_in_channel(self, id: str, channel_id: str) -> bool:
"""Get a users active channels"""
user = self._users.get(str(id), {})
channels = user.get(_Key.CHANNELS, {})
return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)]
@property
def dict_internal(self) -> dict:
"""Only to be used to get the internal dict as a copy"""
return copy.deepcopy(self._dict)
def load() -> State:
"""Load the state"""
return State(files.load(_FILE) if os.path.isfile(_FILE) else _DEFAULT_DICT)
def _migrate(dict: dict):
"""Migrate a dict through versions"""
version = dict.get("version", 0)
for i in range(version, _VERSION):
logger.info("Migrating from v%s to v%s", version, version+1)
_MIGRATIONS[i](dict)
dict[_Key.VERSION] = _VERSION
def load_from_file(file: str) -> State:
"""
Load the state from a file
Apply any required migrations
"""
loaded = _EMPTY_DICT
# If there's a file load it and try to migrate
if os.path.isfile(file):
loaded = files.load(file)
_migrate(loaded)
st = State(loaded)
# Save out the migrated (or new) file
files.save(file, st._dict)
return st
def save_to_file(state: State, file: str):
"""Saves the state out to a file"""
files.save(file, state.dict_internal)

65
state_test.py Normal file
View file

@ -0,0 +1,65 @@
"""
Test functions for the state module
"""
import state
import tempfile
import os
def test_basic_state():
"""Simple validate basic state load"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
state.load_from_file(path)
def test_simple_load_reload():
"""Test a basic load, save, reload"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
st = state.load_from_file(path)
state.save_to_file(st, path)
st = state.load_from_file(path)
def test_authscope():
"""Test setting and getting an auth scope"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
st = state.load_from_file(path)
st.set_user_scope(1, state.AuthScope.MATCHER)
state.save_to_file(st, path)
st = state.load_from_file(path)
assert st.get_user_has_scope(1, state.AuthScope.MATCHER)
st.set_user_scope(1, state.AuthScope.MATCHER, False)
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
def test_channeljoin():
"""Test setting and getting an active channel"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
assert not st.get_user_active_in_channel(1, "2")
st = state.load_from_file(path)
st.set_use_active_in_channel(1, "2", True)
state.save_to_file(st, path)
st = state.load_from_file(path)
assert st.get_user_active_in_channel(1, "2")
st.set_use_active_in_channel(1, "2", False)
assert not st.get_user_active_in_channel(1, "2")