Significant set of changes
* Use /join and /leave instead of roles * Use scopes to check for user rights rather than using the config file * Add /list to show the set of current people signed up * Add a bunch more testing for various things * Version both the config and the state
This commit is contained in:
parent
78834f5319
commit
d3a22ff090
8 changed files with 537 additions and 295 deletions
10
README.md
10
README.md
|
@ -20,18 +20,16 @@ Only usable by `OWNER` users, reloads the config and syncs commands, or closes d
|
|||
Matchy is configured by a `config.json` file that takes this format:
|
||||
```
|
||||
{
|
||||
"version": 1,
|
||||
"token": "<<github bot token>>",
|
||||
"owners": [
|
||||
<<owner id>>
|
||||
]
|
||||
}
|
||||
```
|
||||
User IDs can be grabbed by turning on Discord's developer mode and right clicking on a user.
|
||||
|
||||
## TODO
|
||||
* Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html)
|
||||
* Implement /pause to pause a user for a little while
|
||||
* Move more constants to the config
|
||||
* Add scheduling functionality
|
||||
* Version the config and history files
|
||||
* Implement /signup rather than using roles
|
||||
* Implement authorisation scopes instead of just OWNER values
|
||||
* Fix logging in some sub files
|
||||
* Improve the weirdo
|
73
config.py
73
config.py
|
@ -1,38 +1,77 @@
|
|||
"""Very simple config loading library"""
|
||||
from schema import Schema, And, Use
|
||||
import files
|
||||
import os
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("config")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
_FILE = "config.json"
|
||||
|
||||
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
|
||||
_VERSION = 1
|
||||
|
||||
|
||||
class _Keys():
|
||||
TOKEN = "token"
|
||||
VERSION = "version"
|
||||
|
||||
# Removed
|
||||
OWNERS = "owners"
|
||||
|
||||
|
||||
_SCHEMA = Schema(
|
||||
{
|
||||
# Discord bot token
|
||||
"token": And(Use(str)),
|
||||
# The current version
|
||||
_Keys.VERSION: And(Use(int)),
|
||||
|
||||
# ids of owners authorised to use owner-only commands
|
||||
"owners": And(Use(list[int])),
|
||||
# Discord bot token
|
||||
_Keys.TOKEN: And(Use(str)),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _migrate_to_v1(d: dict):
|
||||
# Owners moved to History in v1
|
||||
# Note: owners will be required to be re-added to the state.json
|
||||
owners = d.pop(_Keys.OWNERS)
|
||||
logger.warn(
|
||||
"Migration removed owners from config, these must be re-added to the state.json")
|
||||
logger.warn("Owners: %s", owners)
|
||||
|
||||
|
||||
# Set of migration functions to apply
|
||||
_MIGRATIONS = [
|
||||
_migrate_to_v1
|
||||
]
|
||||
|
||||
|
||||
class Config():
|
||||
def __init__(self, data: dict):
|
||||
"""Initialise and validate the config"""
|
||||
_SCHEMA.validate(data)
|
||||
self.__dict__ = data
|
||||
self._dict = data
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
return self.__dict__["token"]
|
||||
|
||||
@property
|
||||
def owners(self) -> list[int]:
|
||||
return self.__dict__["owners"]
|
||||
|
||||
def reload(self) -> None:
|
||||
"""Reload the config back into the dict"""
|
||||
self.__dict__ = load().__dict__
|
||||
return self._dict["token"]
|
||||
|
||||
|
||||
def load() -> Config:
|
||||
"""Load the config"""
|
||||
return Config(files.load(_FILE))
|
||||
def _migrate(dict: dict):
|
||||
"""Migrate a dict through versions"""
|
||||
version = dict.get("version", 0)
|
||||
for i in range(version, _VERSION):
|
||||
_MIGRATIONS[i](dict)
|
||||
dict["version"] = _VERSION
|
||||
|
||||
|
||||
def load_from_file(file: str = _FILE) -> Config:
|
||||
"""
|
||||
Load the state from a file
|
||||
Apply any required migrations
|
||||
"""
|
||||
assert os.path.isfile(file)
|
||||
loaded = files.load(file)
|
||||
_migrate(loaded)
|
||||
return Config(loaded)
|
||||
|
|
125
history.py
125
history.py
|
@ -1,125 +0,0 @@
|
|||
"""Store matching history"""
|
||||
import os
|
||||
from datetime import datetime
|
||||
from schema import Schema, And, Use, Optional
|
||||
from typing import Protocol
|
||||
import files
|
||||
import copy
|
||||
|
||||
_FILE = "history.json"
|
||||
|
||||
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
|
||||
_DEFAULT_DICT = {
|
||||
"history": {},
|
||||
"matchees": {}
|
||||
}
|
||||
_TIME_FORMAT = "%a %b %d %H:%M:%S %Y"
|
||||
_SCHEMA = Schema(
|
||||
{
|
||||
Optional("history"): {
|
||||
Optional(str): { # a datetime
|
||||
"groups": [
|
||||
{
|
||||
"members": [
|
||||
# The ID of each matchee in the match
|
||||
And(Use(int))
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
Optional("matchees"): {
|
||||
Optional(str): {
|
||||
Optional("matches"): {
|
||||
# Matchee ID and Datetime pair
|
||||
Optional(str): And(Use(str))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Member(Protocol):
|
||||
@property
|
||||
def id(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
def ts_to_datetime(ts: str) -> datetime:
|
||||
"""Convert a ts to datetime using the history format"""
|
||||
return datetime.strptime(ts, _TIME_FORMAT)
|
||||
|
||||
|
||||
def validate(dict: dict):
|
||||
"""Initialise and validate the history"""
|
||||
_SCHEMA.validate(dict)
|
||||
|
||||
|
||||
class History():
|
||||
def __init__(self, data: dict = _DEFAULT_DICT):
|
||||
"""Initialise and validate the history"""
|
||||
validate(data)
|
||||
self.__dict__ = copy.deepcopy(data)
|
||||
|
||||
@property
|
||||
def history(self) -> list[dict]:
|
||||
return self.__dict__["history"]
|
||||
|
||||
@property
|
||||
def matchees(self) -> dict[str, dict]:
|
||||
return self.__dict__["matchees"]
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save out the history"""
|
||||
files.save(_FILE, self.__dict__)
|
||||
|
||||
def oldest(self) -> datetime:
|
||||
"""Grab the oldest timestamp in history"""
|
||||
if not self.history:
|
||||
return None
|
||||
times = (ts_to_datetime(dt) for dt in self.history.keys())
|
||||
return sorted(times)[0]
|
||||
|
||||
def log_groups_to_history(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None:
|
||||
"""Log the groups"""
|
||||
tmp_history = History(self.__dict__)
|
||||
ts = datetime.strftime(ts, _TIME_FORMAT)
|
||||
|
||||
# Grab or create the hitory item for this set of groups
|
||||
history_item = {}
|
||||
tmp_history.history[ts] = history_item
|
||||
history_item_groups = []
|
||||
history_item["groups"] = history_item_groups
|
||||
|
||||
for group in groups:
|
||||
|
||||
# Add the group data
|
||||
history_item_groups.append({
|
||||
"members": list(m.id for m in group)
|
||||
})
|
||||
|
||||
# Update the matchee data with the matches
|
||||
for m in group:
|
||||
matchee = tmp_history.matchees.get(str(m.id), {})
|
||||
matchee_matches = matchee.get("matches", {})
|
||||
|
||||
for o in (o for o in group if o.id != m.id):
|
||||
matchee_matches[str(o.id)] = ts
|
||||
|
||||
matchee["matches"] = matchee_matches
|
||||
tmp_history.matchees[str(m.id)] = matchee
|
||||
|
||||
# Validate before storing the result
|
||||
validate(self.__dict__)
|
||||
self.__dict__ = tmp_history.__dict__
|
||||
|
||||
def save_groups_to_history(self, groups: list[list[Member]]) -> None:
|
||||
"""Save out the groups to the history file"""
|
||||
self.log_groups_to_history(groups)
|
||||
self.save()
|
||||
|
||||
|
||||
def load() -> History:
|
||||
"""Load the history"""
|
||||
return History(files.load(_FILE) if os.path.isfile(_FILE) else _DEFAULT_DICT)
|
98
matching.py
98
matching.py
|
@ -1,6 +1,5 @@
|
|||
"""Utility functions for matchy"""
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Protocol, runtime_checkable
|
||||
import state
|
||||
|
@ -9,17 +8,16 @@ import state
|
|||
# Number of days to step forward from the start of history for each match attempt
|
||||
_ATTEMPT_TIMESTEP_INCREMENT = timedelta(days=7)
|
||||
|
||||
# Attempts for each of those time periods
|
||||
_ATTEMPTS_PER_TIMESTEP = 3
|
||||
|
||||
# Various eligability scoring factors for group meetups
|
||||
_SCORE_CURRENT_MEMBERS = 2**1
|
||||
_SCORE_REPEAT_ROLE = 2**2
|
||||
_SCORE_REPEAT_MATCH = 2**3
|
||||
_SCORE_EXTRA_MEMBERS = 2**4
|
||||
class _ScoreFactors(int):
|
||||
"""Various eligability scoring factors for group meetups"""
|
||||
REPEAT_ROLE = 2**2
|
||||
REPEAT_MATCH = 2**3
|
||||
EXTRA_MEMBER = 2**5
|
||||
|
||||
# Scores higher than this are fully rejected
|
||||
_SCORE_UPPER_THRESHOLD = 2**6
|
||||
UPPER_THRESHOLD = 2**6
|
||||
|
||||
|
||||
logger = logging.getLogger("matching")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
@ -69,33 +67,42 @@ def members_to_groups_simple(matchees: list[Member], per_group: int) -> tuple[bo
|
|||
|
||||
def get_member_group_eligibility_score(member: Member,
|
||||
group: list[Member],
|
||||
relevant_matches: list[int],
|
||||
per_group: int) -> int:
|
||||
prior_matches: list[int],
|
||||
per_group: int) -> float:
|
||||
"""Rates a member against a group"""
|
||||
rating = len(group) * _SCORE_CURRENT_MEMBERS
|
||||
# An empty group is a "perfect" score atomatically
|
||||
rating = 0
|
||||
if not group:
|
||||
return rating
|
||||
|
||||
repeat_meetings = sum(m.id in relevant_matches for m in group)
|
||||
rating += repeat_meetings * _SCORE_REPEAT_MATCH
|
||||
# Add score based on prior matchups of this user
|
||||
rating += sum(m.id in prior_matches for m in group) * \
|
||||
_ScoreFactors.REPEAT_MATCH
|
||||
|
||||
repeat_roles = sum(r in member.roles for r in (m.roles for m in group))
|
||||
rating += (repeat_roles * _SCORE_REPEAT_ROLE)
|
||||
# Calculate the number of roles that match
|
||||
all_role_ids = set(r.id for mr in [r.roles for r in group] for r in mr)
|
||||
member_role_ids = [r.id for r in member.roles]
|
||||
repeat_roles = sum(id in member_role_ids for id in all_role_ids)
|
||||
rating += repeat_roles * _ScoreFactors.REPEAT_ROLE
|
||||
|
||||
extra_members = len(group) - per_group
|
||||
if extra_members > 0:
|
||||
rating += extra_members * _SCORE_EXTRA_MEMBERS
|
||||
# Add score based on the number of extra members
|
||||
# Calculate the member offset (+1 for this user)
|
||||
extra_members = (len(group) - per_group) + 1
|
||||
if extra_members >= 0:
|
||||
rating += extra_members * _ScoreFactors.EXTRA_MEMBER
|
||||
|
||||
return rating
|
||||
|
||||
|
||||
def attempt_create_groups(matchees: list[Member],
|
||||
hist: state.State,
|
||||
current_state: state.State,
|
||||
oldest_relevant_ts: datetime,
|
||||
per_group: int) -> tuple[bool, list[list[Member]]]:
|
||||
"""History aware group matching"""
|
||||
num_groups = max(len(matchees)//per_group, 1)
|
||||
|
||||
# Set up the groups in place
|
||||
groups = list([] for _ in range(num_groups))
|
||||
groups = [[] for _ in range(num_groups)]
|
||||
|
||||
matchees_left = matchees.copy()
|
||||
|
||||
|
@ -103,21 +110,21 @@ def attempt_create_groups(matchees: list[Member],
|
|||
while matchees_left:
|
||||
# Get the next matchee to place
|
||||
matchee = matchees_left.pop()
|
||||
matchee_matches = hist.matchees.get(
|
||||
str(matchee.id), {}).get("matches", {})
|
||||
relevant_matches = list(int(id) for id, ts in matchee_matches.items()
|
||||
if state.ts_to_datetime(ts) >= oldest_relevant_ts)
|
||||
matchee_matches = current_state.get_user_matches(matchee.id)
|
||||
relevant_matches = [int(id) for id, ts
|
||||
in matchee_matches.items()
|
||||
if state.ts_to_datetime(ts) >= oldest_relevant_ts]
|
||||
|
||||
# Try every single group from the current group onwards
|
||||
# Progressing through the groups like this ensures we slowly fill them up with compatible people
|
||||
scores: list[tuple[int, int]] = []
|
||||
scores: list[tuple[int, float]] = []
|
||||
for group in groups:
|
||||
|
||||
score = get_member_group_eligibility_score(
|
||||
matchee, group, relevant_matches, num_groups)
|
||||
matchee, group, relevant_matches, per_group)
|
||||
|
||||
# If the score isn't too high, consider this group
|
||||
if score <= _SCORE_UPPER_THRESHOLD:
|
||||
if score <= _ScoreFactors.UPPER_THRESHOLD:
|
||||
scores.append((group, score))
|
||||
|
||||
# Optimisation:
|
||||
|
@ -143,31 +150,41 @@ def datetime_range(start_time: datetime, increment: timedelta, end: datetime):
|
|||
current += increment
|
||||
|
||||
|
||||
def iterate_all_shifts(list: list):
|
||||
"""Yields each shifted variation of the input list"""
|
||||
yield list
|
||||
for _ in range(len(list)-1):
|
||||
list = list[1:] + [list[0]]
|
||||
yield list
|
||||
|
||||
|
||||
def members_to_groups(matchees: list[Member],
|
||||
hist: state.State = state.State(),
|
||||
per_group: int = 3,
|
||||
allow_fallback: bool = False) -> list[list[Member]]:
|
||||
"""Generate the groups from the set of matchees"""
|
||||
attempts = 0 # Tracking for logging purposes
|
||||
rand = random.Random(117) # Some stable randomness
|
||||
num_groups = len(matchees)//per_group
|
||||
|
||||
# Bail early if there's no-one to match
|
||||
if not matchees:
|
||||
return []
|
||||
|
||||
# Grab the oldest timestamp
|
||||
history_start = hist.oldest_history() or datetime.now()
|
||||
history_start = hist.get_oldest_timestamp() or datetime.now()
|
||||
|
||||
# Walk from the start of time until now using the timestep increment
|
||||
for oldest_relevant_datetime in datetime_range(history_start, _ATTEMPT_TIMESTEP_INCREMENT, datetime.now()):
|
||||
|
||||
# Have a few attempts before stepping forward in time
|
||||
for _ in range(_ATTEMPTS_PER_TIMESTEP):
|
||||
|
||||
rand.shuffle(matchees) # Shuffle the matchees each attempt
|
||||
# Attempt with each starting matchee
|
||||
for shifted_matchees in iterate_all_shifts(matchees):
|
||||
|
||||
attempts += 1
|
||||
groups = attempt_create_groups(
|
||||
matchees, hist, oldest_relevant_datetime, per_group)
|
||||
shifted_matchees, hist, oldest_relevant_datetime, per_group)
|
||||
|
||||
# Fail the match if our groups aren't big enough
|
||||
if (len(matchees)//per_group) <= 1 or (groups and all(len(g) >= per_group for g in groups)):
|
||||
if num_groups <= 1 or (groups and all(len(g) >= per_group for g in groups)):
|
||||
logger.info("Matched groups after %s attempt(s)", attempts)
|
||||
return groups
|
||||
|
||||
|
@ -176,6 +193,10 @@ def members_to_groups(matchees: list[Member],
|
|||
logger.info("Fell back to simple groups after %s attempt(s)", attempts)
|
||||
return members_to_groups_simple(matchees, per_group)
|
||||
|
||||
# Simply assert false, this should never happen
|
||||
# And should be caught by tests
|
||||
assert False
|
||||
|
||||
|
||||
def group_to_message(group: list[Member]) -> str:
|
||||
"""Get the message to send for each group"""
|
||||
|
@ -185,8 +206,3 @@ def group_to_message(group: list[Member]) -> str:
|
|||
else:
|
||||
mentions = mentions[0]
|
||||
return f"Matched up {mentions}!"
|
||||
|
||||
|
||||
def get_role_from_guild(guild: Guild, role: str) -> Role:
|
||||
"""Find a role in a guild"""
|
||||
return next((r for r in guild.roles if r.name == role), None)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Test functions for Matchy
|
||||
Test functions for the matching module
|
||||
"""
|
||||
import discord
|
||||
import pytest
|
||||
|
@ -30,6 +30,7 @@ class Role():
|
|||
class Member():
|
||||
def __init__(self, id: int, roles: list[Role] = []):
|
||||
self._id = id
|
||||
self._roles = roles
|
||||
|
||||
@property
|
||||
def mention(self) -> str:
|
||||
|
@ -37,7 +38,7 @@ class Member():
|
|||
|
||||
@property
|
||||
def roles(self) -> list[Role]:
|
||||
return []
|
||||
return self._roles
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
|
@ -153,7 +154,59 @@ def items_found_in_lists(list_of_lists, items):
|
|||
# Nothing specific to validate
|
||||
]
|
||||
),
|
||||
], ids=['simple_history', 'fallback'])
|
||||
# Specific test pulled out of the stress test
|
||||
(
|
||||
[
|
||||
{
|
||||
"ts": datetime.now() - timedelta(days=4),
|
||||
"groups": [
|
||||
[Member(i) for i in [1, 2, 3, 4, 5, 6,
|
||||
7, 8, 9, 10, 11, 12, 13, 14, 15]]
|
||||
]
|
||||
},
|
||||
{
|
||||
"ts": datetime.now() - timedelta(days=5),
|
||||
"groups": [
|
||||
[Member(i) for i in [1, 2, 3, 4, 5, 6, 7, 8]]
|
||||
]
|
||||
}
|
||||
],
|
||||
[Member(i) for i in [1, 2, 11, 4, 12, 3, 7, 5, 8, 10, 9, 6]],
|
||||
3,
|
||||
[
|
||||
# Nothing specific to validate
|
||||
]
|
||||
),
|
||||
# Silly example that failued due to bad role logic
|
||||
(
|
||||
[
|
||||
# No history
|
||||
],
|
||||
[
|
||||
# print([(m.id, [r.id for r in m.roles]) for m in matchees]) to get the below
|
||||
Member(i, [Role(r) for r in roles]) for (i, roles) in
|
||||
[
|
||||
(4, [1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
(8, [1]),
|
||||
(9, [1, 2, 3, 4, 5]),
|
||||
(6, [1, 2, 3]),
|
||||
(11, [1, 2, 3]),
|
||||
(7, [1, 2, 3, 4, 5, 6, 7]),
|
||||
(1, [1, 2, 3, 4]),
|
||||
(5, [1, 2, 3, 4, 5]),
|
||||
(12, [1, 2, 3, 4]),
|
||||
(10, [1]),
|
||||
(13, [1, 2, 3, 4, 5, 6]),
|
||||
(2, [1, 2, 3, 4, 5, 6]),
|
||||
(3, [1, 2, 3, 4, 5, 6, 7])
|
||||
]
|
||||
],
|
||||
2,
|
||||
[
|
||||
# Nothing else
|
||||
]
|
||||
)
|
||||
], ids=['simple_history', 'fallback', 'example_1', 'example_2'])
|
||||
def test_members_to_groups_with_history(history_data, matchees, per_group, checks):
|
||||
"""Test more advanced group matching works"""
|
||||
tmp_state = state.State()
|
||||
|
@ -180,8 +233,8 @@ def test_members_to_groups_stress_test():
|
|||
|
||||
# Slowly ramp a randomized shuffled list of members with randomised roles
|
||||
for num_members in range(1, 5):
|
||||
matchees = list(Member(i, list(Role(i) for i in range(1, rand.randint(2, num_members*2 + 1))))
|
||||
for i in range(1, rand.randint(2, num_members*10 + 1)))
|
||||
matchees = [Member(i, [Role(i) for i in range(1, rand.randint(2, num_members*2 + 1))])
|
||||
for i in range(1, rand.randint(2, num_members*10 + 1))]
|
||||
rand.shuffle(matchees)
|
||||
|
||||
for num_history in range(8):
|
||||
|
@ -190,14 +243,14 @@ def test_members_to_groups_stress_test():
|
|||
# Start some time from now to the past
|
||||
time = datetime.now() - timedelta(days=rand.randint(0, num_history*5))
|
||||
history_data = []
|
||||
for x in range(0, num_history):
|
||||
for _ in range(0, num_history):
|
||||
run = {
|
||||
"ts": time
|
||||
}
|
||||
groups = []
|
||||
for y in range(1, num_history):
|
||||
groups.append(list(Member(i)
|
||||
for i in range(1, max(num_members, rand.randint(2, num_members*10 + 1)))))
|
||||
groups.append([Member(i)
|
||||
for i in range(1, max(num_members, rand.randint(2, num_members*10 + 1)))])
|
||||
run["groups"] = groups
|
||||
history_data.append(run)
|
||||
|
||||
|
@ -212,4 +265,32 @@ def test_members_to_groups_stress_test():
|
|||
for d in history_data:
|
||||
tmp_state.log_groups(d["groups"], d["ts"])
|
||||
|
||||
inner_validate_members_to_groups(matchees, tmp_state, per_group)
|
||||
inner_validate_members_to_groups(
|
||||
matchees, tmp_state, per_group)
|
||||
|
||||
|
||||
def test_auth_scopes():
|
||||
tmp_state = state.State()
|
||||
|
||||
id = "1"
|
||||
tmp_state.set_user_scope(id, state.AuthScope.OWNER)
|
||||
assert tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
|
||||
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
|
||||
|
||||
id = "2"
|
||||
tmp_state.set_user_scope(id, state.AuthScope.MATCHER)
|
||||
assert not tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
|
||||
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
|
||||
|
||||
tmp_state.validate()
|
||||
|
||||
|
||||
def test_iterate_all_shifts():
|
||||
original = [1, 2, 3, 4]
|
||||
lists = [val for val in matching.iterate_all_shifts(original)]
|
||||
assert lists == [
|
||||
[1, 2, 3, 4],
|
||||
[2, 3, 4, 1],
|
||||
[3, 4, 1, 2],
|
||||
[4, 1, 2, 3],
|
||||
]
|
||||
|
|
130
matchy.py
130
matchy.py
|
@ -11,8 +11,11 @@ import config
|
|||
import re
|
||||
|
||||
|
||||
Config = config.load()
|
||||
State = state.load()
|
||||
STATE_FILE = "state.json"
|
||||
CONFIG_FILE = "config.json"
|
||||
|
||||
Config = config.load_from_file(CONFIG_FILE)
|
||||
State = state.load_from_file(STATE_FILE)
|
||||
|
||||
logger = logging.getLogger("matchy")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
@ -39,7 +42,7 @@ async def on_ready():
|
|||
|
||||
def owner_only(ctx: commands.Context) -> bool:
|
||||
"""Checks the author is an owner"""
|
||||
return ctx.message.author.id in Config.owners
|
||||
return State.get_user_has_scope(ctx.message.author.id, state.AuthScope.OWNER)
|
||||
|
||||
|
||||
@bot.command()
|
||||
|
@ -47,9 +50,10 @@ def owner_only(ctx: commands.Context) -> bool:
|
|||
@commands.check(owner_only)
|
||||
async def sync(ctx: commands.Context):
|
||||
"""Handle sync command"""
|
||||
msg = await ctx.reply("Reloading config...", ephemeral=True)
|
||||
Config.reload()
|
||||
logger.info("Reloaded config")
|
||||
msg = await ctx.reply("Reloading state...", ephemeral=True)
|
||||
global State
|
||||
State = state.load_from_file(STATE_FILE)
|
||||
logger.info("Reloaded state")
|
||||
|
||||
await msg.edit(content="Syncing commands...")
|
||||
synced = await bot.tree.sync()
|
||||
|
@ -68,96 +72,112 @@ async def close(ctx: commands.Context):
|
|||
await bot.close()
|
||||
|
||||
|
||||
@bot.tree.command(description="Join the matchees for this channel")
|
||||
@commands.guild_only()
|
||||
async def join(interaction: discord.Interaction):
|
||||
State.set_use_active_in_channel(
|
||||
interaction.user.id, interaction.channel.id)
|
||||
state.save_to_file(State, STATE_FILE)
|
||||
await interaction.response.send_message(
|
||||
f"Roger roger {interaction.user.mention}!\n"
|
||||
+ f"Added you to {interaction.channel.mention}!",
|
||||
ephemeral=True, silent=True)
|
||||
|
||||
|
||||
@bot.tree.command(description="Leave the matchees for this channel")
|
||||
@commands.guild_only()
|
||||
async def leave(interaction: discord.Interaction):
|
||||
State.set_use_active_in_channel(
|
||||
interaction.user.id, interaction.channel.id, False)
|
||||
state.save_to_file(State, STATE_FILE)
|
||||
await interaction.response.send_message(
|
||||
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
|
||||
|
||||
|
||||
@bot.tree.command(description="List the matchees for this channel")
|
||||
@commands.guild_only()
|
||||
async def list(interaction: discord.Interaction):
|
||||
matchees = get_matchees_in_channel(interaction.channel)
|
||||
mentions = [m.mention for m in matchees]
|
||||
msg = "Current matchees in this channel:\n" + \
|
||||
f"{', '.join(mentions[:-1])} and {mentions[-1]}"
|
||||
await interaction.response.send_message(msg, ephemeral=True, silent=True)
|
||||
|
||||
|
||||
@bot.tree.command(description="Match up matchees")
|
||||
@commands.guild_only()
|
||||
@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)",
|
||||
matchee_role="Role for matchees (defaults to @Matchee)")
|
||||
async def match(interaction: discord.Interaction, members_min: int = None, matchee_role: str = None):
|
||||
@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)")
|
||||
async def match(interaction: discord.Interaction, members_min: int = None):
|
||||
"""Match groups of channel members"""
|
||||
|
||||
logger.info("Handling request '/match group_min=%s matchee_role=%s'",
|
||||
members_min, matchee_role)
|
||||
logger.info("Handling request '/match group_min=%s", members_min)
|
||||
logger.info("User %s from %s in #%s", interaction.user,
|
||||
interaction.guild.name, interaction.channel.name)
|
||||
|
||||
# Sort out the defaults, if not specified they'll come in as None
|
||||
if not members_min:
|
||||
members_min = 3
|
||||
if not matchee_role:
|
||||
matchee_role = "Matchee"
|
||||
|
||||
# Grab the roles and verify the given role
|
||||
matcher = matching.get_role_from_guild(interaction.guild, "Matcher")
|
||||
matcher = matcher and matcher in interaction.user.roles
|
||||
matchee = matching.get_role_from_guild(interaction.guild, matchee_role)
|
||||
if not matchee:
|
||||
await interaction.response.send_message(f"Server is missing '{matchee_role}' role :(", ephemeral=True)
|
||||
# Grab the groups
|
||||
groups = active_members_to_groups(interaction.channel, members_min)
|
||||
|
||||
# Let the user know when there's nobody to match
|
||||
if not groups:
|
||||
await interaction.response.send_message("Nobody to match up :(", ephemeral=True, silent=True)
|
||||
return
|
||||
|
||||
# Create some example groups to show the user
|
||||
matchees = list(
|
||||
m for m in interaction.channel.members if matchee in m.roles)
|
||||
groups = matching.members_to_groups(
|
||||
matchees, State, members_min, allow_fallback=True)
|
||||
|
||||
# Post about all the groups with a button to send to the channel
|
||||
groups_list = '\n'.join(matching.group_to_message(g) for g in groups)
|
||||
msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}"
|
||||
view = discord.utils.MISSING
|
||||
|
||||
if not matcher:
|
||||
if State.get_user_has_scope(interaction.user.id, state.AuthScope.MATCHER):
|
||||
# Let a non-matcher know why they don't have the button
|
||||
msg += "\n\nYou'll need the 'Matcher' role to post this to the channel, sorry!"
|
||||
msg += f"\n\nYou'll need the {state.AuthScope.MATCHER} scope to post this to the channel, sorry!"
|
||||
else:
|
||||
# Otherwise set up the button
|
||||
msg += "\n\nClick the button to match up groups and send them to the channel.\n"
|
||||
view = discord.ui.View(timeout=None)
|
||||
view.add_item(DynamicGroupButton(members_min, matchee_role))
|
||||
view.add_item(DynamicGroupButton(members_min))
|
||||
|
||||
await interaction.response.send_message(msg, ephemeral=True, silent=True, view=view)
|
||||
|
||||
logger.info("Done.")
|
||||
|
||||
|
||||
# Increment when adjusting the custom_id so we don't confuse old users
|
||||
_BUTTON_CUSTOM_ID_VERSION = 1
|
||||
|
||||
|
||||
class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
|
||||
template=r'match:min:(?P<min>[0-9]+):role:(?P<role>[@\w\s]+)'):
|
||||
def __init__(self, min: int, role: str) -> None:
|
||||
template=f'match:v{_BUTTON_CUSTOM_ID_VERSION}:' + r'min:(?P<min>[0-9]+)'):
|
||||
def __init__(self, min: int) -> None:
|
||||
super().__init__(
|
||||
discord.ui.Button(
|
||||
label='Match Groups!',
|
||||
style=discord.ButtonStyle.blurple,
|
||||
custom_id=f'match:min:{min}:role:{role}',
|
||||
custom_id=f'match:min:{min}',
|
||||
)
|
||||
)
|
||||
self.min: int = min
|
||||
self.role: int = role
|
||||
|
||||
# This is called when the button is clicked and the custom_id matches the template.
|
||||
@classmethod
|
||||
async def from_custom_id(cls, interaction: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /):
|
||||
min = int(match['min'])
|
||||
role = str(match['role'])
|
||||
return cls(min, role)
|
||||
return cls(min)
|
||||
|
||||
async def callback(self, interaction: discord.Interaction) -> None:
|
||||
"""Match up people when the button is pressed"""
|
||||
|
||||
logger.info("Handling button press min=%s role=%s'",
|
||||
self.min, self.role)
|
||||
logger.info("Handling button press min=%s", self.min)
|
||||
logger.info("User %s from %s in #%s", interaction.user,
|
||||
interaction.guild.name, interaction.channel.name)
|
||||
|
||||
# Let the user know we've recieved the message
|
||||
await interaction.response.send_message(content="Matchy is matching matchees...", ephemeral=True)
|
||||
|
||||
# Grab the role
|
||||
matchee = matching.get_role_from_guild(interaction.guild, self.role)
|
||||
|
||||
# Create our groups!
|
||||
matchees = list(
|
||||
m for m in interaction.channel.members if matchee in m.roles)
|
||||
groups = matching.members_to_groups(
|
||||
matchees, State, self.min, allow_fallback=True)
|
||||
groups = active_members_to_groups(interaction.channel, self.min)
|
||||
|
||||
# Send the groups
|
||||
for msg in (matching.group_to_message(g) for g in groups):
|
||||
|
@ -167,10 +187,26 @@ class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
|
|||
await interaction.channel.send("That's all folks, happy matching and remember - DFTBA!")
|
||||
|
||||
# Save the groups to the history
|
||||
State.save_groups(groups)
|
||||
State.log_groups(groups)
|
||||
state.save_to_file(State, STATE_FILE)
|
||||
|
||||
logger.info("Done. Matched %s matchees into %s groups.",
|
||||
len(matchees), len(groups))
|
||||
logger.info("Done! Matched into %s groups.", len(groups))
|
||||
|
||||
|
||||
def get_matchees_in_channel(channel: discord.channel):
|
||||
"""Fetches the matchees in a channel"""
|
||||
# Gather up the prospective matchees
|
||||
return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)]
|
||||
|
||||
|
||||
def active_members_to_groups(channel: discord.channel, min_members: int):
|
||||
"""Helper to create groups from channel members"""
|
||||
|
||||
# Gather up the prospective matchees
|
||||
matchees = get_matchees_in_channel(channel)
|
||||
|
||||
# Create our groups!
|
||||
return matching.members_to_groups(matchees, State, min_members, allow_fallback=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
238
state.py
238
state.py
|
@ -5,22 +5,65 @@ from schema import Schema, And, Use, Optional
|
|||
from typing import Protocol
|
||||
import files
|
||||
import copy
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("state")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
_FILE = "state.json"
|
||||
|
||||
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
|
||||
_DEFAULT_DICT = {
|
||||
"history": {},
|
||||
"matchees": {}
|
||||
}
|
||||
_VERSION = 1
|
||||
|
||||
|
||||
def _migrate_to_v1(d: dict):
|
||||
logger.info("Renaming %s to %s", _Key.MATCHEES, _Key.USERS)
|
||||
d[_Key.USERS] = d[_Key.MATCHEES]
|
||||
del d[_Key.MATCHEES]
|
||||
|
||||
|
||||
# Set of migration functions to apply
|
||||
_MIGRATIONS = [
|
||||
_migrate_to_v1
|
||||
]
|
||||
|
||||
|
||||
class AuthScope(str):
|
||||
"""Various auth scopes"""
|
||||
OWNER = "owner"
|
||||
MATCHER = "matcher"
|
||||
|
||||
|
||||
class _Key(str):
|
||||
"""Various keys used in the schema"""
|
||||
HISTORY = "history"
|
||||
GROUPS = "groups"
|
||||
MEMBERS = "members"
|
||||
USERS = "users"
|
||||
SCOPES = "scopes"
|
||||
MATCHES = "matches"
|
||||
ACTIVE = "active"
|
||||
CHANNELS = "channels"
|
||||
REACTIVATE = "reactivate"
|
||||
VERSION = "version"
|
||||
|
||||
# Unused
|
||||
MATCHEES = "matchees"
|
||||
|
||||
|
||||
_TIME_FORMAT = "%a %b %d %H:%M:%S %Y"
|
||||
|
||||
|
||||
_SCHEMA = Schema(
|
||||
{
|
||||
Optional("history"): {
|
||||
Optional(str): { # a datetime
|
||||
"groups": [
|
||||
# The current version
|
||||
_Key.VERSION: And(Use(int)),
|
||||
|
||||
Optional(_Key.HISTORY): {
|
||||
# A datetime
|
||||
Optional(str): {
|
||||
_Key.GROUPS: [
|
||||
{
|
||||
"members": [
|
||||
_Key.MEMBERS: [
|
||||
# The ID of each matchee in the match
|
||||
And(Use(int))
|
||||
]
|
||||
|
@ -28,17 +71,33 @@ _SCHEMA = Schema(
|
|||
]
|
||||
}
|
||||
},
|
||||
Optional("matchees"): {
|
||||
Optional(_Key.USERS): {
|
||||
Optional(str): {
|
||||
Optional("matches"): {
|
||||
Optional(_Key.SCOPES): And(Use(list[str])),
|
||||
Optional(_Key.MATCHES): {
|
||||
# Matchee ID and Datetime pair
|
||||
Optional(str): And(Use(str))
|
||||
},
|
||||
Optional(_Key.CHANNELS): {
|
||||
# The channel ID
|
||||
Optional(str): {
|
||||
# Whether the user is signed up in this channel
|
||||
_Key.ACTIVE: And(Use(bool)),
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Empty but schema-valid internal dict
|
||||
_EMPTY_DICT = {
|
||||
_Key.HISTORY: {},
|
||||
_Key.USERS: {},
|
||||
_Key.VERSION: _VERSION
|
||||
}
|
||||
assert _SCHEMA.validate(_EMPTY_DICT)
|
||||
|
||||
|
||||
class Member(Protocol):
|
||||
@property
|
||||
|
@ -51,75 +110,148 @@ def ts_to_datetime(ts: str) -> datetime:
|
|||
return datetime.strptime(ts, _TIME_FORMAT)
|
||||
|
||||
|
||||
def validate(dict: dict):
|
||||
class State():
|
||||
def __init__(self, data: dict = _EMPTY_DICT):
|
||||
"""Initialise and validate the state"""
|
||||
self.validate(data)
|
||||
self._dict = copy.deepcopy(data)
|
||||
|
||||
@property
|
||||
def _history(self) -> dict[str]:
|
||||
return self._dict[_Key.HISTORY]
|
||||
|
||||
@property
|
||||
def _users(self) -> dict[str]:
|
||||
return self._dict[_Key.USERS]
|
||||
|
||||
def validate(self, dict: dict = None):
|
||||
"""Initialise and validate a state dict"""
|
||||
if not dict:
|
||||
dict = self._dict
|
||||
_SCHEMA.validate(dict)
|
||||
|
||||
|
||||
class State():
|
||||
def __init__(self, data: dict = _DEFAULT_DICT):
|
||||
"""Initialise and validate the state"""
|
||||
validate(data)
|
||||
self.__dict__ = copy.deepcopy(data)
|
||||
|
||||
@property
|
||||
def history(self) -> list[dict]:
|
||||
return self.__dict__["history"]
|
||||
|
||||
@property
|
||||
def matchees(self) -> dict[str, dict]:
|
||||
return self.__dict__["matchees"]
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save out the state"""
|
||||
files.save(_FILE, self.__dict__)
|
||||
|
||||
def oldest_history(self) -> datetime:
|
||||
def get_oldest_timestamp(self) -> datetime:
|
||||
"""Grab the oldest timestamp in history"""
|
||||
if not self.history:
|
||||
return None
|
||||
times = (ts_to_datetime(dt) for dt in self.history.keys())
|
||||
return sorted(times)[0]
|
||||
times = (ts_to_datetime(dt) for dt in self._history.keys())
|
||||
return next(times, None)
|
||||
|
||||
def get_user_matches(self, id: int) -> list[int]:
|
||||
return self._users.get(str(id), {}).get(_Key.MATCHES, {})
|
||||
|
||||
def log_groups(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None:
|
||||
"""Log the groups"""
|
||||
tmp_state = State(self.__dict__)
|
||||
tmp_state = State(self._dict)
|
||||
ts = datetime.strftime(ts, _TIME_FORMAT)
|
||||
|
||||
# Grab or create the hitory item for this set of groups
|
||||
history_item = {}
|
||||
tmp_state.history[ts] = history_item
|
||||
tmp_state._history[ts] = history_item
|
||||
history_item_groups = []
|
||||
history_item["groups"] = history_item_groups
|
||||
history_item[_Key.GROUPS] = history_item_groups
|
||||
|
||||
for group in groups:
|
||||
|
||||
# Add the group data
|
||||
history_item_groups.append({
|
||||
"members": list(m.id for m in group)
|
||||
_Key.MEMBERS: [m.id for m in group]
|
||||
})
|
||||
|
||||
# Update the matchee data with the matches
|
||||
for m in group:
|
||||
matchee = tmp_state.matchees.get(str(m.id), {})
|
||||
matchee_matches = matchee.get("matches", {})
|
||||
matchee = tmp_state._users.get(str(m.id), {})
|
||||
matchee_matches = matchee.get(_Key.MATCHES, {})
|
||||
|
||||
for o in (o for o in group if o.id != m.id):
|
||||
matchee_matches[str(o.id)] = ts
|
||||
|
||||
matchee["matches"] = matchee_matches
|
||||
tmp_state.matchees[str(m.id)] = matchee
|
||||
matchee[_Key.MATCHES] = matchee_matches
|
||||
tmp_state._users[str(m.id)] = matchee
|
||||
|
||||
# Validate before storing the result
|
||||
validate(self.__dict__)
|
||||
self.__dict__ = tmp_state.__dict__
|
||||
tmp_state.validate()
|
||||
self._dict = tmp_state._dict
|
||||
|
||||
def save_groups(self, groups: list[list[Member]]) -> None:
|
||||
"""Save out the groups to the state file"""
|
||||
self.log_groups(groups)
|
||||
self.save()
|
||||
def set_user_scope(self, id: str, scope: str, value: bool = True):
|
||||
"""Add an auth scope to a user"""
|
||||
# Dive in
|
||||
user = self._users.get(str(id), {})
|
||||
scopes = user.get(_Key.SCOPES, [])
|
||||
|
||||
# Set the value
|
||||
if value and scope not in scopes:
|
||||
scopes.append(scope)
|
||||
elif not value and scope in scopes:
|
||||
scopes.remove(scope)
|
||||
|
||||
# Roll out
|
||||
user[_Key.SCOPES] = scopes
|
||||
self._users[id] = user
|
||||
|
||||
def get_user_has_scope(self, id: str, scope: str) -> bool:
|
||||
"""
|
||||
Check if a user has an auth scope
|
||||
"owner" users have all scopes
|
||||
"""
|
||||
user = self._users.get(str(id), {})
|
||||
scopes = user.get(_Key.SCOPES, [])
|
||||
return AuthScope.OWNER in scopes or scope in scopes
|
||||
|
||||
def set_use_active_in_channel(self, id: str, channel_id: str, active: bool = True):
|
||||
"""Set a user as active (or not) on a given channel"""
|
||||
# Dive in
|
||||
user = self._users.get(str(id), {})
|
||||
channels = user.get(_Key.CHANNELS, {})
|
||||
channel = channels.get(str(channel_id), {})
|
||||
|
||||
# Set the value
|
||||
channel[_Key.ACTIVE] = active
|
||||
|
||||
# Unroll
|
||||
channels[str(channel_id)] = channel
|
||||
user[_Key.CHANNELS] = channels
|
||||
self._users[str(id)] = user
|
||||
|
||||
def get_user_active_in_channel(self, id: str, channel_id: str) -> bool:
|
||||
"""Get a users active channels"""
|
||||
user = self._users.get(str(id), {})
|
||||
channels = user.get(_Key.CHANNELS, {})
|
||||
return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)]
|
||||
|
||||
@property
|
||||
def dict_internal(self) -> dict:
|
||||
"""Only to be used to get the internal dict as a copy"""
|
||||
return copy.deepcopy(self._dict)
|
||||
|
||||
|
||||
def load() -> State:
|
||||
"""Load the state"""
|
||||
return State(files.load(_FILE) if os.path.isfile(_FILE) else _DEFAULT_DICT)
|
||||
def _migrate(dict: dict):
|
||||
"""Migrate a dict through versions"""
|
||||
version = dict.get("version", 0)
|
||||
for i in range(version, _VERSION):
|
||||
logger.info("Migrating from v%s to v%s", version, version+1)
|
||||
_MIGRATIONS[i](dict)
|
||||
dict[_Key.VERSION] = _VERSION
|
||||
|
||||
|
||||
def load_from_file(file: str) -> State:
|
||||
"""
|
||||
Load the state from a file
|
||||
Apply any required migrations
|
||||
"""
|
||||
loaded = _EMPTY_DICT
|
||||
|
||||
# If there's a file load it and try to migrate
|
||||
if os.path.isfile(file):
|
||||
loaded = files.load(file)
|
||||
_migrate(loaded)
|
||||
|
||||
st = State(loaded)
|
||||
|
||||
# Save out the migrated (or new) file
|
||||
files.save(file, st._dict)
|
||||
|
||||
return st
|
||||
|
||||
|
||||
def save_to_file(state: State, file: str):
|
||||
"""Saves the state out to a file"""
|
||||
files.save(file, state.dict_internal)
|
||||
|
|
65
state_test.py
Normal file
65
state_test.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
"""
|
||||
Test functions for the state module
|
||||
"""
|
||||
import state
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
|
||||
def test_basic_state():
|
||||
"""Simple validate basic state load"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
state.load_from_file(path)
|
||||
|
||||
|
||||
def test_simple_load_reload():
|
||||
"""Test a basic load, save, reload"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
st = state.load_from_file(path)
|
||||
|
||||
|
||||
def test_authscope():
|
||||
"""Test setting and getting an auth scope"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
|
||||
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
|
||||
|
||||
st = state.load_from_file(path)
|
||||
st.set_user_scope(1, state.AuthScope.MATCHER)
|
||||
state.save_to_file(st, path)
|
||||
|
||||
st = state.load_from_file(path)
|
||||
assert st.get_user_has_scope(1, state.AuthScope.MATCHER)
|
||||
|
||||
st.set_user_scope(1, state.AuthScope.MATCHER, False)
|
||||
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
|
||||
|
||||
|
||||
def test_channeljoin():
|
||||
"""Test setting and getting an active channel"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
|
||||
assert not st.get_user_active_in_channel(1, "2")
|
||||
|
||||
st = state.load_from_file(path)
|
||||
st.set_use_active_in_channel(1, "2", True)
|
||||
state.save_to_file(st, path)
|
||||
|
||||
st = state.load_from_file(path)
|
||||
assert st.get_user_active_in_channel(1, "2")
|
||||
|
||||
st.set_use_active_in_channel(1, "2", False)
|
||||
assert not st.get_user_active_in_channel(1, "2")
|
Loading…
Add table
Reference in a new issue