Merge pull request #1 from mdiluz/feature-no-roles

First big update
This commit is contained in:
Marc Di Luzio 2024-08-12 09:27:05 +01:00 committed by GitHub
commit db00c9f7c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 1280 additions and 644 deletions

View file

@ -18,12 +18,9 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
- name: Analysing the code with flake8
- name: Run tests
run: |
flake8 --max-line-length 120 $(git ls-files '*.py')
- name: Run tests with pytest
run: |
pytest
bash bin/test.sh
- name: Update release branch
if: github.ref == 'refs/heads/main'
run: |

2
.gitignore vendored
View file

@ -1,4 +1,4 @@
__pycache__
config.json
history.json
state.json
.venv

2
.vscode/launch.json vendored
View file

@ -8,7 +8,7 @@
"name": "Python Debugger: Matchy",
"type": "debugpy",
"request": "launch",
"program": "matchy.py",
"program": "py/matchy.py",
"console": "integratedTerminal"
}
]

View file

@ -6,8 +6,14 @@ Matchy matches matchees.
Matchy is a Discord bot that groups up users for fun and vibes. Matchy can be installed by clicking [here](https://discord.com/oauth2/authorize?client_id=1270849346987884696).
## Commands
### /match [group_min: int(3)] [matchee_role: str(@Matchee)]
Matches groups of users with a given role and posts those groups to the channel. Tracks historical matches and attempts to match users to make new connections with people with divergent roles, in an attempt to maximise diversity.
### /match [group_min: int(3)]
Matches groups of users in a channel and offers a button to pose those groups to the channel to users with `matcher` auth scope. Tracks historical matches and attempts to match users to make new connections with people with divergent roles, in an attempt to maximise diversity.
### /join and /leave
Allows users to sign up and leave the group matching in the channel the command is used
### /pause [days: int(7)]
Allows users to pause their matching in a channel for a given number of days. Users can use `/join` to re-join before the end of that time.
### $sync and $close
Only usable by `OWNER` users, reloads the config and syncs commands, or closes down the bot. Only usable in DMs with the bot user.
@ -18,20 +24,24 @@ Only usable by `OWNER` users, reloads the config and syncs commands, or closes d
## Config
Matchy is configured by a `config.json` file that takes this format:
```
```json
{
"version" : 1,
"token" : "<<github bot token>>",
"owners": [
<<owner id>>
]
"match" : {
"score_factors": {
"repeat_role" : 4,
"repeat_match" : 8,
"extra_member" : 32,
"upper_threshold" : 64
}
}
}
```
User IDs can be grabbed by turning on Discord's developer mode and right clicking on a user.
Only token and version are required. See [`py/config.py`](py/config.py) for explanations for any of these.
## TODO
* Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html)
* Add scheduling functionality
* Version the config and history files
* Implement /signup rather than using roles
* Implement authorisation scopes instead of just OWNER values
* Write integration tests (maybe with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html)?)
* Improve the weirdo

5
bin/coverage.sh Executable file
View file

@ -0,0 +1,5 @@
#!/usr/bin/env bash
set -x
set -e
pytest --cov=. --cov-report=html

View file

@ -8,4 +8,4 @@ if [ ! -d .venv ]; then
fi
source .venv/bin/activate
python -m pip install -r requirements.txt
python matchy.py
python py/matchy.py

9
bin/test.sh Executable file
View file

@ -0,0 +1,9 @@
#!/usr/bin/env bash
set -x
set -e
# Check formatting and linting
flake8 --max-line-length 120 $(git ls-files '*.py')
# Run pytest
pytest

View file

@ -1,38 +0,0 @@
"""Very simple config loading library"""
from schema import Schema, And, Use
import files
_FILE = "config.json"
_SCHEMA = Schema(
{
# Discord bot token
"token": And(Use(str)),
# ids of owners authorised to use owner-only commands
"owners": And(Use(list[int])),
}
)
class Config():
def __init__(self, data: dict):
"""Initialise and validate the config"""
_SCHEMA.validate(data)
self.__dict__ = data
@property
def token(self) -> str:
return self.__dict__["token"]
@property
def owners(self) -> list[int]:
return self.__dict__["owners"]
def reload(self) -> None:
"""Reload the config back into the dict"""
self.__dict__ = load().__dict__
def load() -> Config:
"""Load the config"""
return Config(files.load(_FILE))

View file

@ -1,2 +0,0 @@
#!/usr/bin/env bash
pytest --cov=. --cov-report=html

View file

@ -1,125 +0,0 @@
"""Store matching history"""
import os
from datetime import datetime
from schema import Schema, And, Use, Optional
from typing import Protocol
import files
import copy
_FILE = "history.json"
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_DEFAULT_DICT = {
"history": {},
"matchees": {}
}
_TIME_FORMAT = "%a %b %d %H:%M:%S %Y"
_SCHEMA = Schema(
{
Optional("history"): {
Optional(str): { # a datetime
"groups": [
{
"members": [
# The ID of each matchee in the match
And(Use(int))
]
}
]
}
},
Optional("matchees"): {
Optional(str): {
Optional("matches"): {
# Matchee ID and Datetime pair
Optional(str): And(Use(str))
}
}
}
}
)
class Member(Protocol):
@property
def id(self) -> int:
pass
def ts_to_datetime(ts: str) -> datetime:
"""Convert a ts to datetime using the history format"""
return datetime.strptime(ts, _TIME_FORMAT)
def validate(dict: dict):
"""Initialise and validate the history"""
_SCHEMA.validate(dict)
class History():
def __init__(self, data: dict = _DEFAULT_DICT):
"""Initialise and validate the history"""
validate(data)
self.__dict__ = copy.deepcopy(data)
@property
def history(self) -> list[dict]:
return self.__dict__["history"]
@property
def matchees(self) -> dict[str, dict]:
return self.__dict__["matchees"]
def save(self) -> None:
"""Save out the history"""
files.save(_FILE, self.__dict__)
def oldest(self) -> datetime:
"""Grab the oldest timestamp in history"""
if not self.history:
return None
times = (ts_to_datetime(dt) for dt in self.history.keys())
return sorted(times)[0]
def log_groups_to_history(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None:
"""Log the groups"""
tmp_history = History(self.__dict__)
ts = datetime.strftime(ts, _TIME_FORMAT)
# Grab or create the hitory item for this set of groups
history_item = {}
tmp_history.history[ts] = history_item
history_item_groups = []
history_item["groups"] = history_item_groups
for group in groups:
# Add the group data
history_item_groups.append({
"members": list(m.id for m in group)
})
# Update the matchee data with the matches
for m in group:
matchee = tmp_history.matchees.get(str(m.id), {})
matchee_matches = matchee.get("matches", {})
for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts
matchee["matches"] = matchee_matches
tmp_history.matchees[str(m.id)] = matchee
# Validate before storing the result
validate(self.__dict__)
self.__dict__ = tmp_history.__dict__
def save_groups_to_history(self, groups: list[list[Member]]) -> None:
"""Save out the groups to the history file"""
self.log_groups_to_history(groups)
self.save()
def load() -> History:
"""Load the history"""
return History(files.load(_FILE) if os.path.isfile(_FILE) else _DEFAULT_DICT)

View file

@ -1,215 +0,0 @@
"""
Test functions for Matchy
"""
import discord
import pytest
import random
import matching
import history
from datetime import datetime, timedelta
def test_protocols():
"""Verify the protocols we're using match the discord ones"""
assert isinstance(discord.Member, matching.Member)
assert isinstance(discord.Guild, matching.Guild)
assert isinstance(discord.Role, matching.Role)
assert isinstance(Member, matching.Member)
# assert isinstance(Role, matching.Role)
class Role():
def __init__(self, id: int):
self._id = id
@property
def id(self) -> int:
return self._id
class Member():
def __init__(self, id: int, roles: list[Role] = []):
self._id = id
@property
def mention(self) -> str:
return f"<@{self._id}>"
@property
def roles(self) -> list[Role]:
return []
@property
def id(self) -> int:
return self._id
def inner_validate_members_to_groups(matchees: list[Member], hist: history.History, per_group: int):
"""Inner function to validate the main output of the groups function"""
groups = matching.members_to_groups(matchees, hist, per_group)
# We should always have one group
assert len(groups)
# Log the groups to history
# This will validate the internals
hist.log_groups_to_history(groups)
# Ensure each group contains within the bounds of expected members
for group in groups:
if len(matchees) >= per_group:
assert len(group) >= per_group
else:
assert len(group) == len(matchees)
assert len(group) < per_group*2 # TODO: We could be more strict here
return groups
@pytest.mark.parametrize("matchees, per_group", [
# Simplest test possible
([Member(1)], 1),
# More requested than we have
([Member(1)], 2),
# A selection of hyper-simple checks to validate core functionality
([Member(1)] * 100, 3),
([Member(1)] * 12, 5),
([Member(1)] * 11, 2),
([Member(1)] * 356, 8),
], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"])
def test_members_to_groups_no_history(matchees, per_group):
"""Test simple group matching works"""
hist = history.History()
inner_validate_members_to_groups(matchees, hist, per_group)
def items_found_in_lists(list_of_lists, items):
"""validates if any sets of items are found in individual lists"""
for sublist in list_of_lists:
if all(item in sublist for item in items):
return True
return False
@pytest.mark.parametrize("history_data, matchees, per_group, checks", [
# Slightly more difficult test
# Describe a history where we previously matched up some people and ensure they don't get rematched
(
[
{
"ts": datetime.now() - timedelta(days=1),
"groups": [
[Member(1), Member(2)],
[Member(3), Member(4)],
]
}
],
[
Member(1),
Member(2),
Member(3),
Member(4),
],
2,
[
lambda groups: not items_found_in_lists(
groups, [Member(1), Member(2)]),
lambda groups: not items_found_in_lists(
groups, [Member(3), Member(4)])
]
),
# Feed the system an "impossible" test
# The function should fall back to ignoring history and still give us something
(
[
{
"ts": datetime.now() - timedelta(days=1),
"groups": [
[
Member(1),
Member(2),
Member(3)
],
[
Member(4),
Member(5),
Member(6)
],
]
}
],
[
Member(1, [Role(1), Role(2), Role(3), Role(4)]),
Member(2, [Role(1), Role(2), Role(3), Role(4)]),
Member(3, [Role(1), Role(2), Role(3), Role(4)]),
Member(4, [Role(1), Role(2), Role(3), Role(4)]),
Member(5, [Role(1), Role(2), Role(3), Role(4)]),
Member(6, [Role(1), Role(2), Role(3), Role(4)]),
],
3,
[
# Nothing specific to validate
]
),
], ids=['simple_history', 'fallback'])
def test_members_to_groups_with_history(history_data, matchees, per_group, checks):
"""Test more advanced group matching works"""
hist = history.History()
# Replay the history
for d in history_data:
hist.log_groups_to_history(d["groups"], d["ts"])
groups = inner_validate_members_to_groups(matchees, hist, per_group)
# Run the custom validate functions
for check in checks:
assert check(groups)
def test_members_to_groups_stress_test():
"""stress test firing significant random data at the code"""
# Use a stable rand, feel free to adjust this if needed but this lets the test be stable
rand = random.Random(123)
# Slowly ramp up the group size
for per_group in range(2, 6):
# Slowly ramp a randomized shuffled list of members with randomised roles
for num_members in range(1, 5):
matchees = list(Member(i, list(Role(i) for i in range(1, rand.randint(2, num_members*2 + 1))))
for i in range(1, rand.randint(2, num_members*10 + 1)))
rand.shuffle(matchees)
for num_history in range(8):
# Generate some super random history
# Start some time from now to the past
time = datetime.now() - timedelta(days=rand.randint(0, num_history*5))
history_data = []
for x in range(0, num_history):
run = {
"ts": time
}
groups = []
for y in range(1, num_history):
groups.append(list(Member(i)
for i in range(1, max(num_members, rand.randint(2, num_members*10 + 1)))))
run["groups"] = groups
history_data.append(run)
# Step some time backwards in time
time -= timedelta(days=rand.randint(1, num_history))
# No guarantees on history data order so make it a little harder for matchy
rand.shuffle(history_data)
# Replay the history
hist = history.History()
for d in history_data:
hist.log_groups_to_history(d["groups"], d["ts"])
inner_validate_members_to_groups(matchees, hist, per_group)

194
matchy.py
View file

@ -1,194 +0,0 @@
"""
matchy.py - Discord bot that matches people into groups
"""
import logging
import discord
from discord import app_commands
from discord.ext import commands
import matching
import history
import config
import re
Config = config.load()
History = history.load()
logger = logging.getLogger("matchy")
logger.setLevel(logging.INFO)
intents = discord.Intents.default()
intents.message_content = True
intents.members = True
bot = commands.Bot(command_prefix='$',
description="Matchy matches matchees", intents=intents)
@bot.event
async def setup_hook():
bot.add_dynamic_items(DynamicGroupButton)
@bot.event
async def on_ready():
"""Bot is ready and connected"""
logger.info("Bot is up and ready!")
activity = discord.Game("/match")
await bot.change_presence(status=discord.Status.online, activity=activity)
def owner_only(ctx: commands.Context) -> bool:
"""Checks the author is an owner"""
return ctx.message.author.id in Config.owners
@bot.command()
@commands.dm_only()
@commands.check(owner_only)
async def sync(ctx: commands.Context):
"""Handle sync command"""
msg = await ctx.reply("Reloading config...", ephemeral=True)
Config.reload()
logger.info("Reloaded config")
await msg.edit(content="Syncing commands...")
synced = await bot.tree.sync()
logger.info("Synced %s command(s)", len(synced))
await msg.edit(content="Done!")
@bot.command()
@commands.dm_only()
@commands.check(owner_only)
async def close(ctx: commands.Context):
"""Handle restart command"""
await ctx.reply("Closing bot...", ephemeral=True)
logger.info("Closing down the bot")
await bot.close()
# @bot.tree.command(description="Sign up as a matchee in this server")
# @commands.guild_only()
# async def join(interaction: discord.Interaction):
# # TODO: Sign up
# await interaction.response.send_message(
# f"Awesome, great to have you on board {interaction.user.mention}!", ephemeral=True)
# @bot.tree.command(description="Leave the matchee list in this server")
# @commands.guild_only()
# async def leave(interaction: discord.Interaction):
# # TODO: Remove the user
# await interaction.response.send_message(
# f"No worries, see you soon {interaction.user.mention}!", ephemeral=True)
@bot.tree.command(description="Match up matchees")
@commands.guild_only()
@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)",
matchee_role="Role for matchees (defaults to @Matchee)")
async def match(interaction: discord.Interaction, members_min: int = None, matchee_role: str = None):
"""Match groups of channel members"""
logger.info("Handling request '/match group_min=%s matchee_role=%s'",
members_min, matchee_role)
logger.info("User %s from %s in #%s", interaction.user,
interaction.guild.name, interaction.channel.name)
# Sort out the defaults, if not specified they'll come in as None
if not members_min:
members_min = 3
if not matchee_role:
matchee_role = "Matchee"
# Grab the roles and verify the given role
matcher = matching.get_role_from_guild(interaction.guild, "Matcher")
matcher = matcher and matcher in interaction.user.roles
matchee = matching.get_role_from_guild(interaction.guild, matchee_role)
if not matchee:
await interaction.response.send_message(f"Server is missing '{matchee_role}' role :(", ephemeral=True)
return
# Create some example groups to show the user
matchees = list(
m for m in interaction.channel.members if matchee in m.roles)
groups = matching.members_to_groups(
matchees, History, members_min, allow_fallback=True)
# Post about all the groups with a button to send to the channel
groups_list = '\n'.join(matching.group_to_message(g) for g in groups)
msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}"
view = discord.utils.MISSING
if not matcher:
# Let a non-matcher know why they don't have the button
msg += "\n\nYou'll need the 'Matcher' role to post this to the channel, sorry!"
else:
# Otherwise set up the button
msg += "\n\nClick the button to match up groups and send them to the channel.\n"
view = discord.ui.View(timeout=None)
view.add_item(DynamicGroupButton(members_min, matchee_role))
await interaction.response.send_message(msg, ephemeral=True, silent=True, view=view)
logger.info("Done.")
class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
template=r'match:min:(?P<min>[0-9]+):role:(?P<role>[@\w\s]+)'):
def __init__(self, min: int, role: str) -> None:
super().__init__(
discord.ui.Button(
label='Match Groups!',
style=discord.ButtonStyle.blurple,
custom_id=f'match:min:{min}:role:{role}',
)
)
self.min: int = min
self.role: int = role
# This is called when the button is clicked and the custom_id matches the template.
@classmethod
async def from_custom_id(cls, interaction: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /):
min = int(match['min'])
role = str(match['role'])
return cls(min, role)
async def callback(self, interaction: discord.Interaction) -> None:
"""Match up people when the button is pressed"""
logger.info("Handling button press min=%s role=%s'",
self.min, self.role)
logger.info("User %s from %s in #%s", interaction.user,
interaction.guild.name, interaction.channel.name)
# Let the user know we've recieved the message
await interaction.response.send_message(content="Matchy is matching matchees...", ephemeral=True)
# Grab the role
matchee = matching.get_role_from_guild(interaction.guild, self.role)
# Create our groups!
matchees = list(
m for m in interaction.channel.members if matchee in m.roles)
groups = matching.members_to_groups(
matchees, History, self.min, allow_fallback=True)
# Send the groups
for msg in (matching.group_to_message(g) for g in groups):
await interaction.channel.send(msg)
# Close off with a message
await interaction.channel.send("That's all folks, happy matching and remember - DFTBA!")
# Save the groups to the history
History.save_groups_to_history(groups)
logger.info("Done. Matched %s matchees into %s groups.",
len(matchees), len(groups))
if __name__ == "__main__":
handler = logging.StreamHandler()
bot.run(Config.token, log_handler=handler, root_logger=True)

135
py/config.py Normal file
View file

@ -0,0 +1,135 @@
"""Very simple config loading library"""
from schema import Schema, And, Use, Optional
import files
import os
import logging
logger = logging.getLogger("config")
logger.setLevel(logging.INFO)
_FILE = "config.json"
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_VERSION = 1
class _Key():
TOKEN = "token"
VERSION = "version"
MATCH = "match"
SCORE_FACTORS = "score_factors"
REPEAT_ROLE = "repeat_role"
REPEAT_MATCH = "repeat_match"
EXTRA_MEMBER = "extra_member"
UPPER_THRESHOLD = "upper_threshold"
# Removed
_OWNERS = "owners"
_SCHEMA = Schema(
{
# The current version
_Key.VERSION: And(Use(int)),
# Discord bot token
_Key.TOKEN: And(Use(str)),
# Settings for the match algorithmn, see matching.py for explanations on usage
Optional(_Key.MATCH): {
Optional(_Key.SCORE_FACTORS): {
Optional(_Key.REPEAT_ROLE): And(Use(int)),
Optional(_Key.REPEAT_MATCH): And(Use(int)),
Optional(_Key.EXTRA_MEMBER): And(Use(int)),
Optional(_Key.UPPER_THRESHOLD): And(Use(int)),
}
}
}
)
_EMPTY_DICT = {
_Key.TOKEN: "",
_Key.VERSION: _VERSION
}
def _migrate_to_v1(d: dict):
# Owners moved to History in v1
# Note: owners will be required to be re-added to the state.json
owners = d.pop(_Key._OWNERS)
logger.warn(
"Migration removed owners from config, these must be re-added to the state.json")
logger.warn("Owners: %s", owners)
# Set of migration functions to apply
_MIGRATIONS = [
_migrate_to_v1
]
class _ScoreFactors():
def __init__(self, data: dict):
"""Initialise and validate the config"""
self._dict = data
@property
def repeat_role(self) -> int:
return self._dict.get(_Key.REPEAT_ROLE, None)
@property
def repeat_match(self) -> int:
return self._dict.get(_Key.REPEAT_MATCH, None)
@property
def extra_member(self) -> int:
return self._dict.get(_Key.EXTRA_MEMBER, None)
@property
def upper_threshold(self) -> int:
return self._dict.get(_Key.UPPER_THRESHOLD, None)
class _Config():
def __init__(self, data: dict):
"""Initialise and validate the config"""
_SCHEMA.validate(data)
self._dict = data
@property
def token(self) -> str:
return self._dict["token"]
@property
def score_factors(self) -> _ScoreFactors:
return _ScoreFactors(self._dict.get(_Key.SCORE_FACTORS, {}))
def _migrate(dict: dict):
"""Migrate a dict through versions"""
version = dict.get("version", 0)
for i in range(version, _VERSION):
_MIGRATIONS[i](dict)
dict["version"] = _VERSION
def _load_from_file(file: str = _FILE) -> _Config:
"""
Load the state from a file
Apply any required migrations
"""
loaded = _EMPTY_DICT
if os.path.isfile(file):
loaded = files.load(file)
_migrate(loaded)
else:
logger.warn("No %s file found, bot cannot run!", file)
return _Config(loaded)
# Core config for users to use
# Singleton as there should only be one, and it's global
Config = _load_from_file()

View file

@ -1,25 +1,27 @@
"""Utility functions for matchy"""
import logging
import random
from datetime import datetime, timedelta
from typing import Protocol, runtime_checkable
import history
import state
import config
# Number of days to step forward from the start of history for each match attempt
_ATTEMPT_TIMESTEP_INCREMENT = timedelta(days=7)
class _ScoreFactors(int):
"""
Score factors used when trying to build up "best fit" groups
Matchees are sequentially placed into the lowest scoring available group
"""
# Attempts for each of those time periods
_ATTEMPTS_PER_TIMESTEP = 3
# Added for each role the matchee has that another group member has
REPEAT_ROLE = config.Config.score_factors.repeat_role or 2**2
# Added for each member in the group that the matchee has already matched with
REPEAT_MATCH = config.Config.score_factors.repeat_match or 2**3
# Added for each additional member over the set "per group" value
EXTRA_MEMBER = config.Config.score_factors.extra_member or 2**5
# Various eligability scoring factors for group meetups
_SCORE_CURRENT_MEMBERS = 2**1
_SCORE_REPEAT_ROLE = 2**2
_SCORE_REPEAT_MATCH = 2**3
_SCORE_EXTRA_MEMBERS = 2**4
# Upper threshold, if the user scores higher than this they will not be placed in that group
UPPER_THRESHOLD = config.Config.score_factors.upper_threshold or 2**6
# Scores higher than this are fully rejected
_SCORE_UPPER_THRESHOLD = 2**6
logger = logging.getLogger("matching")
logger.setLevel(logging.INFO)
@ -69,33 +71,42 @@ def members_to_groups_simple(matchees: list[Member], per_group: int) -> tuple[bo
def get_member_group_eligibility_score(member: Member,
group: list[Member],
relevant_matches: list[int],
per_group: int) -> int:
prior_matches: list[int],
per_group: int) -> float:
"""Rates a member against a group"""
rating = len(group) * _SCORE_CURRENT_MEMBERS
# An empty group is a "perfect" score atomatically
rating = 0
if not group:
return rating
repeat_meetings = sum(m.id in relevant_matches for m in group)
rating += repeat_meetings * _SCORE_REPEAT_MATCH
# Add score based on prior matchups of this user
num_prior = sum(m.id in prior_matches for m in group)
rating += num_prior * _ScoreFactors.REPEAT_MATCH
repeat_roles = sum(r in member.roles for r in (m.roles for m in group))
rating += (repeat_roles * _SCORE_REPEAT_ROLE)
# Calculate the number of roles that match
all_role_ids = set(r.id for mr in [r.roles for r in group] for r in mr)
member_role_ids = [r.id for r in member.roles]
repeat_roles = sum(id in member_role_ids for id in all_role_ids)
rating += repeat_roles * _ScoreFactors.REPEAT_ROLE
extra_members = len(group) - per_group
if extra_members > 0:
rating += extra_members * _SCORE_EXTRA_MEMBERS
# Add score based on the number of extra members
# Calculate the member offset (+1 for this user)
extra_members = (len(group) - per_group) + 1
if extra_members >= 0:
rating += extra_members * _ScoreFactors.EXTRA_MEMBER
return rating
def attempt_create_groups(matchees: list[Member],
hist: history.History,
current_state: state.State,
oldest_relevant_ts: datetime,
per_group: int) -> tuple[bool, list[list[Member]]]:
"""History aware group matching"""
num_groups = max(len(matchees)//per_group, 1)
# Set up the groups in place
groups = list([] for _ in range(num_groups))
groups = [[] for _ in range(num_groups)]
matchees_left = matchees.copy()
@ -103,21 +114,21 @@ def attempt_create_groups(matchees: list[Member],
while matchees_left:
# Get the next matchee to place
matchee = matchees_left.pop()
matchee_matches = hist.matchees.get(
str(matchee.id), {}).get("matches", {})
relevant_matches = list(int(id) for id, ts in matchee_matches.items()
if history.ts_to_datetime(ts) >= oldest_relevant_ts)
matchee_matches = current_state.get_user_matches(matchee.id)
relevant_matches = [int(id) for id, ts
in matchee_matches.items()
if state.ts_to_datetime(ts) >= oldest_relevant_ts]
# Try every single group from the current group onwards
# Progressing through the groups like this ensures we slowly fill them up with compatible people
scores: list[tuple[int, int]] = []
scores: list[tuple[int, float]] = []
for group in groups:
score = get_member_group_eligibility_score(
matchee, group, relevant_matches, num_groups)
matchee, group, relevant_matches, per_group)
# If the score isn't too high, consider this group
if score <= _SCORE_UPPER_THRESHOLD:
if score <= _ScoreFactors.UPPER_THRESHOLD:
scores.append((group, score))
# Optimisation:
@ -143,31 +154,38 @@ def datetime_range(start_time: datetime, increment: timedelta, end: datetime):
current += increment
def iterate_all_shifts(list: list):
"""Yields each shifted variation of the input list"""
yield list
for _ in range(len(list)-1):
list = list[1:] + [list[0]]
yield list
def members_to_groups(matchees: list[Member],
hist: history.History = history.History(),
st: state.State = state.State(),
per_group: int = 3,
allow_fallback: bool = False) -> list[list[Member]]:
"""Generate the groups from the set of matchees"""
attempts = 0 # Tracking for logging purposes
rand = random.Random(117) # Some stable randomness
num_groups = len(matchees)//per_group
# Grab the oldest timestamp
history_start = hist.oldest() or datetime.now()
# Bail early if there's no-one to match
if not matchees:
return []
# Walk from the start of time until now using the timestep increment
for oldest_relevant_datetime in datetime_range(history_start, _ATTEMPT_TIMESTEP_INCREMENT, datetime.now()):
# Walk from the start of history until now trying to match up groups
for oldest_relevant_datetime in st.get_history_timestamps() + [datetime.now()]:
# Have a few attempts before stepping forward in time
for _ in range(_ATTEMPTS_PER_TIMESTEP):
rand.shuffle(matchees) # Shuffle the matchees each attempt
# Attempt with each starting matchee
for shifted_matchees in iterate_all_shifts(matchees):
attempts += 1
groups = attempt_create_groups(
matchees, hist, oldest_relevant_datetime, per_group)
shifted_matchees, st, oldest_relevant_datetime, per_group)
# Fail the match if our groups aren't big enough
if (len(matchees)//per_group) <= 1 or (groups and all(len(g) >= per_group for g in groups)):
if num_groups <= 1 or (groups and all(len(g) >= per_group for g in groups)):
logger.info("Matched groups after %s attempt(s)", attempts)
return groups
@ -176,6 +194,10 @@ def members_to_groups(matchees: list[Member],
logger.info("Fell back to simple groups after %s attempt(s)", attempts)
return members_to_groups_simple(matchees, per_group)
# Simply assert false, this should never happen
# And should be caught by tests
assert False
def group_to_message(group: list[Member]) -> str:
"""Get the message to send for each group"""
@ -185,8 +207,3 @@ def group_to_message(group: list[Member]) -> str:
else:
mentions = mentions[0]
return f"Matched up {mentions}!"
def get_role_from_guild(guild: Guild, role: str) -> Role:
"""Find a role in a guild"""
return next((r for r in guild.roles if r.name == role), None)

412
py/matching_test.py Normal file
View file

@ -0,0 +1,412 @@
"""
Test functions for the matching module
"""
import discord
import pytest
import random
import matching
import state
import copy
import itertools
from datetime import datetime, timedelta
def test_protocols():
"""Verify the protocols we're using match the discord ones"""
assert isinstance(discord.Member, matching.Member)
assert isinstance(discord.Guild, matching.Guild)
assert isinstance(discord.Role, matching.Role)
assert isinstance(Member, matching.Member)
# assert isinstance(Role, matching.Role)
class Role():
def __init__(self, id: int):
self._id = id
@property
def id(self) -> int:
return self._id
class Member():
def __init__(self, id: int, roles: list[Role] = []):
self._id = id
self._roles = roles
@property
def mention(self) -> str:
return f"<@{self._id}>"
@property
def roles(self) -> list[Role]:
return self._roles
@roles.setter
def roles(self, roles: list[Role]):
self._roles = roles
@property
def id(self) -> int:
return self._id
def members_to_groups_validate(matchees: list[Member], tmp_state: state.State, per_group: int):
"""Inner function to validate the main output of the groups function"""
groups = matching.members_to_groups(matchees, tmp_state, per_group)
# We should always have one group
assert len(groups)
# Log the groups to history
# This will validate the internals
tmp_state.log_groups(groups)
# Ensure each group contains within the bounds of expected members
for group in groups:
if len(matchees) >= per_group:
assert len(group) >= per_group
else:
assert len(group) == len(matchees)
assert len(group) < per_group*2 # TODO: We could be more strict here
return groups
@pytest.mark.parametrize("matchees, per_group", [
# Simplest test possible
([Member(1)], 1),
# More requested than we have
([Member(1)], 2),
# A selection of hyper-simple checks to validate core functionality
([Member(1)] * 100, 3),
([Member(1)] * 12, 5),
([Member(1)] * 11, 2),
([Member(1)] * 356, 8),
], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"])
def test_members_to_groups_no_history(matchees, per_group):
"""Test simple group matching works"""
tmp_state = state.State()
members_to_groups_validate(matchees, tmp_state, per_group)
def items_found_in_lists(list_of_lists, items):
"""validates if any sets of items are found in individual lists"""
for sublist in list_of_lists:
if all(item in sublist for item in items):
return True
return False
@pytest.mark.parametrize("history_data, matchees, per_group, checks", [
# Slightly more difficult test
(
# Describe a history where we previously matched up some people and ensure they don't get rematched
[
{
"ts": datetime.now() - timedelta(days=1),
"groups": [
[Member(1), Member(2)],
[Member(3), Member(4)],
]
}
],
[
Member(1),
Member(2),
Member(3),
Member(4),
],
2,
[
lambda groups: not items_found_in_lists(
groups, [Member(1), Member(2)]),
lambda groups: not items_found_in_lists(
groups, [Member(3), Member(4)])
]
),
# Feed the system an "impossible" test
# The function should fall back to ignoring history and still give us something
(
[
{
"ts": datetime.now() - timedelta(days=1),
"groups": [
[
Member(1),
Member(2),
Member(3)
],
[
Member(4),
Member(5),
Member(6)
],
]
}
],
[
Member(1, [Role(1), Role(2), Role(3), Role(4)]),
Member(2, [Role(1), Role(2), Role(3), Role(4)]),
Member(3, [Role(1), Role(2), Role(3), Role(4)]),
Member(4, [Role(1), Role(2), Role(3), Role(4)]),
Member(5, [Role(1), Role(2), Role(3), Role(4)]),
Member(6, [Role(1), Role(2), Role(3), Role(4)]),
],
3,
[
# Nothing specific to validate
]
),
# Specific test pulled out of the stress test
(
[
{
"ts": datetime.now() - timedelta(days=4),
"groups": [
[Member(i) for i in [1, 2, 3, 4, 5, 6,
7, 8, 9, 10, 11, 12, 13, 14, 15]]
]
},
{
"ts": datetime.now() - timedelta(days=5),
"groups": [
[Member(i) for i in [1, 2, 3, 4, 5, 6, 7, 8]]
]
}
],
[Member(i) for i in [1, 2, 11, 4, 12, 3, 7, 5, 8, 10, 9, 6]],
3,
[
# Nothing specific to validate
]
),
# Silly example that failued due to bad role logic
(
[
# No history
],
[
# print([(m.id, [r.id for r in m.roles]) for m in matchees]) to get the below
Member(i, [Role(r) for r in roles]) for (i, roles) in
[
(4, [1, 2, 3, 4, 5, 6, 7, 8]),
(8, [1]),
(9, [1, 2, 3, 4, 5]),
(6, [1, 2, 3]),
(11, [1, 2, 3]),
(7, [1, 2, 3, 4, 5, 6, 7]),
(1, [1, 2, 3, 4]),
(5, [1, 2, 3, 4, 5]),
(12, [1, 2, 3, 4]),
(10, [1]),
(13, [1, 2, 3, 4, 5, 6]),
(2, [1, 2, 3, 4, 5, 6]),
(3, [1, 2, 3, 4, 5, 6, 7])
]
],
2,
[
# Nothing else
]
),
# Another weird one pulled out of the stress test
(
[
# print([(str(h["ts"]), [[f"Member({gm.id})" for gm in g] for g in h["groups"]]) for h in history_data])
{"ts": datetime.strptime(ts, r"%Y-%m-%d %H:%M:%S.%f"), "groups": [
[Member(m) for m in group] for group in groups]}
for (ts, groups) in
[
(
'2024-07-07 20:25:56.313993',
[
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
[1],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
[1, 2, 3, 4, 5, 6, 7, 8]
]
),
(
'2024-07-13 20:25:56.313993',
[
[1, 2],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1]
]
),
(
'2024-06-29 20:25:56.313993',
[
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5, 6, 7],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20]
]
),
(
'2024-06-25 20:25:56.313993',
[
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18],
[1, 2],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[1, 2]
]
),
(
'2024-07-04 20:25:56.313993',
[
[1, 2, 3, 4, 5],
[1, 2, 3],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
[1, 2, 3, 4, 5, 6, 7]
]
),
(
'2024-07-16 20:25:56.313993',
[
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
[1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20],
[1, 2, 3, 4, 5, 6],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18]
]
)
]
],
[
# print([(m.id, [r.id for r in m.roles]) for m in matchees]) to get the below
Member(i, [Role(r) for r in roles]) for (i, roles) in
[
(10, [1, 2, 3]),
(4, [1, 2, 3]),
(5, [1, 2]),
(13, [1, 2]),
(3, [1, 2, 3, 4]),
(14, [1]),
(6, [1, 2, 3, 4]),
(11, [1]),
(9, [1]),
(1, [1, 2, 3]),
(16, [1, 2]),
(15, [1, 2]),
(2, [1, 2, 3]),
(7, [1, 2, 3]),
(12, [1, 2]),
(8, [1, 2, 3, 4])
]
],
5,
[
# Nothing
]
)
], ids=['simple_history', 'fallback', 'example_1', 'example_2', 'example_3'])
def test_unique_regressions(history_data, matchees, per_group, checks):
"""Test a bunch of unqiue failures that happened in the past"""
tmp_state = state.State()
# Replay the history
for d in history_data:
tmp_state.log_groups(d["groups"], d["ts"])
groups = members_to_groups_validate(matchees, tmp_state, per_group)
# Run the custom validate functions
for check in checks:
assert check(groups)
def random_chunk(li, min_chunk, max_chunk, rand):
"""
"Borrowed" from https://stackoverflow.com/questions/21439011/best-way-to-split-a-list-into-randomly-sized-chunks
"""
it = iter(li)
while True:
nxt = list(itertools.islice(it, rand.randint(min_chunk, max_chunk)))
if nxt:
yield nxt
else:
break
# Generate a large set of "interesting" tests that replay a fake history onto random people
# Increase these numbers for some extreme programming
@pytest.mark.parametrize("per_group, num_members, num_history", (
(per_group, num_members, num_history)
# Most of the time groups are gonna be from 2 to 5
for per_group in range(2, 5)
# Going lower than 8 members doesn't give the bot much of a chance
# And it will fail to not fall back sometimes
# That's probably OK frankly
for num_members in range(8, 32, 5)
# Throw up to 7 histories at the algorithmn
for num_history in range(0, 8)))
def test_stess_random_groups(per_group, num_members, num_history):
"""Run a randomised test based on the input"""
# Seed the random based on the inputs paird with primes
# Ensures the test has interesting fake data, but is stable
rand = random.Random(per_group*3 + num_members*5 + num_history*7)
# Start with a list of all possible members
possible_members = [Member(i) for i in range(num_members*2)]
for member in possible_members:
# Give each member 3 random roles from 1-7
member.roles = [Role(i) for i in rand.sample(range(1, 8), 3)]
# For each history item match up groups and log those
cumulative_state = state.State()
for i in range(num_history+1):
# Grab the num of members and replay
rand.shuffle(possible_members)
members = copy.deepcopy(possible_members[:num_members])
groups = members_to_groups_validate(
members, cumulative_state, per_group)
cumulative_state.log_groups(
groups, datetime.now() - timedelta(days=num_history-i))
def test_auth_scopes():
tmp_state = state.State()
id = "1"
tmp_state.set_user_scope(id, state.AuthScope.OWNER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
id = "2"
tmp_state.set_user_scope(id, state.AuthScope.MATCHER)
assert not tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
tmp_state.validate()
def test_iterate_all_shifts():
original = [1, 2, 3, 4]
lists = [val for val in matching.iterate_all_shifts(original)]
assert lists == [
[1, 2, 3, 4],
[2, 3, 4, 1],
[3, 4, 1, 2],
[4, 1, 2, 3],
]

230
py/matchy.py Executable file
View file

@ -0,0 +1,230 @@
"""
matchy.py - Discord bot that matches people into groups
"""
import logging
import discord
from discord import app_commands
from discord.ext import commands
import matching
import state
import config
import re
STATE_FILE = "state.json"
State = state.load_from_file(STATE_FILE)
logger = logging.getLogger("matchy")
logger.setLevel(logging.INFO)
intents = discord.Intents.default()
intents.message_content = True
intents.members = True
bot = commands.Bot(command_prefix='$',
description="Matchy matches matchees", intents=intents)
@bot.event
async def setup_hook():
bot.add_dynamic_items(DynamicGroupButton)
@bot.event
async def on_ready():
"""Bot is ready and connected"""
logger.info("Bot is up and ready!")
activity = discord.Game("/join")
await bot.change_presence(status=discord.Status.online, activity=activity)
def owner_only(ctx: commands.Context) -> bool:
"""Checks the author is an owner"""
return State.get_user_has_scope(ctx.message.author.id, state.AuthScope.OWNER)
@bot.command()
@commands.dm_only()
@commands.check(owner_only)
async def sync(ctx: commands.Context):
"""Handle sync command"""
msg = await ctx.reply("Reloading state...", ephemeral=True)
global State
State = state.load_from_file(STATE_FILE)
logger.info("Reloaded state")
await msg.edit(content="Syncing commands...")
synced = await bot.tree.sync()
logger.info("Synced %s command(s)", len(synced))
await msg.edit(content="Done!")
@bot.command()
@commands.dm_only()
@commands.check(owner_only)
async def close(ctx: commands.Context):
"""Handle restart command"""
await ctx.reply("Closing bot...", ephemeral=True)
logger.info("Closing down the bot")
await bot.close()
@bot.tree.command(description="Join the matchees for this channel")
@commands.guild_only()
async def join(interaction: discord.Interaction):
State.set_user_active_in_channel(
interaction.user.id, interaction.channel.id)
state.save_to_file(State, STATE_FILE)
await interaction.response.send_message(
f"Roger roger {interaction.user.mention}!\n"
+ f"Added you to {interaction.channel.mention}!",
ephemeral=True, silent=True)
@bot.tree.command(description="Leave the matchees for this channel")
@commands.guild_only()
async def leave(interaction: discord.Interaction):
State.set_user_active_in_channel(
interaction.user.id, interaction.channel.id, False)
state.save_to_file(State, STATE_FILE)
await interaction.response.send_message(
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
@bot.tree.command(description="Pause your matching in this channel for a number of days")
@commands.guild_only()
@app_commands.describe(days="Days to pause for (defaults to 7)")
async def pause(interaction: discord.Interaction, days: int = None):
if not days: # Default to a week
days = 7
State.set_user_paused_in_channel(
interaction.user.id, interaction.channel.id, days)
state.save_to_file(State, STATE_FILE)
await interaction.response.send_message(
f"Sure thing {interaction.user.mention}. Paused you for {days} days!", ephemeral=True, silent=True)
@bot.tree.command(description="List the matchees for this channel")
@commands.guild_only()
async def list(interaction: discord.Interaction):
matchees = get_matchees_in_channel(interaction.channel)
mentions = [m.mention for m in matchees]
msg = "Current matchees in this channel:\n" + \
f"{', '.join(mentions[:-1])} and {mentions[-1]}"
await interaction.response.send_message(msg, ephemeral=True, silent=True)
@bot.tree.command(description="Match up matchees")
@commands.guild_only()
@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)")
async def match(interaction: discord.Interaction, members_min: int = None):
"""Match groups of channel members"""
logger.info("Handling request '/match group_min=%s", members_min)
logger.info("User %s from %s in #%s", interaction.user,
interaction.guild.name, interaction.channel.name)
# Sort out the defaults, if not specified they'll come in as None
if not members_min:
members_min = 3
# Grab the groups
groups = active_members_to_groups(interaction.channel, members_min)
# Let the user know when there's nobody to match
if not groups:
await interaction.response.send_message("Nobody to match up :(", ephemeral=True, silent=True)
return
# Post about all the groups with a button to send to the channel
groups_list = '\n'.join(matching.group_to_message(g) for g in groups)
msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}"
view = discord.utils.MISSING
if State.get_user_has_scope(interaction.user.id, state.AuthScope.MATCHER):
# Otherwise set up the button
msg += "\n\nClick the button to match up groups and send them to the channel.\n"
view = discord.ui.View(timeout=None)
view.add_item(DynamicGroupButton(members_min))
else:
# Let a non-matcher know why they don't have the button
msg += f"\n\nYou'll need the {state.AuthScope.MATCHER}"
+ " scope to post this to the channel, sorry!"
await interaction.response.send_message(msg, ephemeral=True, silent=True, view=view)
logger.info("Done.")
# Increment when adjusting the custom_id so we don't confuse old users
_MATCH_BUTTON_CUSTOM_ID_VERSION = 1
_MATCH_BUTTON_CUSTOM_ID_PREFIX = f'match:v{_MATCH_BUTTON_CUSTOM_ID_VERSION}:'
class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
template=_MATCH_BUTTON_CUSTOM_ID_PREFIX + r'min:(?P<min>[0-9]+)'):
def __init__(self, min: int) -> None:
super().__init__(
discord.ui.Button(
label='Match Groups!',
style=discord.ButtonStyle.blurple,
custom_id=_MATCH_BUTTON_CUSTOM_ID_PREFIX + f'min:{min}',
)
)
self.min: int = min
# This is called when the button is clicked and the custom_id matches the template.
@classmethod
async def from_custom_id(cls, interaction: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /):
min = int(match['min'])
return cls(min)
async def callback(self, interaction: discord.Interaction) -> None:
"""Match up people when the button is pressed"""
logger.info("Handling button press min=%s", self.min)
logger.info("User %s from %s in #%s", interaction.user,
interaction.guild.name, interaction.channel.name)
# Let the user know we've recieved the message
await interaction.response.send_message(content="Matchy is matching matchees...", ephemeral=True)
groups = active_members_to_groups(interaction.channel, self.min)
# Send the groups
for msg in (matching.group_to_message(g) for g in groups):
await interaction.channel.send(msg)
# Close off with a message
await interaction.channel.send("That's all folks, happy matching and remember - DFTBA!")
# Save the groups to the history
State.log_groups(groups)
state.save_to_file(State, STATE_FILE)
logger.info("Done! Matched into %s groups.", len(groups))
def get_matchees_in_channel(channel: discord.channel):
"""Fetches the matchees in a channel"""
# Reactivate any unpaused users
State.reactivate_users(channel.id)
# Gather up the prospective matchees
return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)]
def active_members_to_groups(channel: discord.channel, min_members: int):
"""Helper to create groups from channel members"""
# Gather up the prospective matchees
matchees = get_matchees_in_channel(channel)
# Create our groups!
return matching.members_to_groups(matchees, State, min_members, allow_fallback=True)
if __name__ == "__main__":
handler = logging.StreamHandler()
bot.run(config.Config.token, log_handler=handler, root_logger=True)

330
py/state.py Normal file
View file

@ -0,0 +1,330 @@
"""Store bot state"""
import os
from datetime import datetime, timedelta
from schema import Schema, And, Use, Optional
from typing import Protocol
import files
import copy
import logging
from contextlib import contextmanager
logger = logging.getLogger("state")
logger.setLevel(logging.INFO)
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_VERSION = 2
def _migrate_to_v1(d: dict):
"""v1 simply renamed matchees to users"""
logger.info("Renaming %s to %s", _Key._MATCHEES, _Key.USERS)
d[_Key.USERS] = d[_Key._MATCHEES]
del d[_Key._MATCHEES]
def _migrate_to_v2(d: dict):
"""v2 swapped the date over to a less silly format"""
logger.info("Fixing up date format from %s to %s",
_TIME_FORMAT_OLD, _TIME_FORMAT)
def old_to_new_ts(ts: str) -> str:
return datetime.strftime(datetime.strptime(ts, _TIME_FORMAT_OLD), _TIME_FORMAT)
# Adjust all the history keys
d[_Key.HISTORY] = {
old_to_new_ts(ts): entry
for ts, entry in d[_Key.HISTORY].items()
}
# Adjust all the user parts
for user in d[_Key.USERS].values():
# Update the match dates
matches = user.get(_Key.MATCHES, {})
for id, ts in matches.items():
matches[id] = old_to_new_ts(ts)
# Update any reactivation dates
channels = user.get(_Key.CHANNELS, {})
for id, channel in channels.items():
old_ts = channel.get(_Key.REACTIVATE, None)
if old_ts:
channel[_Key.REACTIVATE] = old_to_new_ts(old_ts)
# Set of migration functions to apply
_MIGRATIONS = [
_migrate_to_v1,
_migrate_to_v2
]
class AuthScope(str):
"""Various auth scopes"""
OWNER = "owner"
MATCHER = "matcher"
class _Key(str):
"""Various keys used in the schema"""
HISTORY = "history"
GROUPS = "groups"
MEMBERS = "members"
USERS = "users"
SCOPES = "scopes"
MATCHES = "matches"
ACTIVE = "active"
CHANNELS = "channels"
REACTIVATE = "reactivate"
VERSION = "version"
# Unused
_MATCHEES = "matchees"
_TIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f"
_TIME_FORMAT_OLD = "%a %b %d %H:%M:%S %Y"
_SCHEMA = Schema(
{
# The current version
_Key.VERSION: And(Use(int)),
Optional(_Key.HISTORY): {
# A datetime
Optional(str): {
_Key.GROUPS: [
{
_Key.MEMBERS: [
# The ID of each matchee in the match
And(Use(int))
]
}
]
}
},
Optional(_Key.USERS): {
Optional(str): {
Optional(_Key.SCOPES): And(Use(list[str])),
Optional(_Key.MATCHES): {
# Matchee ID and Datetime pair
Optional(str): And(Use(str))
},
Optional(_Key.CHANNELS): {
# The channel ID
Optional(str): {
# Whether the user is signed up in this channel
_Key.ACTIVE: And(Use(bool)),
# A timestamp for when to re-activate the user
Optional(_Key.REACTIVATE): And(Use(str)),
}
}
}
},
}
)
# Empty but schema-valid internal dict
_EMPTY_DICT = {
_Key.HISTORY: {},
_Key.USERS: {},
_Key.VERSION: _VERSION
}
assert _SCHEMA.validate(_EMPTY_DICT)
class Member(Protocol):
@property
def id(self) -> int:
pass
def ts_to_datetime(ts: str) -> datetime:
"""Convert a string ts to datetime using the internal format"""
return datetime.strptime(ts, _TIME_FORMAT)
def datetime_to_ts(ts: datetime) -> str:
"""Convert a datetime to a string ts using the internal format"""
return datetime.strftime(ts, _TIME_FORMAT)
class State():
def __init__(self, data: dict = _EMPTY_DICT):
"""Initialise and validate the state"""
self.validate(data)
self._dict = copy.deepcopy(data)
def validate(self, dict: dict = None):
"""Initialise and validate a state dict"""
if not dict:
dict = self._dict
_SCHEMA.validate(dict)
def get_history_timestamps(self) -> list[datetime]:
"""Grab all timestamps in the history"""
return sorted([ts_to_datetime(dt) for dt in self._history.keys()])
def get_user_matches(self, id: int) -> list[int]:
return self._users.get(str(id), {}).get(_Key.MATCHES, {})
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
"""Log the groups"""
ts = datetime_to_ts(ts or datetime.now())
with self._safe_wrap() as safe_state:
# Grab or create the hitory item for this set of groups
history_item = {}
safe_state._history[ts] = history_item
history_item_groups = []
history_item[_Key.GROUPS] = history_item_groups
for group in groups:
# Add the group data
history_item_groups.append({
_Key.MEMBERS: [m.id for m in group]
})
# Update the matchee data with the matches
for m in group:
matchee = safe_state._users.get(str(m.id), {})
matchee_matches = matchee.get(_Key.MATCHES, {})
for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts
matchee[_Key.MATCHES] = matchee_matches
safe_state._users[str(m.id)] = matchee
def set_user_scope(self, id: str, scope: str, value: bool = True):
"""Add an auth scope to a user"""
with self._safe_wrap() as safe_state:
# Dive in
user = safe_state._users.get(str(id), {})
scopes = user.get(_Key.SCOPES, [])
# Set the value
if value and scope not in scopes:
scopes.append(scope)
elif not value and scope in scopes:
scopes.remove(scope)
# Roll out
user[_Key.SCOPES] = scopes
safe_state._users[str(id)] = user
def get_user_has_scope(self, id: str, scope: str) -> bool:
"""
Check if a user has an auth scope
"owner" users have all scopes
"""
user = self._users.get(str(id), {})
scopes = user.get(_Key.SCOPES, [])
return AuthScope.OWNER in scopes or scope in scopes
def set_user_active_in_channel(self, id: str, channel_id: str, active: bool = True):
"""Set a user as active (or not) on a given channel"""
self._set_user_channel_prop(id, channel_id, _Key.ACTIVE, active)
def get_user_active_in_channel(self, id: str, channel_id: str) -> bool:
"""Get a users active channels"""
user = self._users.get(str(id), {})
channels = user.get(_Key.CHANNELS, {})
return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)]
def set_user_paused_in_channel(self, id: str, channel_id: str, days: int):
"""Sets a user as paused in a channel"""
# Deactivate the user in the channel first
self.set_user_active_in_channel(id, channel_id, False)
# Set the reactivate time the number of days in the future
ts = datetime.now() + timedelta(days=days)
self._set_user_channel_prop(
id, channel_id, _Key.REACTIVATE, datetime_to_ts(ts))
def reactivate_users(self, channel_id: str):
"""Reactivate any users who've passed their reactivation time on this channel"""
with self._safe_wrap() as safe_state:
for user in safe_state._users.values():
channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {})
if channel and not channel[_Key.ACTIVE]:
reactivate = channel.get(_Key.REACTIVATE, None)
# Check if we've gone past the reactivation time and re-activate
if reactivate and datetime.now() > ts_to_datetime(reactivate):
channel[_Key.ACTIVE] = True
@property
def dict_internal_copy(self) -> dict:
"""Only to be used to get the internal dict as a copy"""
return copy.deepcopy(self._dict)
@property
def _history(self) -> dict[str]:
return self._dict[_Key.HISTORY]
@property
def _users(self) -> dict[str]:
return self._dict[_Key.USERS]
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
"""Set a user channel property helper"""
with self._safe_wrap() as safe_state:
# Dive in
user = safe_state._users.get(str(id), {})
channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {})
# Set the value
channel[key] = value
# Unroll
channels[str(channel_id)] = channel
user[_Key.CHANNELS] = channels
safe_state._users[str(id)] = user
@contextmanager
def _safe_wrap(self):
"""Safely run any function wrapped in a validate"""
# Wrap in a temporary state to validate first to prevent corruption
tmp_state = State(self._dict)
try:
yield tmp_state
finally:
# Validate and then overwrite our dict with the new one
tmp_state.validate()
self._dict = tmp_state._dict
def _migrate(dict: dict):
"""Migrate a dict through versions"""
version = dict.get("version", 0)
for i in range(version, _VERSION):
logger.info("Migrating from v%s to v%s", version, version+1)
_MIGRATIONS[i](dict)
dict[_Key.VERSION] = _VERSION
def load_from_file(file: str) -> State:
"""
Load the state from a file
Apply any required migrations
"""
loaded = _EMPTY_DICT
# If there's a file load it and try to migrate
if os.path.isfile(file):
loaded = files.load(file)
_migrate(loaded)
st = State(loaded)
# Save out the migrated (or new) file
files.save(file, st._dict)
return st
def save_to_file(state: State, file: str):
"""Saves the state out to a file"""
files.save(file, state.dict_internal_copy)

65
py/state_test.py Normal file
View file

@ -0,0 +1,65 @@
"""
Test functions for the state module
"""
import state
import tempfile
import os
def test_basic_state():
"""Simple validate basic state load"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
state.load_from_file(path)
def test_simple_load_reload():
"""Test a basic load, save, reload"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
st = state.load_from_file(path)
state.save_to_file(st, path)
st = state.load_from_file(path)
def test_authscope():
"""Test setting and getting an auth scope"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
st = state.load_from_file(path)
st.set_user_scope(1, state.AuthScope.MATCHER)
state.save_to_file(st, path)
st = state.load_from_file(path)
assert st.get_user_has_scope(1, state.AuthScope.MATCHER)
st.set_user_scope(1, state.AuthScope.MATCHER, False)
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
def test_channeljoin():
"""Test setting and getting an active channel"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
assert not st.get_user_active_in_channel(1, "2")
st = state.load_from_file(path)
st.set_user_active_in_channel(1, "2", True)
state.save_to_file(st, path)
st = state.load_from_file(path)
assert st.get_user_active_in_channel(1, "2")
st.set_user_active_in_channel(1, "2", False)
assert not st.get_user_active_in_channel(1, "2")