commit
db00c9f7c1
19 changed files with 1280 additions and 644 deletions
7
.github/workflows/test.yml
vendored
7
.github/workflows/test.yml
vendored
|
@ -18,12 +18,9 @@ jobs:
|
|||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r requirements.txt
|
||||
- name: Analysing the code with flake8
|
||||
- name: Run tests
|
||||
run: |
|
||||
flake8 --max-line-length 120 $(git ls-files '*.py')
|
||||
- name: Run tests with pytest
|
||||
run: |
|
||||
pytest
|
||||
bash bin/test.sh
|
||||
- name: Update release branch
|
||||
if: github.ref == 'refs/heads/main'
|
||||
run: |
|
||||
|
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,4 +1,4 @@
|
|||
__pycache__
|
||||
config.json
|
||||
history.json
|
||||
state.json
|
||||
.venv
|
2
.vscode/launch.json
vendored
2
.vscode/launch.json
vendored
|
@ -8,7 +8,7 @@
|
|||
"name": "Python Debugger: Matchy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "matchy.py",
|
||||
"program": "py/matchy.py",
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
|
|
32
README.md
32
README.md
|
@ -6,8 +6,14 @@ Matchy matches matchees.
|
|||
Matchy is a Discord bot that groups up users for fun and vibes. Matchy can be installed by clicking [here](https://discord.com/oauth2/authorize?client_id=1270849346987884696).
|
||||
|
||||
## Commands
|
||||
### /match [group_min: int(3)] [matchee_role: str(@Matchee)]
|
||||
Matches groups of users with a given role and posts those groups to the channel. Tracks historical matches and attempts to match users to make new connections with people with divergent roles, in an attempt to maximise diversity.
|
||||
### /match [group_min: int(3)]
|
||||
Matches groups of users in a channel and offers a button to pose those groups to the channel to users with `matcher` auth scope. Tracks historical matches and attempts to match users to make new connections with people with divergent roles, in an attempt to maximise diversity.
|
||||
|
||||
### /join and /leave
|
||||
Allows users to sign up and leave the group matching in the channel the command is used
|
||||
|
||||
### /pause [days: int(7)]
|
||||
Allows users to pause their matching in a channel for a given number of days. Users can use `/join` to re-join before the end of that time.
|
||||
|
||||
### $sync and $close
|
||||
Only usable by `OWNER` users, reloads the config and syncs commands, or closes down the bot. Only usable in DMs with the bot user.
|
||||
|
@ -18,20 +24,24 @@ Only usable by `OWNER` users, reloads the config and syncs commands, or closes d
|
|||
|
||||
## Config
|
||||
Matchy is configured by a `config.json` file that takes this format:
|
||||
```
|
||||
```json
|
||||
{
|
||||
"version" : 1,
|
||||
"token" : "<<github bot token>>",
|
||||
"owners": [
|
||||
<<owner id>>
|
||||
]
|
||||
|
||||
"match" : {
|
||||
"score_factors": {
|
||||
"repeat_role" : 4,
|
||||
"repeat_match" : 8,
|
||||
"extra_member" : 32,
|
||||
"upper_threshold" : 64
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
User IDs can be grabbed by turning on Discord's developer mode and right clicking on a user.
|
||||
Only token and version are required. See [`py/config.py`](py/config.py) for explanations for any of these.
|
||||
|
||||
## TODO
|
||||
* Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html)
|
||||
* Add scheduling functionality
|
||||
* Version the config and history files
|
||||
* Implement /signup rather than using roles
|
||||
* Implement authorisation scopes instead of just OWNER values
|
||||
* Write integration tests (maybe with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html)?)
|
||||
* Improve the weirdo
|
5
bin/coverage.sh
Executable file
5
bin/coverage.sh
Executable file
|
@ -0,0 +1,5 @@
|
|||
#!/usr/bin/env bash
|
||||
set -x
|
||||
set -e
|
||||
|
||||
pytest --cov=. --cov-report=html
|
|
@ -8,4 +8,4 @@ if [ ! -d .venv ]; then
|
|||
fi
|
||||
source .venv/bin/activate
|
||||
python -m pip install -r requirements.txt
|
||||
python matchy.py
|
||||
python py/matchy.py
|
9
bin/test.sh
Executable file
9
bin/test.sh
Executable file
|
@ -0,0 +1,9 @@
|
|||
#!/usr/bin/env bash
|
||||
set -x
|
||||
set -e
|
||||
|
||||
# Check formatting and linting
|
||||
flake8 --max-line-length 120 $(git ls-files '*.py')
|
||||
|
||||
# Run pytest
|
||||
pytest
|
38
config.py
38
config.py
|
@ -1,38 +0,0 @@
|
|||
"""Very simple config loading library"""
|
||||
from schema import Schema, And, Use
|
||||
import files
|
||||
|
||||
_FILE = "config.json"
|
||||
_SCHEMA = Schema(
|
||||
{
|
||||
# Discord bot token
|
||||
"token": And(Use(str)),
|
||||
|
||||
# ids of owners authorised to use owner-only commands
|
||||
"owners": And(Use(list[int])),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Config():
|
||||
def __init__(self, data: dict):
|
||||
"""Initialise and validate the config"""
|
||||
_SCHEMA.validate(data)
|
||||
self.__dict__ = data
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
return self.__dict__["token"]
|
||||
|
||||
@property
|
||||
def owners(self) -> list[int]:
|
||||
return self.__dict__["owners"]
|
||||
|
||||
def reload(self) -> None:
|
||||
"""Reload the config back into the dict"""
|
||||
self.__dict__ = load().__dict__
|
||||
|
||||
|
||||
def load() -> Config:
|
||||
"""Load the config"""
|
||||
return Config(files.load(_FILE))
|
|
@ -1,2 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
pytest --cov=. --cov-report=html
|
125
history.py
125
history.py
|
@ -1,125 +0,0 @@
|
|||
"""Store matching history"""
|
||||
import os
|
||||
from datetime import datetime
|
||||
from schema import Schema, And, Use, Optional
|
||||
from typing import Protocol
|
||||
import files
|
||||
import copy
|
||||
|
||||
_FILE = "history.json"
|
||||
|
||||
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
|
||||
_DEFAULT_DICT = {
|
||||
"history": {},
|
||||
"matchees": {}
|
||||
}
|
||||
_TIME_FORMAT = "%a %b %d %H:%M:%S %Y"
|
||||
_SCHEMA = Schema(
|
||||
{
|
||||
Optional("history"): {
|
||||
Optional(str): { # a datetime
|
||||
"groups": [
|
||||
{
|
||||
"members": [
|
||||
# The ID of each matchee in the match
|
||||
And(Use(int))
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
Optional("matchees"): {
|
||||
Optional(str): {
|
||||
Optional("matches"): {
|
||||
# Matchee ID and Datetime pair
|
||||
Optional(str): And(Use(str))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Member(Protocol):
|
||||
@property
|
||||
def id(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
def ts_to_datetime(ts: str) -> datetime:
|
||||
"""Convert a ts to datetime using the history format"""
|
||||
return datetime.strptime(ts, _TIME_FORMAT)
|
||||
|
||||
|
||||
def validate(dict: dict):
|
||||
"""Initialise and validate the history"""
|
||||
_SCHEMA.validate(dict)
|
||||
|
||||
|
||||
class History():
|
||||
def __init__(self, data: dict = _DEFAULT_DICT):
|
||||
"""Initialise and validate the history"""
|
||||
validate(data)
|
||||
self.__dict__ = copy.deepcopy(data)
|
||||
|
||||
@property
|
||||
def history(self) -> list[dict]:
|
||||
return self.__dict__["history"]
|
||||
|
||||
@property
|
||||
def matchees(self) -> dict[str, dict]:
|
||||
return self.__dict__["matchees"]
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save out the history"""
|
||||
files.save(_FILE, self.__dict__)
|
||||
|
||||
def oldest(self) -> datetime:
|
||||
"""Grab the oldest timestamp in history"""
|
||||
if not self.history:
|
||||
return None
|
||||
times = (ts_to_datetime(dt) for dt in self.history.keys())
|
||||
return sorted(times)[0]
|
||||
|
||||
def log_groups_to_history(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None:
|
||||
"""Log the groups"""
|
||||
tmp_history = History(self.__dict__)
|
||||
ts = datetime.strftime(ts, _TIME_FORMAT)
|
||||
|
||||
# Grab or create the hitory item for this set of groups
|
||||
history_item = {}
|
||||
tmp_history.history[ts] = history_item
|
||||
history_item_groups = []
|
||||
history_item["groups"] = history_item_groups
|
||||
|
||||
for group in groups:
|
||||
|
||||
# Add the group data
|
||||
history_item_groups.append({
|
||||
"members": list(m.id for m in group)
|
||||
})
|
||||
|
||||
# Update the matchee data with the matches
|
||||
for m in group:
|
||||
matchee = tmp_history.matchees.get(str(m.id), {})
|
||||
matchee_matches = matchee.get("matches", {})
|
||||
|
||||
for o in (o for o in group if o.id != m.id):
|
||||
matchee_matches[str(o.id)] = ts
|
||||
|
||||
matchee["matches"] = matchee_matches
|
||||
tmp_history.matchees[str(m.id)] = matchee
|
||||
|
||||
# Validate before storing the result
|
||||
validate(self.__dict__)
|
||||
self.__dict__ = tmp_history.__dict__
|
||||
|
||||
def save_groups_to_history(self, groups: list[list[Member]]) -> None:
|
||||
"""Save out the groups to the history file"""
|
||||
self.log_groups_to_history(groups)
|
||||
self.save()
|
||||
|
||||
|
||||
def load() -> History:
|
||||
"""Load the history"""
|
||||
return History(files.load(_FILE) if os.path.isfile(_FILE) else _DEFAULT_DICT)
|
215
matching_test.py
215
matching_test.py
|
@ -1,215 +0,0 @@
|
|||
"""
|
||||
Test functions for Matchy
|
||||
"""
|
||||
import discord
|
||||
import pytest
|
||||
import random
|
||||
import matching
|
||||
import history
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def test_protocols():
|
||||
"""Verify the protocols we're using match the discord ones"""
|
||||
assert isinstance(discord.Member, matching.Member)
|
||||
assert isinstance(discord.Guild, matching.Guild)
|
||||
assert isinstance(discord.Role, matching.Role)
|
||||
assert isinstance(Member, matching.Member)
|
||||
# assert isinstance(Role, matching.Role)
|
||||
|
||||
|
||||
class Role():
|
||||
def __init__(self, id: int):
|
||||
self._id = id
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
|
||||
class Member():
|
||||
def __init__(self, id: int, roles: list[Role] = []):
|
||||
self._id = id
|
||||
|
||||
@property
|
||||
def mention(self) -> str:
|
||||
return f"<@{self._id}>"
|
||||
|
||||
@property
|
||||
def roles(self) -> list[Role]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
|
||||
def inner_validate_members_to_groups(matchees: list[Member], hist: history.History, per_group: int):
|
||||
"""Inner function to validate the main output of the groups function"""
|
||||
groups = matching.members_to_groups(matchees, hist, per_group)
|
||||
|
||||
# We should always have one group
|
||||
assert len(groups)
|
||||
|
||||
# Log the groups to history
|
||||
# This will validate the internals
|
||||
hist.log_groups_to_history(groups)
|
||||
|
||||
# Ensure each group contains within the bounds of expected members
|
||||
for group in groups:
|
||||
if len(matchees) >= per_group:
|
||||
assert len(group) >= per_group
|
||||
else:
|
||||
assert len(group) == len(matchees)
|
||||
assert len(group) < per_group*2 # TODO: We could be more strict here
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
@pytest.mark.parametrize("matchees, per_group", [
|
||||
# Simplest test possible
|
||||
([Member(1)], 1),
|
||||
|
||||
# More requested than we have
|
||||
([Member(1)], 2),
|
||||
|
||||
# A selection of hyper-simple checks to validate core functionality
|
||||
([Member(1)] * 100, 3),
|
||||
([Member(1)] * 12, 5),
|
||||
([Member(1)] * 11, 2),
|
||||
([Member(1)] * 356, 8),
|
||||
], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"])
|
||||
def test_members_to_groups_no_history(matchees, per_group):
|
||||
"""Test simple group matching works"""
|
||||
hist = history.History()
|
||||
inner_validate_members_to_groups(matchees, hist, per_group)
|
||||
|
||||
|
||||
def items_found_in_lists(list_of_lists, items):
|
||||
"""validates if any sets of items are found in individual lists"""
|
||||
for sublist in list_of_lists:
|
||||
if all(item in sublist for item in items):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("history_data, matchees, per_group, checks", [
|
||||
# Slightly more difficult test
|
||||
# Describe a history where we previously matched up some people and ensure they don't get rematched
|
||||
(
|
||||
[
|
||||
{
|
||||
"ts": datetime.now() - timedelta(days=1),
|
||||
"groups": [
|
||||
[Member(1), Member(2)],
|
||||
[Member(3), Member(4)],
|
||||
]
|
||||
}
|
||||
],
|
||||
[
|
||||
Member(1),
|
||||
Member(2),
|
||||
Member(3),
|
||||
Member(4),
|
||||
],
|
||||
2,
|
||||
[
|
||||
lambda groups: not items_found_in_lists(
|
||||
groups, [Member(1), Member(2)]),
|
||||
lambda groups: not items_found_in_lists(
|
||||
groups, [Member(3), Member(4)])
|
||||
]
|
||||
),
|
||||
# Feed the system an "impossible" test
|
||||
# The function should fall back to ignoring history and still give us something
|
||||
(
|
||||
[
|
||||
{
|
||||
"ts": datetime.now() - timedelta(days=1),
|
||||
"groups": [
|
||||
[
|
||||
Member(1),
|
||||
Member(2),
|
||||
Member(3)
|
||||
],
|
||||
[
|
||||
Member(4),
|
||||
Member(5),
|
||||
Member(6)
|
||||
],
|
||||
]
|
||||
}
|
||||
],
|
||||
[
|
||||
Member(1, [Role(1), Role(2), Role(3), Role(4)]),
|
||||
Member(2, [Role(1), Role(2), Role(3), Role(4)]),
|
||||
Member(3, [Role(1), Role(2), Role(3), Role(4)]),
|
||||
Member(4, [Role(1), Role(2), Role(3), Role(4)]),
|
||||
Member(5, [Role(1), Role(2), Role(3), Role(4)]),
|
||||
Member(6, [Role(1), Role(2), Role(3), Role(4)]),
|
||||
],
|
||||
3,
|
||||
[
|
||||
# Nothing specific to validate
|
||||
]
|
||||
),
|
||||
], ids=['simple_history', 'fallback'])
|
||||
def test_members_to_groups_with_history(history_data, matchees, per_group, checks):
|
||||
"""Test more advanced group matching works"""
|
||||
hist = history.History()
|
||||
|
||||
# Replay the history
|
||||
for d in history_data:
|
||||
hist.log_groups_to_history(d["groups"], d["ts"])
|
||||
|
||||
groups = inner_validate_members_to_groups(matchees, hist, per_group)
|
||||
|
||||
# Run the custom validate functions
|
||||
for check in checks:
|
||||
assert check(groups)
|
||||
|
||||
|
||||
def test_members_to_groups_stress_test():
|
||||
"""stress test firing significant random data at the code"""
|
||||
|
||||
# Use a stable rand, feel free to adjust this if needed but this lets the test be stable
|
||||
rand = random.Random(123)
|
||||
|
||||
# Slowly ramp up the group size
|
||||
for per_group in range(2, 6):
|
||||
|
||||
# Slowly ramp a randomized shuffled list of members with randomised roles
|
||||
for num_members in range(1, 5):
|
||||
matchees = list(Member(i, list(Role(i) for i in range(1, rand.randint(2, num_members*2 + 1))))
|
||||
for i in range(1, rand.randint(2, num_members*10 + 1)))
|
||||
rand.shuffle(matchees)
|
||||
|
||||
for num_history in range(8):
|
||||
|
||||
# Generate some super random history
|
||||
# Start some time from now to the past
|
||||
time = datetime.now() - timedelta(days=rand.randint(0, num_history*5))
|
||||
history_data = []
|
||||
for x in range(0, num_history):
|
||||
run = {
|
||||
"ts": time
|
||||
}
|
||||
groups = []
|
||||
for y in range(1, num_history):
|
||||
groups.append(list(Member(i)
|
||||
for i in range(1, max(num_members, rand.randint(2, num_members*10 + 1)))))
|
||||
run["groups"] = groups
|
||||
history_data.append(run)
|
||||
|
||||
# Step some time backwards in time
|
||||
time -= timedelta(days=rand.randint(1, num_history))
|
||||
|
||||
# No guarantees on history data order so make it a little harder for matchy
|
||||
rand.shuffle(history_data)
|
||||
|
||||
# Replay the history
|
||||
hist = history.History()
|
||||
for d in history_data:
|
||||
hist.log_groups_to_history(d["groups"], d["ts"])
|
||||
|
||||
inner_validate_members_to_groups(matchees, hist, per_group)
|
194
matchy.py
194
matchy.py
|
@ -1,194 +0,0 @@
|
|||
"""
|
||||
matchy.py - Discord bot that matches people into groups
|
||||
"""
|
||||
import logging
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
import matching
|
||||
import history
|
||||
import config
|
||||
import re
|
||||
|
||||
|
||||
Config = config.load()
|
||||
History = history.load()
|
||||
|
||||
logger = logging.getLogger("matchy")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.members = True
|
||||
bot = commands.Bot(command_prefix='$',
|
||||
description="Matchy matches matchees", intents=intents)
|
||||
|
||||
|
||||
@bot.event
|
||||
async def setup_hook():
|
||||
bot.add_dynamic_items(DynamicGroupButton)
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
"""Bot is ready and connected"""
|
||||
logger.info("Bot is up and ready!")
|
||||
activity = discord.Game("/match")
|
||||
await bot.change_presence(status=discord.Status.online, activity=activity)
|
||||
|
||||
|
||||
def owner_only(ctx: commands.Context) -> bool:
|
||||
"""Checks the author is an owner"""
|
||||
return ctx.message.author.id in Config.owners
|
||||
|
||||
|
||||
@bot.command()
|
||||
@commands.dm_only()
|
||||
@commands.check(owner_only)
|
||||
async def sync(ctx: commands.Context):
|
||||
"""Handle sync command"""
|
||||
msg = await ctx.reply("Reloading config...", ephemeral=True)
|
||||
Config.reload()
|
||||
logger.info("Reloaded config")
|
||||
|
||||
await msg.edit(content="Syncing commands...")
|
||||
synced = await bot.tree.sync()
|
||||
logger.info("Synced %s command(s)", len(synced))
|
||||
|
||||
await msg.edit(content="Done!")
|
||||
|
||||
|
||||
@bot.command()
|
||||
@commands.dm_only()
|
||||
@commands.check(owner_only)
|
||||
async def close(ctx: commands.Context):
|
||||
"""Handle restart command"""
|
||||
await ctx.reply("Closing bot...", ephemeral=True)
|
||||
logger.info("Closing down the bot")
|
||||
await bot.close()
|
||||
|
||||
|
||||
# @bot.tree.command(description="Sign up as a matchee in this server")
|
||||
# @commands.guild_only()
|
||||
# async def join(interaction: discord.Interaction):
|
||||
# # TODO: Sign up
|
||||
# await interaction.response.send_message(
|
||||
# f"Awesome, great to have you on board {interaction.user.mention}!", ephemeral=True)
|
||||
|
||||
|
||||
# @bot.tree.command(description="Leave the matchee list in this server")
|
||||
# @commands.guild_only()
|
||||
# async def leave(interaction: discord.Interaction):
|
||||
# # TODO: Remove the user
|
||||
# await interaction.response.send_message(
|
||||
# f"No worries, see you soon {interaction.user.mention}!", ephemeral=True)
|
||||
|
||||
|
||||
@bot.tree.command(description="Match up matchees")
|
||||
@commands.guild_only()
|
||||
@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)",
|
||||
matchee_role="Role for matchees (defaults to @Matchee)")
|
||||
async def match(interaction: discord.Interaction, members_min: int = None, matchee_role: str = None):
|
||||
"""Match groups of channel members"""
|
||||
|
||||
logger.info("Handling request '/match group_min=%s matchee_role=%s'",
|
||||
members_min, matchee_role)
|
||||
logger.info("User %s from %s in #%s", interaction.user,
|
||||
interaction.guild.name, interaction.channel.name)
|
||||
|
||||
# Sort out the defaults, if not specified they'll come in as None
|
||||
if not members_min:
|
||||
members_min = 3
|
||||
if not matchee_role:
|
||||
matchee_role = "Matchee"
|
||||
|
||||
# Grab the roles and verify the given role
|
||||
matcher = matching.get_role_from_guild(interaction.guild, "Matcher")
|
||||
matcher = matcher and matcher in interaction.user.roles
|
||||
matchee = matching.get_role_from_guild(interaction.guild, matchee_role)
|
||||
if not matchee:
|
||||
await interaction.response.send_message(f"Server is missing '{matchee_role}' role :(", ephemeral=True)
|
||||
return
|
||||
|
||||
# Create some example groups to show the user
|
||||
matchees = list(
|
||||
m for m in interaction.channel.members if matchee in m.roles)
|
||||
groups = matching.members_to_groups(
|
||||
matchees, History, members_min, allow_fallback=True)
|
||||
|
||||
# Post about all the groups with a button to send to the channel
|
||||
groups_list = '\n'.join(matching.group_to_message(g) for g in groups)
|
||||
msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}"
|
||||
view = discord.utils.MISSING
|
||||
|
||||
if not matcher:
|
||||
# Let a non-matcher know why they don't have the button
|
||||
msg += "\n\nYou'll need the 'Matcher' role to post this to the channel, sorry!"
|
||||
else:
|
||||
# Otherwise set up the button
|
||||
msg += "\n\nClick the button to match up groups and send them to the channel.\n"
|
||||
view = discord.ui.View(timeout=None)
|
||||
view.add_item(DynamicGroupButton(members_min, matchee_role))
|
||||
|
||||
await interaction.response.send_message(msg, ephemeral=True, silent=True, view=view)
|
||||
|
||||
logger.info("Done.")
|
||||
|
||||
|
||||
class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
|
||||
template=r'match:min:(?P<min>[0-9]+):role:(?P<role>[@\w\s]+)'):
|
||||
def __init__(self, min: int, role: str) -> None:
|
||||
super().__init__(
|
||||
discord.ui.Button(
|
||||
label='Match Groups!',
|
||||
style=discord.ButtonStyle.blurple,
|
||||
custom_id=f'match:min:{min}:role:{role}',
|
||||
)
|
||||
)
|
||||
self.min: int = min
|
||||
self.role: int = role
|
||||
|
||||
# This is called when the button is clicked and the custom_id matches the template.
|
||||
@classmethod
|
||||
async def from_custom_id(cls, interaction: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /):
|
||||
min = int(match['min'])
|
||||
role = str(match['role'])
|
||||
return cls(min, role)
|
||||
|
||||
async def callback(self, interaction: discord.Interaction) -> None:
|
||||
"""Match up people when the button is pressed"""
|
||||
|
||||
logger.info("Handling button press min=%s role=%s'",
|
||||
self.min, self.role)
|
||||
logger.info("User %s from %s in #%s", interaction.user,
|
||||
interaction.guild.name, interaction.channel.name)
|
||||
|
||||
# Let the user know we've recieved the message
|
||||
await interaction.response.send_message(content="Matchy is matching matchees...", ephemeral=True)
|
||||
|
||||
# Grab the role
|
||||
matchee = matching.get_role_from_guild(interaction.guild, self.role)
|
||||
|
||||
# Create our groups!
|
||||
matchees = list(
|
||||
m for m in interaction.channel.members if matchee in m.roles)
|
||||
groups = matching.members_to_groups(
|
||||
matchees, History, self.min, allow_fallback=True)
|
||||
|
||||
# Send the groups
|
||||
for msg in (matching.group_to_message(g) for g in groups):
|
||||
await interaction.channel.send(msg)
|
||||
|
||||
# Close off with a message
|
||||
await interaction.channel.send("That's all folks, happy matching and remember - DFTBA!")
|
||||
|
||||
# Save the groups to the history
|
||||
History.save_groups_to_history(groups)
|
||||
|
||||
logger.info("Done. Matched %s matchees into %s groups.",
|
||||
len(matchees), len(groups))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
handler = logging.StreamHandler()
|
||||
bot.run(Config.token, log_handler=handler, root_logger=True)
|
135
py/config.py
Normal file
135
py/config.py
Normal file
|
@ -0,0 +1,135 @@
|
|||
"""Very simple config loading library"""
|
||||
from schema import Schema, And, Use, Optional
|
||||
import files
|
||||
import os
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("config")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
_FILE = "config.json"
|
||||
|
||||
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
|
||||
_VERSION = 1
|
||||
|
||||
|
||||
class _Key():
|
||||
TOKEN = "token"
|
||||
VERSION = "version"
|
||||
|
||||
MATCH = "match"
|
||||
|
||||
SCORE_FACTORS = "score_factors"
|
||||
REPEAT_ROLE = "repeat_role"
|
||||
REPEAT_MATCH = "repeat_match"
|
||||
EXTRA_MEMBER = "extra_member"
|
||||
UPPER_THRESHOLD = "upper_threshold"
|
||||
|
||||
# Removed
|
||||
_OWNERS = "owners"
|
||||
|
||||
|
||||
_SCHEMA = Schema(
|
||||
{
|
||||
# The current version
|
||||
_Key.VERSION: And(Use(int)),
|
||||
|
||||
# Discord bot token
|
||||
_Key.TOKEN: And(Use(str)),
|
||||
|
||||
# Settings for the match algorithmn, see matching.py for explanations on usage
|
||||
Optional(_Key.MATCH): {
|
||||
Optional(_Key.SCORE_FACTORS): {
|
||||
|
||||
Optional(_Key.REPEAT_ROLE): And(Use(int)),
|
||||
Optional(_Key.REPEAT_MATCH): And(Use(int)),
|
||||
Optional(_Key.EXTRA_MEMBER): And(Use(int)),
|
||||
Optional(_Key.UPPER_THRESHOLD): And(Use(int)),
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
_EMPTY_DICT = {
|
||||
_Key.TOKEN: "",
|
||||
_Key.VERSION: _VERSION
|
||||
}
|
||||
|
||||
|
||||
def _migrate_to_v1(d: dict):
|
||||
# Owners moved to History in v1
|
||||
# Note: owners will be required to be re-added to the state.json
|
||||
owners = d.pop(_Key._OWNERS)
|
||||
logger.warn(
|
||||
"Migration removed owners from config, these must be re-added to the state.json")
|
||||
logger.warn("Owners: %s", owners)
|
||||
|
||||
|
||||
# Set of migration functions to apply
|
||||
_MIGRATIONS = [
|
||||
_migrate_to_v1
|
||||
]
|
||||
|
||||
|
||||
class _ScoreFactors():
|
||||
def __init__(self, data: dict):
|
||||
"""Initialise and validate the config"""
|
||||
self._dict = data
|
||||
|
||||
@property
|
||||
def repeat_role(self) -> int:
|
||||
return self._dict.get(_Key.REPEAT_ROLE, None)
|
||||
|
||||
@property
|
||||
def repeat_match(self) -> int:
|
||||
return self._dict.get(_Key.REPEAT_MATCH, None)
|
||||
|
||||
@property
|
||||
def extra_member(self) -> int:
|
||||
return self._dict.get(_Key.EXTRA_MEMBER, None)
|
||||
|
||||
@property
|
||||
def upper_threshold(self) -> int:
|
||||
return self._dict.get(_Key.UPPER_THRESHOLD, None)
|
||||
|
||||
|
||||
class _Config():
|
||||
def __init__(self, data: dict):
|
||||
"""Initialise and validate the config"""
|
||||
_SCHEMA.validate(data)
|
||||
self._dict = data
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
return self._dict["token"]
|
||||
|
||||
@property
|
||||
def score_factors(self) -> _ScoreFactors:
|
||||
return _ScoreFactors(self._dict.get(_Key.SCORE_FACTORS, {}))
|
||||
|
||||
|
||||
def _migrate(dict: dict):
|
||||
"""Migrate a dict through versions"""
|
||||
version = dict.get("version", 0)
|
||||
for i in range(version, _VERSION):
|
||||
_MIGRATIONS[i](dict)
|
||||
dict["version"] = _VERSION
|
||||
|
||||
|
||||
def _load_from_file(file: str = _FILE) -> _Config:
|
||||
"""
|
||||
Load the state from a file
|
||||
Apply any required migrations
|
||||
"""
|
||||
loaded = _EMPTY_DICT
|
||||
if os.path.isfile(file):
|
||||
loaded = files.load(file)
|
||||
_migrate(loaded)
|
||||
else:
|
||||
logger.warn("No %s file found, bot cannot run!", file)
|
||||
return _Config(loaded)
|
||||
|
||||
|
||||
# Core config for users to use
|
||||
# Singleton as there should only be one, and it's global
|
||||
Config = _load_from_file()
|
|
@ -1,25 +1,27 @@
|
|||
"""Utility functions for matchy"""
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Protocol, runtime_checkable
|
||||
import history
|
||||
import state
|
||||
import config
|
||||
|
||||
|
||||
# Number of days to step forward from the start of history for each match attempt
|
||||
_ATTEMPT_TIMESTEP_INCREMENT = timedelta(days=7)
|
||||
class _ScoreFactors(int):
|
||||
"""
|
||||
Score factors used when trying to build up "best fit" groups
|
||||
Matchees are sequentially placed into the lowest scoring available group
|
||||
"""
|
||||
|
||||
# Attempts for each of those time periods
|
||||
_ATTEMPTS_PER_TIMESTEP = 3
|
||||
# Added for each role the matchee has that another group member has
|
||||
REPEAT_ROLE = config.Config.score_factors.repeat_role or 2**2
|
||||
# Added for each member in the group that the matchee has already matched with
|
||||
REPEAT_MATCH = config.Config.score_factors.repeat_match or 2**3
|
||||
# Added for each additional member over the set "per group" value
|
||||
EXTRA_MEMBER = config.Config.score_factors.extra_member or 2**5
|
||||
|
||||
# Various eligability scoring factors for group meetups
|
||||
_SCORE_CURRENT_MEMBERS = 2**1
|
||||
_SCORE_REPEAT_ROLE = 2**2
|
||||
_SCORE_REPEAT_MATCH = 2**3
|
||||
_SCORE_EXTRA_MEMBERS = 2**4
|
||||
# Upper threshold, if the user scores higher than this they will not be placed in that group
|
||||
UPPER_THRESHOLD = config.Config.score_factors.upper_threshold or 2**6
|
||||
|
||||
# Scores higher than this are fully rejected
|
||||
_SCORE_UPPER_THRESHOLD = 2**6
|
||||
|
||||
logger = logging.getLogger("matching")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
@ -69,33 +71,42 @@ def members_to_groups_simple(matchees: list[Member], per_group: int) -> tuple[bo
|
|||
|
||||
def get_member_group_eligibility_score(member: Member,
|
||||
group: list[Member],
|
||||
relevant_matches: list[int],
|
||||
per_group: int) -> int:
|
||||
prior_matches: list[int],
|
||||
per_group: int) -> float:
|
||||
"""Rates a member against a group"""
|
||||
rating = len(group) * _SCORE_CURRENT_MEMBERS
|
||||
# An empty group is a "perfect" score atomatically
|
||||
rating = 0
|
||||
if not group:
|
||||
return rating
|
||||
|
||||
repeat_meetings = sum(m.id in relevant_matches for m in group)
|
||||
rating += repeat_meetings * _SCORE_REPEAT_MATCH
|
||||
# Add score based on prior matchups of this user
|
||||
num_prior = sum(m.id in prior_matches for m in group)
|
||||
rating += num_prior * _ScoreFactors.REPEAT_MATCH
|
||||
|
||||
repeat_roles = sum(r in member.roles for r in (m.roles for m in group))
|
||||
rating += (repeat_roles * _SCORE_REPEAT_ROLE)
|
||||
# Calculate the number of roles that match
|
||||
all_role_ids = set(r.id for mr in [r.roles for r in group] for r in mr)
|
||||
member_role_ids = [r.id for r in member.roles]
|
||||
repeat_roles = sum(id in member_role_ids for id in all_role_ids)
|
||||
rating += repeat_roles * _ScoreFactors.REPEAT_ROLE
|
||||
|
||||
extra_members = len(group) - per_group
|
||||
if extra_members > 0:
|
||||
rating += extra_members * _SCORE_EXTRA_MEMBERS
|
||||
# Add score based on the number of extra members
|
||||
# Calculate the member offset (+1 for this user)
|
||||
extra_members = (len(group) - per_group) + 1
|
||||
if extra_members >= 0:
|
||||
rating += extra_members * _ScoreFactors.EXTRA_MEMBER
|
||||
|
||||
return rating
|
||||
|
||||
|
||||
def attempt_create_groups(matchees: list[Member],
|
||||
hist: history.History,
|
||||
current_state: state.State,
|
||||
oldest_relevant_ts: datetime,
|
||||
per_group: int) -> tuple[bool, list[list[Member]]]:
|
||||
"""History aware group matching"""
|
||||
num_groups = max(len(matchees)//per_group, 1)
|
||||
|
||||
# Set up the groups in place
|
||||
groups = list([] for _ in range(num_groups))
|
||||
groups = [[] for _ in range(num_groups)]
|
||||
|
||||
matchees_left = matchees.copy()
|
||||
|
||||
|
@ -103,21 +114,21 @@ def attempt_create_groups(matchees: list[Member],
|
|||
while matchees_left:
|
||||
# Get the next matchee to place
|
||||
matchee = matchees_left.pop()
|
||||
matchee_matches = hist.matchees.get(
|
||||
str(matchee.id), {}).get("matches", {})
|
||||
relevant_matches = list(int(id) for id, ts in matchee_matches.items()
|
||||
if history.ts_to_datetime(ts) >= oldest_relevant_ts)
|
||||
matchee_matches = current_state.get_user_matches(matchee.id)
|
||||
relevant_matches = [int(id) for id, ts
|
||||
in matchee_matches.items()
|
||||
if state.ts_to_datetime(ts) >= oldest_relevant_ts]
|
||||
|
||||
# Try every single group from the current group onwards
|
||||
# Progressing through the groups like this ensures we slowly fill them up with compatible people
|
||||
scores: list[tuple[int, int]] = []
|
||||
scores: list[tuple[int, float]] = []
|
||||
for group in groups:
|
||||
|
||||
score = get_member_group_eligibility_score(
|
||||
matchee, group, relevant_matches, num_groups)
|
||||
matchee, group, relevant_matches, per_group)
|
||||
|
||||
# If the score isn't too high, consider this group
|
||||
if score <= _SCORE_UPPER_THRESHOLD:
|
||||
if score <= _ScoreFactors.UPPER_THRESHOLD:
|
||||
scores.append((group, score))
|
||||
|
||||
# Optimisation:
|
||||
|
@ -143,31 +154,38 @@ def datetime_range(start_time: datetime, increment: timedelta, end: datetime):
|
|||
current += increment
|
||||
|
||||
|
||||
def iterate_all_shifts(list: list):
|
||||
"""Yields each shifted variation of the input list"""
|
||||
yield list
|
||||
for _ in range(len(list)-1):
|
||||
list = list[1:] + [list[0]]
|
||||
yield list
|
||||
|
||||
|
||||
def members_to_groups(matchees: list[Member],
|
||||
hist: history.History = history.History(),
|
||||
st: state.State = state.State(),
|
||||
per_group: int = 3,
|
||||
allow_fallback: bool = False) -> list[list[Member]]:
|
||||
"""Generate the groups from the set of matchees"""
|
||||
attempts = 0 # Tracking for logging purposes
|
||||
rand = random.Random(117) # Some stable randomness
|
||||
num_groups = len(matchees)//per_group
|
||||
|
||||
# Grab the oldest timestamp
|
||||
history_start = hist.oldest() or datetime.now()
|
||||
# Bail early if there's no-one to match
|
||||
if not matchees:
|
||||
return []
|
||||
|
||||
# Walk from the start of time until now using the timestep increment
|
||||
for oldest_relevant_datetime in datetime_range(history_start, _ATTEMPT_TIMESTEP_INCREMENT, datetime.now()):
|
||||
# Walk from the start of history until now trying to match up groups
|
||||
for oldest_relevant_datetime in st.get_history_timestamps() + [datetime.now()]:
|
||||
|
||||
# Have a few attempts before stepping forward in time
|
||||
for _ in range(_ATTEMPTS_PER_TIMESTEP):
|
||||
|
||||
rand.shuffle(matchees) # Shuffle the matchees each attempt
|
||||
# Attempt with each starting matchee
|
||||
for shifted_matchees in iterate_all_shifts(matchees):
|
||||
|
||||
attempts += 1
|
||||
groups = attempt_create_groups(
|
||||
matchees, hist, oldest_relevant_datetime, per_group)
|
||||
shifted_matchees, st, oldest_relevant_datetime, per_group)
|
||||
|
||||
# Fail the match if our groups aren't big enough
|
||||
if (len(matchees)//per_group) <= 1 or (groups and all(len(g) >= per_group for g in groups)):
|
||||
if num_groups <= 1 or (groups and all(len(g) >= per_group for g in groups)):
|
||||
logger.info("Matched groups after %s attempt(s)", attempts)
|
||||
return groups
|
||||
|
||||
|
@ -176,6 +194,10 @@ def members_to_groups(matchees: list[Member],
|
|||
logger.info("Fell back to simple groups after %s attempt(s)", attempts)
|
||||
return members_to_groups_simple(matchees, per_group)
|
||||
|
||||
# Simply assert false, this should never happen
|
||||
# And should be caught by tests
|
||||
assert False
|
||||
|
||||
|
||||
def group_to_message(group: list[Member]) -> str:
|
||||
"""Get the message to send for each group"""
|
||||
|
@ -185,8 +207,3 @@ def group_to_message(group: list[Member]) -> str:
|
|||
else:
|
||||
mentions = mentions[0]
|
||||
return f"Matched up {mentions}!"
|
||||
|
||||
|
||||
def get_role_from_guild(guild: Guild, role: str) -> Role:
|
||||
"""Find a role in a guild"""
|
||||
return next((r for r in guild.roles if r.name == role), None)
|
412
py/matching_test.py
Normal file
412
py/matching_test.py
Normal file
|
@ -0,0 +1,412 @@
|
|||
"""
|
||||
Test functions for the matching module
|
||||
"""
|
||||
import discord
|
||||
import pytest
|
||||
import random
|
||||
import matching
|
||||
import state
|
||||
import copy
|
||||
import itertools
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def test_protocols():
|
||||
"""Verify the protocols we're using match the discord ones"""
|
||||
assert isinstance(discord.Member, matching.Member)
|
||||
assert isinstance(discord.Guild, matching.Guild)
|
||||
assert isinstance(discord.Role, matching.Role)
|
||||
assert isinstance(Member, matching.Member)
|
||||
# assert isinstance(Role, matching.Role)
|
||||
|
||||
|
||||
class Role():
|
||||
def __init__(self, id: int):
|
||||
self._id = id
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
|
||||
class Member():
|
||||
def __init__(self, id: int, roles: list[Role] = []):
|
||||
self._id = id
|
||||
self._roles = roles
|
||||
|
||||
@property
|
||||
def mention(self) -> str:
|
||||
return f"<@{self._id}>"
|
||||
|
||||
@property
|
||||
def roles(self) -> list[Role]:
|
||||
return self._roles
|
||||
|
||||
@roles.setter
|
||||
def roles(self, roles: list[Role]):
|
||||
self._roles = roles
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
|
||||
def members_to_groups_validate(matchees: list[Member], tmp_state: state.State, per_group: int):
|
||||
"""Inner function to validate the main output of the groups function"""
|
||||
groups = matching.members_to_groups(matchees, tmp_state, per_group)
|
||||
|
||||
# We should always have one group
|
||||
assert len(groups)
|
||||
|
||||
# Log the groups to history
|
||||
# This will validate the internals
|
||||
tmp_state.log_groups(groups)
|
||||
|
||||
# Ensure each group contains within the bounds of expected members
|
||||
for group in groups:
|
||||
if len(matchees) >= per_group:
|
||||
assert len(group) >= per_group
|
||||
else:
|
||||
assert len(group) == len(matchees)
|
||||
assert len(group) < per_group*2 # TODO: We could be more strict here
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
@pytest.mark.parametrize("matchees, per_group", [
|
||||
# Simplest test possible
|
||||
([Member(1)], 1),
|
||||
|
||||
# More requested than we have
|
||||
([Member(1)], 2),
|
||||
|
||||
# A selection of hyper-simple checks to validate core functionality
|
||||
([Member(1)] * 100, 3),
|
||||
([Member(1)] * 12, 5),
|
||||
([Member(1)] * 11, 2),
|
||||
([Member(1)] * 356, 8),
|
||||
], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"])
|
||||
def test_members_to_groups_no_history(matchees, per_group):
|
||||
"""Test simple group matching works"""
|
||||
tmp_state = state.State()
|
||||
members_to_groups_validate(matchees, tmp_state, per_group)
|
||||
|
||||
|
||||
def items_found_in_lists(list_of_lists, items):
|
||||
"""validates if any sets of items are found in individual lists"""
|
||||
for sublist in list_of_lists:
|
||||
if all(item in sublist for item in items):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("history_data, matchees, per_group, checks", [
|
||||
# Slightly more difficult test
|
||||
(
|
||||
# Describe a history where we previously matched up some people and ensure they don't get rematched
|
||||
[
|
||||
{
|
||||
"ts": datetime.now() - timedelta(days=1),
|
||||
"groups": [
|
||||
[Member(1), Member(2)],
|
||||
[Member(3), Member(4)],
|
||||
]
|
||||
}
|
||||
],
|
||||
[
|
||||
Member(1),
|
||||
Member(2),
|
||||
Member(3),
|
||||
Member(4),
|
||||
],
|
||||
2,
|
||||
[
|
||||
lambda groups: not items_found_in_lists(
|
||||
groups, [Member(1), Member(2)]),
|
||||
lambda groups: not items_found_in_lists(
|
||||
groups, [Member(3), Member(4)])
|
||||
]
|
||||
),
|
||||
# Feed the system an "impossible" test
|
||||
# The function should fall back to ignoring history and still give us something
|
||||
(
|
||||
[
|
||||
{
|
||||
"ts": datetime.now() - timedelta(days=1),
|
||||
"groups": [
|
||||
[
|
||||
Member(1),
|
||||
Member(2),
|
||||
Member(3)
|
||||
],
|
||||
[
|
||||
Member(4),
|
||||
Member(5),
|
||||
Member(6)
|
||||
],
|
||||
]
|
||||
}
|
||||
],
|
||||
[
|
||||
Member(1, [Role(1), Role(2), Role(3), Role(4)]),
|
||||
Member(2, [Role(1), Role(2), Role(3), Role(4)]),
|
||||
Member(3, [Role(1), Role(2), Role(3), Role(4)]),
|
||||
Member(4, [Role(1), Role(2), Role(3), Role(4)]),
|
||||
Member(5, [Role(1), Role(2), Role(3), Role(4)]),
|
||||
Member(6, [Role(1), Role(2), Role(3), Role(4)]),
|
||||
],
|
||||
3,
|
||||
[
|
||||
# Nothing specific to validate
|
||||
]
|
||||
),
|
||||
# Specific test pulled out of the stress test
|
||||
(
|
||||
[
|
||||
{
|
||||
"ts": datetime.now() - timedelta(days=4),
|
||||
"groups": [
|
||||
[Member(i) for i in [1, 2, 3, 4, 5, 6,
|
||||
7, 8, 9, 10, 11, 12, 13, 14, 15]]
|
||||
]
|
||||
},
|
||||
{
|
||||
"ts": datetime.now() - timedelta(days=5),
|
||||
"groups": [
|
||||
[Member(i) for i in [1, 2, 3, 4, 5, 6, 7, 8]]
|
||||
]
|
||||
}
|
||||
],
|
||||
[Member(i) for i in [1, 2, 11, 4, 12, 3, 7, 5, 8, 10, 9, 6]],
|
||||
3,
|
||||
[
|
||||
# Nothing specific to validate
|
||||
]
|
||||
),
|
||||
# Silly example that failued due to bad role logic
|
||||
(
|
||||
[
|
||||
# No history
|
||||
],
|
||||
[
|
||||
# print([(m.id, [r.id for r in m.roles]) for m in matchees]) to get the below
|
||||
Member(i, [Role(r) for r in roles]) for (i, roles) in
|
||||
[
|
||||
(4, [1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
(8, [1]),
|
||||
(9, [1, 2, 3, 4, 5]),
|
||||
(6, [1, 2, 3]),
|
||||
(11, [1, 2, 3]),
|
||||
(7, [1, 2, 3, 4, 5, 6, 7]),
|
||||
(1, [1, 2, 3, 4]),
|
||||
(5, [1, 2, 3, 4, 5]),
|
||||
(12, [1, 2, 3, 4]),
|
||||
(10, [1]),
|
||||
(13, [1, 2, 3, 4, 5, 6]),
|
||||
(2, [1, 2, 3, 4, 5, 6]),
|
||||
(3, [1, 2, 3, 4, 5, 6, 7])
|
||||
]
|
||||
],
|
||||
2,
|
||||
[
|
||||
# Nothing else
|
||||
]
|
||||
),
|
||||
# Another weird one pulled out of the stress test
|
||||
(
|
||||
[
|
||||
# print([(str(h["ts"]), [[f"Member({gm.id})" for gm in g] for g in h["groups"]]) for h in history_data])
|
||||
{"ts": datetime.strptime(ts, r"%Y-%m-%d %H:%M:%S.%f"), "groups": [
|
||||
[Member(m) for m in group] for group in groups]}
|
||||
for (ts, groups) in
|
||||
[
|
||||
(
|
||||
'2024-07-07 20:25:56.313993',
|
||||
[
|
||||
[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
[1],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8]
|
||||
]
|
||||
),
|
||||
(
|
||||
'2024-07-13 20:25:56.313993',
|
||||
[
|
||||
[1, 2],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
||||
[1]
|
||||
]
|
||||
),
|
||||
(
|
||||
'2024-06-29 20:25:56.313993',
|
||||
[
|
||||
[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5, 6, 7],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18, 19, 20],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
|
||||
11, 12, 13, 14, 15, 16, 17],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20]
|
||||
]
|
||||
),
|
||||
(
|
||||
'2024-06-25 20:25:56.313993',
|
||||
[
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18],
|
||||
[1, 2],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
[1, 2]
|
||||
]
|
||||
),
|
||||
(
|
||||
'2024-07-04 20:25:56.313993',
|
||||
[
|
||||
[1, 2, 3, 4, 5],
|
||||
[1, 2, 3],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
[1, 2, 3, 4, 5, 6, 7]
|
||||
]
|
||||
),
|
||||
(
|
||||
'2024-07-16 20:25:56.313993',
|
||||
[
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18, 19, 20],
|
||||
[1, 2, 3, 4, 5, 6],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
|
||||
11, 12, 13, 14, 15, 16, 17, 18]
|
||||
]
|
||||
)
|
||||
]
|
||||
],
|
||||
[
|
||||
# print([(m.id, [r.id for r in m.roles]) for m in matchees]) to get the below
|
||||
Member(i, [Role(r) for r in roles]) for (i, roles) in
|
||||
[
|
||||
(10, [1, 2, 3]),
|
||||
(4, [1, 2, 3]),
|
||||
(5, [1, 2]),
|
||||
(13, [1, 2]),
|
||||
(3, [1, 2, 3, 4]),
|
||||
(14, [1]),
|
||||
(6, [1, 2, 3, 4]),
|
||||
(11, [1]),
|
||||
(9, [1]),
|
||||
(1, [1, 2, 3]),
|
||||
(16, [1, 2]),
|
||||
(15, [1, 2]),
|
||||
(2, [1, 2, 3]),
|
||||
(7, [1, 2, 3]),
|
||||
(12, [1, 2]),
|
||||
(8, [1, 2, 3, 4])
|
||||
]
|
||||
],
|
||||
5,
|
||||
[
|
||||
# Nothing
|
||||
]
|
||||
)
|
||||
], ids=['simple_history', 'fallback', 'example_1', 'example_2', 'example_3'])
|
||||
def test_unique_regressions(history_data, matchees, per_group, checks):
|
||||
"""Test a bunch of unqiue failures that happened in the past"""
|
||||
tmp_state = state.State()
|
||||
|
||||
# Replay the history
|
||||
for d in history_data:
|
||||
tmp_state.log_groups(d["groups"], d["ts"])
|
||||
|
||||
groups = members_to_groups_validate(matchees, tmp_state, per_group)
|
||||
|
||||
# Run the custom validate functions
|
||||
for check in checks:
|
||||
assert check(groups)
|
||||
|
||||
|
||||
def random_chunk(li, min_chunk, max_chunk, rand):
|
||||
"""
|
||||
"Borrowed" from https://stackoverflow.com/questions/21439011/best-way-to-split-a-list-into-randomly-sized-chunks
|
||||
"""
|
||||
it = iter(li)
|
||||
while True:
|
||||
nxt = list(itertools.islice(it, rand.randint(min_chunk, max_chunk)))
|
||||
if nxt:
|
||||
yield nxt
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
# Generate a large set of "interesting" tests that replay a fake history onto random people
|
||||
# Increase these numbers for some extreme programming
|
||||
@pytest.mark.parametrize("per_group, num_members, num_history", (
|
||||
(per_group, num_members, num_history)
|
||||
# Most of the time groups are gonna be from 2 to 5
|
||||
for per_group in range(2, 5)
|
||||
# Going lower than 8 members doesn't give the bot much of a chance
|
||||
# And it will fail to not fall back sometimes
|
||||
# That's probably OK frankly
|
||||
for num_members in range(8, 32, 5)
|
||||
# Throw up to 7 histories at the algorithmn
|
||||
for num_history in range(0, 8)))
|
||||
def test_stess_random_groups(per_group, num_members, num_history):
|
||||
"""Run a randomised test based on the input"""
|
||||
|
||||
# Seed the random based on the inputs paird with primes
|
||||
# Ensures the test has interesting fake data, but is stable
|
||||
rand = random.Random(per_group*3 + num_members*5 + num_history*7)
|
||||
|
||||
# Start with a list of all possible members
|
||||
possible_members = [Member(i) for i in range(num_members*2)]
|
||||
for member in possible_members:
|
||||
# Give each member 3 random roles from 1-7
|
||||
member.roles = [Role(i) for i in rand.sample(range(1, 8), 3)]
|
||||
|
||||
# For each history item match up groups and log those
|
||||
cumulative_state = state.State()
|
||||
for i in range(num_history+1):
|
||||
|
||||
# Grab the num of members and replay
|
||||
rand.shuffle(possible_members)
|
||||
members = copy.deepcopy(possible_members[:num_members])
|
||||
|
||||
groups = members_to_groups_validate(
|
||||
members, cumulative_state, per_group)
|
||||
cumulative_state.log_groups(
|
||||
groups, datetime.now() - timedelta(days=num_history-i))
|
||||
|
||||
|
||||
def test_auth_scopes():
|
||||
tmp_state = state.State()
|
||||
|
||||
id = "1"
|
||||
tmp_state.set_user_scope(id, state.AuthScope.OWNER)
|
||||
assert tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
|
||||
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
|
||||
|
||||
id = "2"
|
||||
tmp_state.set_user_scope(id, state.AuthScope.MATCHER)
|
||||
assert not tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
|
||||
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
|
||||
|
||||
tmp_state.validate()
|
||||
|
||||
|
||||
def test_iterate_all_shifts():
|
||||
original = [1, 2, 3, 4]
|
||||
lists = [val for val in matching.iterate_all_shifts(original)]
|
||||
assert lists == [
|
||||
[1, 2, 3, 4],
|
||||
[2, 3, 4, 1],
|
||||
[3, 4, 1, 2],
|
||||
[4, 1, 2, 3],
|
||||
]
|
230
py/matchy.py
Executable file
230
py/matchy.py
Executable file
|
@ -0,0 +1,230 @@
|
|||
"""
|
||||
matchy.py - Discord bot that matches people into groups
|
||||
"""
|
||||
import logging
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
import matching
|
||||
import state
|
||||
import config
|
||||
import re
|
||||
|
||||
|
||||
STATE_FILE = "state.json"
|
||||
|
||||
State = state.load_from_file(STATE_FILE)
|
||||
|
||||
logger = logging.getLogger("matchy")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.members = True
|
||||
bot = commands.Bot(command_prefix='$',
|
||||
description="Matchy matches matchees", intents=intents)
|
||||
|
||||
|
||||
@bot.event
|
||||
async def setup_hook():
|
||||
bot.add_dynamic_items(DynamicGroupButton)
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
"""Bot is ready and connected"""
|
||||
logger.info("Bot is up and ready!")
|
||||
activity = discord.Game("/join")
|
||||
await bot.change_presence(status=discord.Status.online, activity=activity)
|
||||
|
||||
|
||||
def owner_only(ctx: commands.Context) -> bool:
|
||||
"""Checks the author is an owner"""
|
||||
return State.get_user_has_scope(ctx.message.author.id, state.AuthScope.OWNER)
|
||||
|
||||
|
||||
@bot.command()
|
||||
@commands.dm_only()
|
||||
@commands.check(owner_only)
|
||||
async def sync(ctx: commands.Context):
|
||||
"""Handle sync command"""
|
||||
msg = await ctx.reply("Reloading state...", ephemeral=True)
|
||||
global State
|
||||
State = state.load_from_file(STATE_FILE)
|
||||
logger.info("Reloaded state")
|
||||
|
||||
await msg.edit(content="Syncing commands...")
|
||||
synced = await bot.tree.sync()
|
||||
logger.info("Synced %s command(s)", len(synced))
|
||||
|
||||
await msg.edit(content="Done!")
|
||||
|
||||
|
||||
@bot.command()
|
||||
@commands.dm_only()
|
||||
@commands.check(owner_only)
|
||||
async def close(ctx: commands.Context):
|
||||
"""Handle restart command"""
|
||||
await ctx.reply("Closing bot...", ephemeral=True)
|
||||
logger.info("Closing down the bot")
|
||||
await bot.close()
|
||||
|
||||
|
||||
@bot.tree.command(description="Join the matchees for this channel")
|
||||
@commands.guild_only()
|
||||
async def join(interaction: discord.Interaction):
|
||||
State.set_user_active_in_channel(
|
||||
interaction.user.id, interaction.channel.id)
|
||||
state.save_to_file(State, STATE_FILE)
|
||||
await interaction.response.send_message(
|
||||
f"Roger roger {interaction.user.mention}!\n"
|
||||
+ f"Added you to {interaction.channel.mention}!",
|
||||
ephemeral=True, silent=True)
|
||||
|
||||
|
||||
@bot.tree.command(description="Leave the matchees for this channel")
|
||||
@commands.guild_only()
|
||||
async def leave(interaction: discord.Interaction):
|
||||
State.set_user_active_in_channel(
|
||||
interaction.user.id, interaction.channel.id, False)
|
||||
state.save_to_file(State, STATE_FILE)
|
||||
await interaction.response.send_message(
|
||||
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
|
||||
|
||||
|
||||
@bot.tree.command(description="Pause your matching in this channel for a number of days")
|
||||
@commands.guild_only()
|
||||
@app_commands.describe(days="Days to pause for (defaults to 7)")
|
||||
async def pause(interaction: discord.Interaction, days: int = None):
|
||||
if not days: # Default to a week
|
||||
days = 7
|
||||
State.set_user_paused_in_channel(
|
||||
interaction.user.id, interaction.channel.id, days)
|
||||
state.save_to_file(State, STATE_FILE)
|
||||
await interaction.response.send_message(
|
||||
f"Sure thing {interaction.user.mention}. Paused you for {days} days!", ephemeral=True, silent=True)
|
||||
|
||||
|
||||
@bot.tree.command(description="List the matchees for this channel")
|
||||
@commands.guild_only()
|
||||
async def list(interaction: discord.Interaction):
|
||||
matchees = get_matchees_in_channel(interaction.channel)
|
||||
mentions = [m.mention for m in matchees]
|
||||
msg = "Current matchees in this channel:\n" + \
|
||||
f"{', '.join(mentions[:-1])} and {mentions[-1]}"
|
||||
await interaction.response.send_message(msg, ephemeral=True, silent=True)
|
||||
|
||||
|
||||
@bot.tree.command(description="Match up matchees")
|
||||
@commands.guild_only()
|
||||
@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)")
|
||||
async def match(interaction: discord.Interaction, members_min: int = None):
|
||||
"""Match groups of channel members"""
|
||||
|
||||
logger.info("Handling request '/match group_min=%s", members_min)
|
||||
logger.info("User %s from %s in #%s", interaction.user,
|
||||
interaction.guild.name, interaction.channel.name)
|
||||
|
||||
# Sort out the defaults, if not specified they'll come in as None
|
||||
if not members_min:
|
||||
members_min = 3
|
||||
|
||||
# Grab the groups
|
||||
groups = active_members_to_groups(interaction.channel, members_min)
|
||||
|
||||
# Let the user know when there's nobody to match
|
||||
if not groups:
|
||||
await interaction.response.send_message("Nobody to match up :(", ephemeral=True, silent=True)
|
||||
return
|
||||
|
||||
# Post about all the groups with a button to send to the channel
|
||||
groups_list = '\n'.join(matching.group_to_message(g) for g in groups)
|
||||
msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}"
|
||||
view = discord.utils.MISSING
|
||||
|
||||
if State.get_user_has_scope(interaction.user.id, state.AuthScope.MATCHER):
|
||||
# Otherwise set up the button
|
||||
msg += "\n\nClick the button to match up groups and send them to the channel.\n"
|
||||
view = discord.ui.View(timeout=None)
|
||||
view.add_item(DynamicGroupButton(members_min))
|
||||
else:
|
||||
# Let a non-matcher know why they don't have the button
|
||||
msg += f"\n\nYou'll need the {state.AuthScope.MATCHER}"
|
||||
+ " scope to post this to the channel, sorry!"
|
||||
|
||||
await interaction.response.send_message(msg, ephemeral=True, silent=True, view=view)
|
||||
|
||||
logger.info("Done.")
|
||||
|
||||
|
||||
# Increment when adjusting the custom_id so we don't confuse old users
|
||||
_MATCH_BUTTON_CUSTOM_ID_VERSION = 1
|
||||
_MATCH_BUTTON_CUSTOM_ID_PREFIX = f'match:v{_MATCH_BUTTON_CUSTOM_ID_VERSION}:'
|
||||
|
||||
|
||||
class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
|
||||
template=_MATCH_BUTTON_CUSTOM_ID_PREFIX + r'min:(?P<min>[0-9]+)'):
|
||||
def __init__(self, min: int) -> None:
|
||||
super().__init__(
|
||||
discord.ui.Button(
|
||||
label='Match Groups!',
|
||||
style=discord.ButtonStyle.blurple,
|
||||
custom_id=_MATCH_BUTTON_CUSTOM_ID_PREFIX + f'min:{min}',
|
||||
)
|
||||
)
|
||||
self.min: int = min
|
||||
|
||||
# This is called when the button is clicked and the custom_id matches the template.
|
||||
@classmethod
|
||||
async def from_custom_id(cls, interaction: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /):
|
||||
min = int(match['min'])
|
||||
return cls(min)
|
||||
|
||||
async def callback(self, interaction: discord.Interaction) -> None:
|
||||
"""Match up people when the button is pressed"""
|
||||
|
||||
logger.info("Handling button press min=%s", self.min)
|
||||
logger.info("User %s from %s in #%s", interaction.user,
|
||||
interaction.guild.name, interaction.channel.name)
|
||||
|
||||
# Let the user know we've recieved the message
|
||||
await interaction.response.send_message(content="Matchy is matching matchees...", ephemeral=True)
|
||||
|
||||
groups = active_members_to_groups(interaction.channel, self.min)
|
||||
|
||||
# Send the groups
|
||||
for msg in (matching.group_to_message(g) for g in groups):
|
||||
await interaction.channel.send(msg)
|
||||
|
||||
# Close off with a message
|
||||
await interaction.channel.send("That's all folks, happy matching and remember - DFTBA!")
|
||||
|
||||
# Save the groups to the history
|
||||
State.log_groups(groups)
|
||||
state.save_to_file(State, STATE_FILE)
|
||||
|
||||
logger.info("Done! Matched into %s groups.", len(groups))
|
||||
|
||||
|
||||
def get_matchees_in_channel(channel: discord.channel):
|
||||
"""Fetches the matchees in a channel"""
|
||||
# Reactivate any unpaused users
|
||||
State.reactivate_users(channel.id)
|
||||
|
||||
# Gather up the prospective matchees
|
||||
return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)]
|
||||
|
||||
|
||||
def active_members_to_groups(channel: discord.channel, min_members: int):
|
||||
"""Helper to create groups from channel members"""
|
||||
|
||||
# Gather up the prospective matchees
|
||||
matchees = get_matchees_in_channel(channel)
|
||||
|
||||
# Create our groups!
|
||||
return matching.members_to_groups(matchees, State, min_members, allow_fallback=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
handler = logging.StreamHandler()
|
||||
bot.run(config.Config.token, log_handler=handler, root_logger=True)
|
330
py/state.py
Normal file
330
py/state.py
Normal file
|
@ -0,0 +1,330 @@
|
|||
"""Store bot state"""
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from schema import Schema, And, Use, Optional
|
||||
from typing import Protocol
|
||||
import files
|
||||
import copy
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = logging.getLogger("state")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
|
||||
_VERSION = 2
|
||||
|
||||
|
||||
def _migrate_to_v1(d: dict):
|
||||
"""v1 simply renamed matchees to users"""
|
||||
logger.info("Renaming %s to %s", _Key._MATCHEES, _Key.USERS)
|
||||
d[_Key.USERS] = d[_Key._MATCHEES]
|
||||
del d[_Key._MATCHEES]
|
||||
|
||||
|
||||
def _migrate_to_v2(d: dict):
|
||||
"""v2 swapped the date over to a less silly format"""
|
||||
logger.info("Fixing up date format from %s to %s",
|
||||
_TIME_FORMAT_OLD, _TIME_FORMAT)
|
||||
|
||||
def old_to_new_ts(ts: str) -> str:
|
||||
return datetime.strftime(datetime.strptime(ts, _TIME_FORMAT_OLD), _TIME_FORMAT)
|
||||
|
||||
# Adjust all the history keys
|
||||
d[_Key.HISTORY] = {
|
||||
old_to_new_ts(ts): entry
|
||||
for ts, entry in d[_Key.HISTORY].items()
|
||||
}
|
||||
# Adjust all the user parts
|
||||
for user in d[_Key.USERS].values():
|
||||
# Update the match dates
|
||||
matches = user.get(_Key.MATCHES, {})
|
||||
for id, ts in matches.items():
|
||||
matches[id] = old_to_new_ts(ts)
|
||||
|
||||
# Update any reactivation dates
|
||||
channels = user.get(_Key.CHANNELS, {})
|
||||
for id, channel in channels.items():
|
||||
old_ts = channel.get(_Key.REACTIVATE, None)
|
||||
if old_ts:
|
||||
channel[_Key.REACTIVATE] = old_to_new_ts(old_ts)
|
||||
|
||||
|
||||
# Set of migration functions to apply
|
||||
_MIGRATIONS = [
|
||||
_migrate_to_v1,
|
||||
_migrate_to_v2
|
||||
]
|
||||
|
||||
|
||||
class AuthScope(str):
|
||||
"""Various auth scopes"""
|
||||
OWNER = "owner"
|
||||
MATCHER = "matcher"
|
||||
|
||||
|
||||
class _Key(str):
|
||||
"""Various keys used in the schema"""
|
||||
HISTORY = "history"
|
||||
GROUPS = "groups"
|
||||
MEMBERS = "members"
|
||||
USERS = "users"
|
||||
SCOPES = "scopes"
|
||||
MATCHES = "matches"
|
||||
ACTIVE = "active"
|
||||
CHANNELS = "channels"
|
||||
REACTIVATE = "reactivate"
|
||||
VERSION = "version"
|
||||
|
||||
# Unused
|
||||
_MATCHEES = "matchees"
|
||||
|
||||
|
||||
_TIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f"
|
||||
_TIME_FORMAT_OLD = "%a %b %d %H:%M:%S %Y"
|
||||
|
||||
|
||||
_SCHEMA = Schema(
|
||||
{
|
||||
# The current version
|
||||
_Key.VERSION: And(Use(int)),
|
||||
|
||||
Optional(_Key.HISTORY): {
|
||||
# A datetime
|
||||
Optional(str): {
|
||||
_Key.GROUPS: [
|
||||
{
|
||||
_Key.MEMBERS: [
|
||||
# The ID of each matchee in the match
|
||||
And(Use(int))
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
Optional(_Key.USERS): {
|
||||
Optional(str): {
|
||||
Optional(_Key.SCOPES): And(Use(list[str])),
|
||||
Optional(_Key.MATCHES): {
|
||||
# Matchee ID and Datetime pair
|
||||
Optional(str): And(Use(str))
|
||||
},
|
||||
Optional(_Key.CHANNELS): {
|
||||
# The channel ID
|
||||
Optional(str): {
|
||||
# Whether the user is signed up in this channel
|
||||
_Key.ACTIVE: And(Use(bool)),
|
||||
# A timestamp for when to re-activate the user
|
||||
Optional(_Key.REACTIVATE): And(Use(str)),
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Empty but schema-valid internal dict
|
||||
_EMPTY_DICT = {
|
||||
_Key.HISTORY: {},
|
||||
_Key.USERS: {},
|
||||
_Key.VERSION: _VERSION
|
||||
}
|
||||
assert _SCHEMA.validate(_EMPTY_DICT)
|
||||
|
||||
|
||||
class Member(Protocol):
|
||||
@property
|
||||
def id(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
def ts_to_datetime(ts: str) -> datetime:
|
||||
"""Convert a string ts to datetime using the internal format"""
|
||||
return datetime.strptime(ts, _TIME_FORMAT)
|
||||
|
||||
|
||||
def datetime_to_ts(ts: datetime) -> str:
|
||||
"""Convert a datetime to a string ts using the internal format"""
|
||||
return datetime.strftime(ts, _TIME_FORMAT)
|
||||
|
||||
|
||||
class State():
|
||||
def __init__(self, data: dict = _EMPTY_DICT):
|
||||
"""Initialise and validate the state"""
|
||||
self.validate(data)
|
||||
self._dict = copy.deepcopy(data)
|
||||
|
||||
def validate(self, dict: dict = None):
|
||||
"""Initialise and validate a state dict"""
|
||||
if not dict:
|
||||
dict = self._dict
|
||||
_SCHEMA.validate(dict)
|
||||
|
||||
def get_history_timestamps(self) -> list[datetime]:
|
||||
"""Grab all timestamps in the history"""
|
||||
return sorted([ts_to_datetime(dt) for dt in self._history.keys()])
|
||||
|
||||
def get_user_matches(self, id: int) -> list[int]:
|
||||
return self._users.get(str(id), {}).get(_Key.MATCHES, {})
|
||||
|
||||
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
|
||||
"""Log the groups"""
|
||||
ts = datetime_to_ts(ts or datetime.now())
|
||||
with self._safe_wrap() as safe_state:
|
||||
# Grab or create the hitory item for this set of groups
|
||||
history_item = {}
|
||||
safe_state._history[ts] = history_item
|
||||
history_item_groups = []
|
||||
history_item[_Key.GROUPS] = history_item_groups
|
||||
|
||||
for group in groups:
|
||||
|
||||
# Add the group data
|
||||
history_item_groups.append({
|
||||
_Key.MEMBERS: [m.id for m in group]
|
||||
})
|
||||
|
||||
# Update the matchee data with the matches
|
||||
for m in group:
|
||||
matchee = safe_state._users.get(str(m.id), {})
|
||||
matchee_matches = matchee.get(_Key.MATCHES, {})
|
||||
|
||||
for o in (o for o in group if o.id != m.id):
|
||||
matchee_matches[str(o.id)] = ts
|
||||
|
||||
matchee[_Key.MATCHES] = matchee_matches
|
||||
safe_state._users[str(m.id)] = matchee
|
||||
|
||||
def set_user_scope(self, id: str, scope: str, value: bool = True):
|
||||
"""Add an auth scope to a user"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
# Dive in
|
||||
user = safe_state._users.get(str(id), {})
|
||||
scopes = user.get(_Key.SCOPES, [])
|
||||
|
||||
# Set the value
|
||||
if value and scope not in scopes:
|
||||
scopes.append(scope)
|
||||
elif not value and scope in scopes:
|
||||
scopes.remove(scope)
|
||||
|
||||
# Roll out
|
||||
user[_Key.SCOPES] = scopes
|
||||
safe_state._users[str(id)] = user
|
||||
|
||||
def get_user_has_scope(self, id: str, scope: str) -> bool:
|
||||
"""
|
||||
Check if a user has an auth scope
|
||||
"owner" users have all scopes
|
||||
"""
|
||||
user = self._users.get(str(id), {})
|
||||
scopes = user.get(_Key.SCOPES, [])
|
||||
return AuthScope.OWNER in scopes or scope in scopes
|
||||
|
||||
def set_user_active_in_channel(self, id: str, channel_id: str, active: bool = True):
|
||||
"""Set a user as active (or not) on a given channel"""
|
||||
self._set_user_channel_prop(id, channel_id, _Key.ACTIVE, active)
|
||||
|
||||
def get_user_active_in_channel(self, id: str, channel_id: str) -> bool:
|
||||
"""Get a users active channels"""
|
||||
user = self._users.get(str(id), {})
|
||||
channels = user.get(_Key.CHANNELS, {})
|
||||
return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)]
|
||||
|
||||
def set_user_paused_in_channel(self, id: str, channel_id: str, days: int):
|
||||
"""Sets a user as paused in a channel"""
|
||||
# Deactivate the user in the channel first
|
||||
self.set_user_active_in_channel(id, channel_id, False)
|
||||
|
||||
# Set the reactivate time the number of days in the future
|
||||
ts = datetime.now() + timedelta(days=days)
|
||||
self._set_user_channel_prop(
|
||||
id, channel_id, _Key.REACTIVATE, datetime_to_ts(ts))
|
||||
|
||||
def reactivate_users(self, channel_id: str):
|
||||
"""Reactivate any users who've passed their reactivation time on this channel"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
for user in safe_state._users.values():
|
||||
channels = user.get(_Key.CHANNELS, {})
|
||||
channel = channels.get(str(channel_id), {})
|
||||
if channel and not channel[_Key.ACTIVE]:
|
||||
reactivate = channel.get(_Key.REACTIVATE, None)
|
||||
# Check if we've gone past the reactivation time and re-activate
|
||||
if reactivate and datetime.now() > ts_to_datetime(reactivate):
|
||||
channel[_Key.ACTIVE] = True
|
||||
|
||||
@property
|
||||
def dict_internal_copy(self) -> dict:
|
||||
"""Only to be used to get the internal dict as a copy"""
|
||||
return copy.deepcopy(self._dict)
|
||||
|
||||
@property
|
||||
def _history(self) -> dict[str]:
|
||||
return self._dict[_Key.HISTORY]
|
||||
|
||||
@property
|
||||
def _users(self) -> dict[str]:
|
||||
return self._dict[_Key.USERS]
|
||||
|
||||
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
|
||||
"""Set a user channel property helper"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
# Dive in
|
||||
user = safe_state._users.get(str(id), {})
|
||||
channels = user.get(_Key.CHANNELS, {})
|
||||
channel = channels.get(str(channel_id), {})
|
||||
|
||||
# Set the value
|
||||
channel[key] = value
|
||||
|
||||
# Unroll
|
||||
channels[str(channel_id)] = channel
|
||||
user[_Key.CHANNELS] = channels
|
||||
safe_state._users[str(id)] = user
|
||||
|
||||
@contextmanager
|
||||
def _safe_wrap(self):
|
||||
"""Safely run any function wrapped in a validate"""
|
||||
# Wrap in a temporary state to validate first to prevent corruption
|
||||
tmp_state = State(self._dict)
|
||||
try:
|
||||
yield tmp_state
|
||||
finally:
|
||||
# Validate and then overwrite our dict with the new one
|
||||
tmp_state.validate()
|
||||
self._dict = tmp_state._dict
|
||||
|
||||
|
||||
def _migrate(dict: dict):
|
||||
"""Migrate a dict through versions"""
|
||||
version = dict.get("version", 0)
|
||||
for i in range(version, _VERSION):
|
||||
logger.info("Migrating from v%s to v%s", version, version+1)
|
||||
_MIGRATIONS[i](dict)
|
||||
dict[_Key.VERSION] = _VERSION
|
||||
|
||||
|
||||
def load_from_file(file: str) -> State:
|
||||
"""
|
||||
Load the state from a file
|
||||
Apply any required migrations
|
||||
"""
|
||||
loaded = _EMPTY_DICT
|
||||
|
||||
# If there's a file load it and try to migrate
|
||||
if os.path.isfile(file):
|
||||
loaded = files.load(file)
|
||||
_migrate(loaded)
|
||||
|
||||
st = State(loaded)
|
||||
|
||||
# Save out the migrated (or new) file
|
||||
files.save(file, st._dict)
|
||||
|
||||
return st
|
||||
|
||||
|
||||
def save_to_file(state: State, file: str):
|
||||
"""Saves the state out to a file"""
|
||||
files.save(file, state.dict_internal_copy)
|
65
py/state_test.py
Normal file
65
py/state_test.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
"""
|
||||
Test functions for the state module
|
||||
"""
|
||||
import state
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
|
||||
def test_basic_state():
|
||||
"""Simple validate basic state load"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
state.load_from_file(path)
|
||||
|
||||
|
||||
def test_simple_load_reload():
|
||||
"""Test a basic load, save, reload"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
st = state.load_from_file(path)
|
||||
|
||||
|
||||
def test_authscope():
|
||||
"""Test setting and getting an auth scope"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
|
||||
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
|
||||
|
||||
st = state.load_from_file(path)
|
||||
st.set_user_scope(1, state.AuthScope.MATCHER)
|
||||
state.save_to_file(st, path)
|
||||
|
||||
st = state.load_from_file(path)
|
||||
assert st.get_user_has_scope(1, state.AuthScope.MATCHER)
|
||||
|
||||
st.set_user_scope(1, state.AuthScope.MATCHER, False)
|
||||
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
|
||||
|
||||
|
||||
def test_channeljoin():
|
||||
"""Test setting and getting an active channel"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
|
||||
assert not st.get_user_active_in_channel(1, "2")
|
||||
|
||||
st = state.load_from_file(path)
|
||||
st.set_user_active_in_channel(1, "2", True)
|
||||
state.save_to_file(st, path)
|
||||
|
||||
st = state.load_from_file(path)
|
||||
assert st.get_user_active_in_channel(1, "2")
|
||||
|
||||
st.set_user_active_in_channel(1, "2", False)
|
||||
assert not st.get_user_active_in_channel(1, "2")
|
Loading…
Add table
Reference in a new issue