Significant set of changes
* Use /join and /leave instead of roles * Use scopes to check for user rights rather than using the config file * Add /list to show the set of current people signed up * Add a bunch more testing for various things * Version both the config and the state
This commit is contained in:
parent
78834f5319
commit
d3a22ff090
8 changed files with 537 additions and 295 deletions
10
README.md
10
README.md
|
@ -20,18 +20,16 @@ Only usable by `OWNER` users, reloads the config and syncs commands, or closes d
|
||||||
Matchy is configured by a `config.json` file that takes this format:
|
Matchy is configured by a `config.json` file that takes this format:
|
||||||
```
|
```
|
||||||
{
|
{
|
||||||
|
"version": 1,
|
||||||
"token": "<<github bot token>>",
|
"token": "<<github bot token>>",
|
||||||
"owners": [
|
|
||||||
<<owner id>>
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
User IDs can be grabbed by turning on Discord's developer mode and right clicking on a user.
|
User IDs can be grabbed by turning on Discord's developer mode and right clicking on a user.
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
* Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html)
|
* Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html)
|
||||||
|
* Implement /pause to pause a user for a little while
|
||||||
|
* Move more constants to the config
|
||||||
* Add scheduling functionality
|
* Add scheduling functionality
|
||||||
* Version the config and history files
|
* Fix logging in some sub files
|
||||||
* Implement /signup rather than using roles
|
|
||||||
* Implement authorisation scopes instead of just OWNER values
|
|
||||||
* Improve the weirdo
|
* Improve the weirdo
|
73
config.py
73
config.py
|
@ -1,38 +1,77 @@
|
||||||
"""Very simple config loading library"""
|
"""Very simple config loading library"""
|
||||||
from schema import Schema, And, Use
|
from schema import Schema, And, Use
|
||||||
import files
|
import files
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("config")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
_FILE = "config.json"
|
_FILE = "config.json"
|
||||||
|
|
||||||
|
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
|
||||||
|
_VERSION = 1
|
||||||
|
|
||||||
|
|
||||||
|
class _Keys():
|
||||||
|
TOKEN = "token"
|
||||||
|
VERSION = "version"
|
||||||
|
|
||||||
|
# Removed
|
||||||
|
OWNERS = "owners"
|
||||||
|
|
||||||
|
|
||||||
_SCHEMA = Schema(
|
_SCHEMA = Schema(
|
||||||
{
|
{
|
||||||
# Discord bot token
|
# The current version
|
||||||
"token": And(Use(str)),
|
_Keys.VERSION: And(Use(int)),
|
||||||
|
|
||||||
# ids of owners authorised to use owner-only commands
|
# Discord bot token
|
||||||
"owners": And(Use(list[int])),
|
_Keys.TOKEN: And(Use(str)),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_to_v1(d: dict):
|
||||||
|
# Owners moved to History in v1
|
||||||
|
# Note: owners will be required to be re-added to the state.json
|
||||||
|
owners = d.pop(_Keys.OWNERS)
|
||||||
|
logger.warn(
|
||||||
|
"Migration removed owners from config, these must be re-added to the state.json")
|
||||||
|
logger.warn("Owners: %s", owners)
|
||||||
|
|
||||||
|
|
||||||
|
# Set of migration functions to apply
|
||||||
|
_MIGRATIONS = [
|
||||||
|
_migrate_to_v1
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Config():
|
class Config():
|
||||||
def __init__(self, data: dict):
|
def __init__(self, data: dict):
|
||||||
"""Initialise and validate the config"""
|
"""Initialise and validate the config"""
|
||||||
_SCHEMA.validate(data)
|
_SCHEMA.validate(data)
|
||||||
self.__dict__ = data
|
self._dict = data
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def token(self) -> str:
|
def token(self) -> str:
|
||||||
return self.__dict__["token"]
|
return self._dict["token"]
|
||||||
|
|
||||||
@property
|
|
||||||
def owners(self) -> list[int]:
|
|
||||||
return self.__dict__["owners"]
|
|
||||||
|
|
||||||
def reload(self) -> None:
|
|
||||||
"""Reload the config back into the dict"""
|
|
||||||
self.__dict__ = load().__dict__
|
|
||||||
|
|
||||||
|
|
||||||
def load() -> Config:
|
def _migrate(dict: dict):
|
||||||
"""Load the config"""
|
"""Migrate a dict through versions"""
|
||||||
return Config(files.load(_FILE))
|
version = dict.get("version", 0)
|
||||||
|
for i in range(version, _VERSION):
|
||||||
|
_MIGRATIONS[i](dict)
|
||||||
|
dict["version"] = _VERSION
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_file(file: str = _FILE) -> Config:
|
||||||
|
"""
|
||||||
|
Load the state from a file
|
||||||
|
Apply any required migrations
|
||||||
|
"""
|
||||||
|
assert os.path.isfile(file)
|
||||||
|
loaded = files.load(file)
|
||||||
|
_migrate(loaded)
|
||||||
|
return Config(loaded)
|
||||||
|
|
125
history.py
125
history.py
|
@ -1,125 +0,0 @@
|
||||||
"""Store matching history"""
|
|
||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
from schema import Schema, And, Use, Optional
|
|
||||||
from typing import Protocol
|
|
||||||
import files
|
|
||||||
import copy
|
|
||||||
|
|
||||||
_FILE = "history.json"
|
|
||||||
|
|
||||||
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
|
|
||||||
_DEFAULT_DICT = {
|
|
||||||
"history": {},
|
|
||||||
"matchees": {}
|
|
||||||
}
|
|
||||||
_TIME_FORMAT = "%a %b %d %H:%M:%S %Y"
|
|
||||||
_SCHEMA = Schema(
|
|
||||||
{
|
|
||||||
Optional("history"): {
|
|
||||||
Optional(str): { # a datetime
|
|
||||||
"groups": [
|
|
||||||
{
|
|
||||||
"members": [
|
|
||||||
# The ID of each matchee in the match
|
|
||||||
And(Use(int))
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Optional("matchees"): {
|
|
||||||
Optional(str): {
|
|
||||||
Optional("matches"): {
|
|
||||||
# Matchee ID and Datetime pair
|
|
||||||
Optional(str): And(Use(str))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Member(Protocol):
|
|
||||||
@property
|
|
||||||
def id(self) -> int:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def ts_to_datetime(ts: str) -> datetime:
|
|
||||||
"""Convert a ts to datetime using the history format"""
|
|
||||||
return datetime.strptime(ts, _TIME_FORMAT)
|
|
||||||
|
|
||||||
|
|
||||||
def validate(dict: dict):
|
|
||||||
"""Initialise and validate the history"""
|
|
||||||
_SCHEMA.validate(dict)
|
|
||||||
|
|
||||||
|
|
||||||
class History():
|
|
||||||
def __init__(self, data: dict = _DEFAULT_DICT):
|
|
||||||
"""Initialise and validate the history"""
|
|
||||||
validate(data)
|
|
||||||
self.__dict__ = copy.deepcopy(data)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def history(self) -> list[dict]:
|
|
||||||
return self.__dict__["history"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def matchees(self) -> dict[str, dict]:
|
|
||||||
return self.__dict__["matchees"]
|
|
||||||
|
|
||||||
def save(self) -> None:
|
|
||||||
"""Save out the history"""
|
|
||||||
files.save(_FILE, self.__dict__)
|
|
||||||
|
|
||||||
def oldest(self) -> datetime:
|
|
||||||
"""Grab the oldest timestamp in history"""
|
|
||||||
if not self.history:
|
|
||||||
return None
|
|
||||||
times = (ts_to_datetime(dt) for dt in self.history.keys())
|
|
||||||
return sorted(times)[0]
|
|
||||||
|
|
||||||
def log_groups_to_history(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None:
|
|
||||||
"""Log the groups"""
|
|
||||||
tmp_history = History(self.__dict__)
|
|
||||||
ts = datetime.strftime(ts, _TIME_FORMAT)
|
|
||||||
|
|
||||||
# Grab or create the hitory item for this set of groups
|
|
||||||
history_item = {}
|
|
||||||
tmp_history.history[ts] = history_item
|
|
||||||
history_item_groups = []
|
|
||||||
history_item["groups"] = history_item_groups
|
|
||||||
|
|
||||||
for group in groups:
|
|
||||||
|
|
||||||
# Add the group data
|
|
||||||
history_item_groups.append({
|
|
||||||
"members": list(m.id for m in group)
|
|
||||||
})
|
|
||||||
|
|
||||||
# Update the matchee data with the matches
|
|
||||||
for m in group:
|
|
||||||
matchee = tmp_history.matchees.get(str(m.id), {})
|
|
||||||
matchee_matches = matchee.get("matches", {})
|
|
||||||
|
|
||||||
for o in (o for o in group if o.id != m.id):
|
|
||||||
matchee_matches[str(o.id)] = ts
|
|
||||||
|
|
||||||
matchee["matches"] = matchee_matches
|
|
||||||
tmp_history.matchees[str(m.id)] = matchee
|
|
||||||
|
|
||||||
# Validate before storing the result
|
|
||||||
validate(self.__dict__)
|
|
||||||
self.__dict__ = tmp_history.__dict__
|
|
||||||
|
|
||||||
def save_groups_to_history(self, groups: list[list[Member]]) -> None:
|
|
||||||
"""Save out the groups to the history file"""
|
|
||||||
self.log_groups_to_history(groups)
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
|
|
||||||
def load() -> History:
|
|
||||||
"""Load the history"""
|
|
||||||
return History(files.load(_FILE) if os.path.isfile(_FILE) else _DEFAULT_DICT)
|
|
100
matching.py
100
matching.py
|
@ -1,6 +1,5 @@
|
||||||
"""Utility functions for matchy"""
|
"""Utility functions for matchy"""
|
||||||
import logging
|
import logging
|
||||||
import random
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
import state
|
import state
|
||||||
|
@ -9,17 +8,16 @@ import state
|
||||||
# Number of days to step forward from the start of history for each match attempt
|
# Number of days to step forward from the start of history for each match attempt
|
||||||
_ATTEMPT_TIMESTEP_INCREMENT = timedelta(days=7)
|
_ATTEMPT_TIMESTEP_INCREMENT = timedelta(days=7)
|
||||||
|
|
||||||
# Attempts for each of those time periods
|
|
||||||
_ATTEMPTS_PER_TIMESTEP = 3
|
|
||||||
|
|
||||||
# Various eligability scoring factors for group meetups
|
class _ScoreFactors(int):
|
||||||
_SCORE_CURRENT_MEMBERS = 2**1
|
"""Various eligability scoring factors for group meetups"""
|
||||||
_SCORE_REPEAT_ROLE = 2**2
|
REPEAT_ROLE = 2**2
|
||||||
_SCORE_REPEAT_MATCH = 2**3
|
REPEAT_MATCH = 2**3
|
||||||
_SCORE_EXTRA_MEMBERS = 2**4
|
EXTRA_MEMBER = 2**5
|
||||||
|
|
||||||
|
# Scores higher than this are fully rejected
|
||||||
|
UPPER_THRESHOLD = 2**6
|
||||||
|
|
||||||
# Scores higher than this are fully rejected
|
|
||||||
_SCORE_UPPER_THRESHOLD = 2**6
|
|
||||||
|
|
||||||
logger = logging.getLogger("matching")
|
logger = logging.getLogger("matching")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
@ -69,33 +67,42 @@ def members_to_groups_simple(matchees: list[Member], per_group: int) -> tuple[bo
|
||||||
|
|
||||||
def get_member_group_eligibility_score(member: Member,
|
def get_member_group_eligibility_score(member: Member,
|
||||||
group: list[Member],
|
group: list[Member],
|
||||||
relevant_matches: list[int],
|
prior_matches: list[int],
|
||||||
per_group: int) -> int:
|
per_group: int) -> float:
|
||||||
"""Rates a member against a group"""
|
"""Rates a member against a group"""
|
||||||
rating = len(group) * _SCORE_CURRENT_MEMBERS
|
# An empty group is a "perfect" score atomatically
|
||||||
|
rating = 0
|
||||||
|
if not group:
|
||||||
|
return rating
|
||||||
|
|
||||||
repeat_meetings = sum(m.id in relevant_matches for m in group)
|
# Add score based on prior matchups of this user
|
||||||
rating += repeat_meetings * _SCORE_REPEAT_MATCH
|
rating += sum(m.id in prior_matches for m in group) * \
|
||||||
|
_ScoreFactors.REPEAT_MATCH
|
||||||
|
|
||||||
repeat_roles = sum(r in member.roles for r in (m.roles for m in group))
|
# Calculate the number of roles that match
|
||||||
rating += (repeat_roles * _SCORE_REPEAT_ROLE)
|
all_role_ids = set(r.id for mr in [r.roles for r in group] for r in mr)
|
||||||
|
member_role_ids = [r.id for r in member.roles]
|
||||||
|
repeat_roles = sum(id in member_role_ids for id in all_role_ids)
|
||||||
|
rating += repeat_roles * _ScoreFactors.REPEAT_ROLE
|
||||||
|
|
||||||
extra_members = len(group) - per_group
|
# Add score based on the number of extra members
|
||||||
if extra_members > 0:
|
# Calculate the member offset (+1 for this user)
|
||||||
rating += extra_members * _SCORE_EXTRA_MEMBERS
|
extra_members = (len(group) - per_group) + 1
|
||||||
|
if extra_members >= 0:
|
||||||
|
rating += extra_members * _ScoreFactors.EXTRA_MEMBER
|
||||||
|
|
||||||
return rating
|
return rating
|
||||||
|
|
||||||
|
|
||||||
def attempt_create_groups(matchees: list[Member],
|
def attempt_create_groups(matchees: list[Member],
|
||||||
hist: state.State,
|
current_state: state.State,
|
||||||
oldest_relevant_ts: datetime,
|
oldest_relevant_ts: datetime,
|
||||||
per_group: int) -> tuple[bool, list[list[Member]]]:
|
per_group: int) -> tuple[bool, list[list[Member]]]:
|
||||||
"""History aware group matching"""
|
"""History aware group matching"""
|
||||||
num_groups = max(len(matchees)//per_group, 1)
|
num_groups = max(len(matchees)//per_group, 1)
|
||||||
|
|
||||||
# Set up the groups in place
|
# Set up the groups in place
|
||||||
groups = list([] for _ in range(num_groups))
|
groups = [[] for _ in range(num_groups)]
|
||||||
|
|
||||||
matchees_left = matchees.copy()
|
matchees_left = matchees.copy()
|
||||||
|
|
||||||
|
@ -103,21 +110,21 @@ def attempt_create_groups(matchees: list[Member],
|
||||||
while matchees_left:
|
while matchees_left:
|
||||||
# Get the next matchee to place
|
# Get the next matchee to place
|
||||||
matchee = matchees_left.pop()
|
matchee = matchees_left.pop()
|
||||||
matchee_matches = hist.matchees.get(
|
matchee_matches = current_state.get_user_matches(matchee.id)
|
||||||
str(matchee.id), {}).get("matches", {})
|
relevant_matches = [int(id) for id, ts
|
||||||
relevant_matches = list(int(id) for id, ts in matchee_matches.items()
|
in matchee_matches.items()
|
||||||
if state.ts_to_datetime(ts) >= oldest_relevant_ts)
|
if state.ts_to_datetime(ts) >= oldest_relevant_ts]
|
||||||
|
|
||||||
# Try every single group from the current group onwards
|
# Try every single group from the current group onwards
|
||||||
# Progressing through the groups like this ensures we slowly fill them up with compatible people
|
# Progressing through the groups like this ensures we slowly fill them up with compatible people
|
||||||
scores: list[tuple[int, int]] = []
|
scores: list[tuple[int, float]] = []
|
||||||
for group in groups:
|
for group in groups:
|
||||||
|
|
||||||
score = get_member_group_eligibility_score(
|
score = get_member_group_eligibility_score(
|
||||||
matchee, group, relevant_matches, num_groups)
|
matchee, group, relevant_matches, per_group)
|
||||||
|
|
||||||
# If the score isn't too high, consider this group
|
# If the score isn't too high, consider this group
|
||||||
if score <= _SCORE_UPPER_THRESHOLD:
|
if score <= _ScoreFactors.UPPER_THRESHOLD:
|
||||||
scores.append((group, score))
|
scores.append((group, score))
|
||||||
|
|
||||||
# Optimisation:
|
# Optimisation:
|
||||||
|
@ -143,31 +150,41 @@ def datetime_range(start_time: datetime, increment: timedelta, end: datetime):
|
||||||
current += increment
|
current += increment
|
||||||
|
|
||||||
|
|
||||||
|
def iterate_all_shifts(list: list):
|
||||||
|
"""Yields each shifted variation of the input list"""
|
||||||
|
yield list
|
||||||
|
for _ in range(len(list)-1):
|
||||||
|
list = list[1:] + [list[0]]
|
||||||
|
yield list
|
||||||
|
|
||||||
|
|
||||||
def members_to_groups(matchees: list[Member],
|
def members_to_groups(matchees: list[Member],
|
||||||
hist: state.State = state.State(),
|
hist: state.State = state.State(),
|
||||||
per_group: int = 3,
|
per_group: int = 3,
|
||||||
allow_fallback: bool = False) -> list[list[Member]]:
|
allow_fallback: bool = False) -> list[list[Member]]:
|
||||||
"""Generate the groups from the set of matchees"""
|
"""Generate the groups from the set of matchees"""
|
||||||
attempts = 0 # Tracking for logging purposes
|
attempts = 0 # Tracking for logging purposes
|
||||||
rand = random.Random(117) # Some stable randomness
|
num_groups = len(matchees)//per_group
|
||||||
|
|
||||||
|
# Bail early if there's no-one to match
|
||||||
|
if not matchees:
|
||||||
|
return []
|
||||||
|
|
||||||
# Grab the oldest timestamp
|
# Grab the oldest timestamp
|
||||||
history_start = hist.oldest_history() or datetime.now()
|
history_start = hist.get_oldest_timestamp() or datetime.now()
|
||||||
|
|
||||||
# Walk from the start of time until now using the timestep increment
|
# Walk from the start of time until now using the timestep increment
|
||||||
for oldest_relevant_datetime in datetime_range(history_start, _ATTEMPT_TIMESTEP_INCREMENT, datetime.now()):
|
for oldest_relevant_datetime in datetime_range(history_start, _ATTEMPT_TIMESTEP_INCREMENT, datetime.now()):
|
||||||
|
|
||||||
# Have a few attempts before stepping forward in time
|
# Attempt with each starting matchee
|
||||||
for _ in range(_ATTEMPTS_PER_TIMESTEP):
|
for shifted_matchees in iterate_all_shifts(matchees):
|
||||||
|
|
||||||
rand.shuffle(matchees) # Shuffle the matchees each attempt
|
|
||||||
|
|
||||||
attempts += 1
|
attempts += 1
|
||||||
groups = attempt_create_groups(
|
groups = attempt_create_groups(
|
||||||
matchees, hist, oldest_relevant_datetime, per_group)
|
shifted_matchees, hist, oldest_relevant_datetime, per_group)
|
||||||
|
|
||||||
# Fail the match if our groups aren't big enough
|
# Fail the match if our groups aren't big enough
|
||||||
if (len(matchees)//per_group) <= 1 or (groups and all(len(g) >= per_group for g in groups)):
|
if num_groups <= 1 or (groups and all(len(g) >= per_group for g in groups)):
|
||||||
logger.info("Matched groups after %s attempt(s)", attempts)
|
logger.info("Matched groups after %s attempt(s)", attempts)
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
|
@ -176,6 +193,10 @@ def members_to_groups(matchees: list[Member],
|
||||||
logger.info("Fell back to simple groups after %s attempt(s)", attempts)
|
logger.info("Fell back to simple groups after %s attempt(s)", attempts)
|
||||||
return members_to_groups_simple(matchees, per_group)
|
return members_to_groups_simple(matchees, per_group)
|
||||||
|
|
||||||
|
# Simply assert false, this should never happen
|
||||||
|
# And should be caught by tests
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
def group_to_message(group: list[Member]) -> str:
|
def group_to_message(group: list[Member]) -> str:
|
||||||
"""Get the message to send for each group"""
|
"""Get the message to send for each group"""
|
||||||
|
@ -185,8 +206,3 @@ def group_to_message(group: list[Member]) -> str:
|
||||||
else:
|
else:
|
||||||
mentions = mentions[0]
|
mentions = mentions[0]
|
||||||
return f"Matched up {mentions}!"
|
return f"Matched up {mentions}!"
|
||||||
|
|
||||||
|
|
||||||
def get_role_from_guild(guild: Guild, role: str) -> Role:
|
|
||||||
"""Find a role in a guild"""
|
|
||||||
return next((r for r in guild.roles if r.name == role), None)
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Test functions for Matchy
|
Test functions for the matching module
|
||||||
"""
|
"""
|
||||||
import discord
|
import discord
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -30,6 +30,7 @@ class Role():
|
||||||
class Member():
|
class Member():
|
||||||
def __init__(self, id: int, roles: list[Role] = []):
|
def __init__(self, id: int, roles: list[Role] = []):
|
||||||
self._id = id
|
self._id = id
|
||||||
|
self._roles = roles
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mention(self) -> str:
|
def mention(self) -> str:
|
||||||
|
@ -37,7 +38,7 @@ class Member():
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def roles(self) -> list[Role]:
|
def roles(self) -> list[Role]:
|
||||||
return []
|
return self._roles
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self) -> int:
|
def id(self) -> int:
|
||||||
|
@ -153,7 +154,59 @@ def items_found_in_lists(list_of_lists, items):
|
||||||
# Nothing specific to validate
|
# Nothing specific to validate
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
], ids=['simple_history', 'fallback'])
|
# Specific test pulled out of the stress test
|
||||||
|
(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"ts": datetime.now() - timedelta(days=4),
|
||||||
|
"groups": [
|
||||||
|
[Member(i) for i in [1, 2, 3, 4, 5, 6,
|
||||||
|
7, 8, 9, 10, 11, 12, 13, 14, 15]]
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ts": datetime.now() - timedelta(days=5),
|
||||||
|
"groups": [
|
||||||
|
[Member(i) for i in [1, 2, 3, 4, 5, 6, 7, 8]]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[Member(i) for i in [1, 2, 11, 4, 12, 3, 7, 5, 8, 10, 9, 6]],
|
||||||
|
3,
|
||||||
|
[
|
||||||
|
# Nothing specific to validate
|
||||||
|
]
|
||||||
|
),
|
||||||
|
# Silly example that failued due to bad role logic
|
||||||
|
(
|
||||||
|
[
|
||||||
|
# No history
|
||||||
|
],
|
||||||
|
[
|
||||||
|
# print([(m.id, [r.id for r in m.roles]) for m in matchees]) to get the below
|
||||||
|
Member(i, [Role(r) for r in roles]) for (i, roles) in
|
||||||
|
[
|
||||||
|
(4, [1, 2, 3, 4, 5, 6, 7, 8]),
|
||||||
|
(8, [1]),
|
||||||
|
(9, [1, 2, 3, 4, 5]),
|
||||||
|
(6, [1, 2, 3]),
|
||||||
|
(11, [1, 2, 3]),
|
||||||
|
(7, [1, 2, 3, 4, 5, 6, 7]),
|
||||||
|
(1, [1, 2, 3, 4]),
|
||||||
|
(5, [1, 2, 3, 4, 5]),
|
||||||
|
(12, [1, 2, 3, 4]),
|
||||||
|
(10, [1]),
|
||||||
|
(13, [1, 2, 3, 4, 5, 6]),
|
||||||
|
(2, [1, 2, 3, 4, 5, 6]),
|
||||||
|
(3, [1, 2, 3, 4, 5, 6, 7])
|
||||||
|
]
|
||||||
|
],
|
||||||
|
2,
|
||||||
|
[
|
||||||
|
# Nothing else
|
||||||
|
]
|
||||||
|
)
|
||||||
|
], ids=['simple_history', 'fallback', 'example_1', 'example_2'])
|
||||||
def test_members_to_groups_with_history(history_data, matchees, per_group, checks):
|
def test_members_to_groups_with_history(history_data, matchees, per_group, checks):
|
||||||
"""Test more advanced group matching works"""
|
"""Test more advanced group matching works"""
|
||||||
tmp_state = state.State()
|
tmp_state = state.State()
|
||||||
|
@ -180,8 +233,8 @@ def test_members_to_groups_stress_test():
|
||||||
|
|
||||||
# Slowly ramp a randomized shuffled list of members with randomised roles
|
# Slowly ramp a randomized shuffled list of members with randomised roles
|
||||||
for num_members in range(1, 5):
|
for num_members in range(1, 5):
|
||||||
matchees = list(Member(i, list(Role(i) for i in range(1, rand.randint(2, num_members*2 + 1))))
|
matchees = [Member(i, [Role(i) for i in range(1, rand.randint(2, num_members*2 + 1))])
|
||||||
for i in range(1, rand.randint(2, num_members*10 + 1)))
|
for i in range(1, rand.randint(2, num_members*10 + 1))]
|
||||||
rand.shuffle(matchees)
|
rand.shuffle(matchees)
|
||||||
|
|
||||||
for num_history in range(8):
|
for num_history in range(8):
|
||||||
|
@ -190,14 +243,14 @@ def test_members_to_groups_stress_test():
|
||||||
# Start some time from now to the past
|
# Start some time from now to the past
|
||||||
time = datetime.now() - timedelta(days=rand.randint(0, num_history*5))
|
time = datetime.now() - timedelta(days=rand.randint(0, num_history*5))
|
||||||
history_data = []
|
history_data = []
|
||||||
for x in range(0, num_history):
|
for _ in range(0, num_history):
|
||||||
run = {
|
run = {
|
||||||
"ts": time
|
"ts": time
|
||||||
}
|
}
|
||||||
groups = []
|
groups = []
|
||||||
for y in range(1, num_history):
|
for y in range(1, num_history):
|
||||||
groups.append(list(Member(i)
|
groups.append([Member(i)
|
||||||
for i in range(1, max(num_members, rand.randint(2, num_members*10 + 1)))))
|
for i in range(1, max(num_members, rand.randint(2, num_members*10 + 1)))])
|
||||||
run["groups"] = groups
|
run["groups"] = groups
|
||||||
history_data.append(run)
|
history_data.append(run)
|
||||||
|
|
||||||
|
@ -212,4 +265,32 @@ def test_members_to_groups_stress_test():
|
||||||
for d in history_data:
|
for d in history_data:
|
||||||
tmp_state.log_groups(d["groups"], d["ts"])
|
tmp_state.log_groups(d["groups"], d["ts"])
|
||||||
|
|
||||||
inner_validate_members_to_groups(matchees, tmp_state, per_group)
|
inner_validate_members_to_groups(
|
||||||
|
matchees, tmp_state, per_group)
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_scopes():
|
||||||
|
tmp_state = state.State()
|
||||||
|
|
||||||
|
id = "1"
|
||||||
|
tmp_state.set_user_scope(id, state.AuthScope.OWNER)
|
||||||
|
assert tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
|
||||||
|
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
|
||||||
|
|
||||||
|
id = "2"
|
||||||
|
tmp_state.set_user_scope(id, state.AuthScope.MATCHER)
|
||||||
|
assert not tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
|
||||||
|
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
|
||||||
|
|
||||||
|
tmp_state.validate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_iterate_all_shifts():
|
||||||
|
original = [1, 2, 3, 4]
|
||||||
|
lists = [val for val in matching.iterate_all_shifts(original)]
|
||||||
|
assert lists == [
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[2, 3, 4, 1],
|
||||||
|
[3, 4, 1, 2],
|
||||||
|
[4, 1, 2, 3],
|
||||||
|
]
|
||||||
|
|
130
matchy.py
130
matchy.py
|
@ -11,8 +11,11 @@ import config
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
Config = config.load()
|
STATE_FILE = "state.json"
|
||||||
State = state.load()
|
CONFIG_FILE = "config.json"
|
||||||
|
|
||||||
|
Config = config.load_from_file(CONFIG_FILE)
|
||||||
|
State = state.load_from_file(STATE_FILE)
|
||||||
|
|
||||||
logger = logging.getLogger("matchy")
|
logger = logging.getLogger("matchy")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
@ -39,7 +42,7 @@ async def on_ready():
|
||||||
|
|
||||||
def owner_only(ctx: commands.Context) -> bool:
|
def owner_only(ctx: commands.Context) -> bool:
|
||||||
"""Checks the author is an owner"""
|
"""Checks the author is an owner"""
|
||||||
return ctx.message.author.id in Config.owners
|
return State.get_user_has_scope(ctx.message.author.id, state.AuthScope.OWNER)
|
||||||
|
|
||||||
|
|
||||||
@bot.command()
|
@bot.command()
|
||||||
|
@ -47,9 +50,10 @@ def owner_only(ctx: commands.Context) -> bool:
|
||||||
@commands.check(owner_only)
|
@commands.check(owner_only)
|
||||||
async def sync(ctx: commands.Context):
|
async def sync(ctx: commands.Context):
|
||||||
"""Handle sync command"""
|
"""Handle sync command"""
|
||||||
msg = await ctx.reply("Reloading config...", ephemeral=True)
|
msg = await ctx.reply("Reloading state...", ephemeral=True)
|
||||||
Config.reload()
|
global State
|
||||||
logger.info("Reloaded config")
|
State = state.load_from_file(STATE_FILE)
|
||||||
|
logger.info("Reloaded state")
|
||||||
|
|
||||||
await msg.edit(content="Syncing commands...")
|
await msg.edit(content="Syncing commands...")
|
||||||
synced = await bot.tree.sync()
|
synced = await bot.tree.sync()
|
||||||
|
@ -68,96 +72,112 @@ async def close(ctx: commands.Context):
|
||||||
await bot.close()
|
await bot.close()
|
||||||
|
|
||||||
|
|
||||||
|
@bot.tree.command(description="Join the matchees for this channel")
|
||||||
|
@commands.guild_only()
|
||||||
|
async def join(interaction: discord.Interaction):
|
||||||
|
State.set_use_active_in_channel(
|
||||||
|
interaction.user.id, interaction.channel.id)
|
||||||
|
state.save_to_file(State, STATE_FILE)
|
||||||
|
await interaction.response.send_message(
|
||||||
|
f"Roger roger {interaction.user.mention}!\n"
|
||||||
|
+ f"Added you to {interaction.channel.mention}!",
|
||||||
|
ephemeral=True, silent=True)
|
||||||
|
|
||||||
|
|
||||||
|
@bot.tree.command(description="Leave the matchees for this channel")
|
||||||
|
@commands.guild_only()
|
||||||
|
async def leave(interaction: discord.Interaction):
|
||||||
|
State.set_use_active_in_channel(
|
||||||
|
interaction.user.id, interaction.channel.id, False)
|
||||||
|
state.save_to_file(State, STATE_FILE)
|
||||||
|
await interaction.response.send_message(
|
||||||
|
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
|
||||||
|
|
||||||
|
|
||||||
|
@bot.tree.command(description="List the matchees for this channel")
|
||||||
|
@commands.guild_only()
|
||||||
|
async def list(interaction: discord.Interaction):
|
||||||
|
matchees = get_matchees_in_channel(interaction.channel)
|
||||||
|
mentions = [m.mention for m in matchees]
|
||||||
|
msg = "Current matchees in this channel:\n" + \
|
||||||
|
f"{', '.join(mentions[:-1])} and {mentions[-1]}"
|
||||||
|
await interaction.response.send_message(msg, ephemeral=True, silent=True)
|
||||||
|
|
||||||
|
|
||||||
@bot.tree.command(description="Match up matchees")
|
@bot.tree.command(description="Match up matchees")
|
||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)",
|
@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)")
|
||||||
matchee_role="Role for matchees (defaults to @Matchee)")
|
async def match(interaction: discord.Interaction, members_min: int = None):
|
||||||
async def match(interaction: discord.Interaction, members_min: int = None, matchee_role: str = None):
|
|
||||||
"""Match groups of channel members"""
|
"""Match groups of channel members"""
|
||||||
|
|
||||||
logger.info("Handling request '/match group_min=%s matchee_role=%s'",
|
logger.info("Handling request '/match group_min=%s", members_min)
|
||||||
members_min, matchee_role)
|
|
||||||
logger.info("User %s from %s in #%s", interaction.user,
|
logger.info("User %s from %s in #%s", interaction.user,
|
||||||
interaction.guild.name, interaction.channel.name)
|
interaction.guild.name, interaction.channel.name)
|
||||||
|
|
||||||
# Sort out the defaults, if not specified they'll come in as None
|
# Sort out the defaults, if not specified they'll come in as None
|
||||||
if not members_min:
|
if not members_min:
|
||||||
members_min = 3
|
members_min = 3
|
||||||
if not matchee_role:
|
|
||||||
matchee_role = "Matchee"
|
|
||||||
|
|
||||||
# Grab the roles and verify the given role
|
# Grab the groups
|
||||||
matcher = matching.get_role_from_guild(interaction.guild, "Matcher")
|
groups = active_members_to_groups(interaction.channel, members_min)
|
||||||
matcher = matcher and matcher in interaction.user.roles
|
|
||||||
matchee = matching.get_role_from_guild(interaction.guild, matchee_role)
|
# Let the user know when there's nobody to match
|
||||||
if not matchee:
|
if not groups:
|
||||||
await interaction.response.send_message(f"Server is missing '{matchee_role}' role :(", ephemeral=True)
|
await interaction.response.send_message("Nobody to match up :(", ephemeral=True, silent=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create some example groups to show the user
|
|
||||||
matchees = list(
|
|
||||||
m for m in interaction.channel.members if matchee in m.roles)
|
|
||||||
groups = matching.members_to_groups(
|
|
||||||
matchees, State, members_min, allow_fallback=True)
|
|
||||||
|
|
||||||
# Post about all the groups with a button to send to the channel
|
# Post about all the groups with a button to send to the channel
|
||||||
groups_list = '\n'.join(matching.group_to_message(g) for g in groups)
|
groups_list = '\n'.join(matching.group_to_message(g) for g in groups)
|
||||||
msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}"
|
msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}"
|
||||||
view = discord.utils.MISSING
|
view = discord.utils.MISSING
|
||||||
|
|
||||||
if not matcher:
|
if State.get_user_has_scope(interaction.user.id, state.AuthScope.MATCHER):
|
||||||
# Let a non-matcher know why they don't have the button
|
# Let a non-matcher know why they don't have the button
|
||||||
msg += "\n\nYou'll need the 'Matcher' role to post this to the channel, sorry!"
|
msg += f"\n\nYou'll need the {state.AuthScope.MATCHER} scope to post this to the channel, sorry!"
|
||||||
else:
|
else:
|
||||||
# Otherwise set up the button
|
# Otherwise set up the button
|
||||||
msg += "\n\nClick the button to match up groups and send them to the channel.\n"
|
msg += "\n\nClick the button to match up groups and send them to the channel.\n"
|
||||||
view = discord.ui.View(timeout=None)
|
view = discord.ui.View(timeout=None)
|
||||||
view.add_item(DynamicGroupButton(members_min, matchee_role))
|
view.add_item(DynamicGroupButton(members_min))
|
||||||
|
|
||||||
await interaction.response.send_message(msg, ephemeral=True, silent=True, view=view)
|
await interaction.response.send_message(msg, ephemeral=True, silent=True, view=view)
|
||||||
|
|
||||||
logger.info("Done.")
|
logger.info("Done.")
|
||||||
|
|
||||||
|
|
||||||
|
# Increment when adjusting the custom_id so we don't confuse old users
|
||||||
|
_BUTTON_CUSTOM_ID_VERSION = 1
|
||||||
|
|
||||||
|
|
||||||
class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
|
class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
|
||||||
template=r'match:min:(?P<min>[0-9]+):role:(?P<role>[@\w\s]+)'):
|
template=f'match:v{_BUTTON_CUSTOM_ID_VERSION}:' + r'min:(?P<min>[0-9]+)'):
|
||||||
def __init__(self, min: int, role: str) -> None:
|
def __init__(self, min: int) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
discord.ui.Button(
|
discord.ui.Button(
|
||||||
label='Match Groups!',
|
label='Match Groups!',
|
||||||
style=discord.ButtonStyle.blurple,
|
style=discord.ButtonStyle.blurple,
|
||||||
custom_id=f'match:min:{min}:role:{role}',
|
custom_id=f'match:min:{min}',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.min: int = min
|
self.min: int = min
|
||||||
self.role: int = role
|
|
||||||
|
|
||||||
# This is called when the button is clicked and the custom_id matches the template.
|
# This is called when the button is clicked and the custom_id matches the template.
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_custom_id(cls, interaction: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /):
|
async def from_custom_id(cls, interaction: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /):
|
||||||
min = int(match['min'])
|
min = int(match['min'])
|
||||||
role = str(match['role'])
|
return cls(min)
|
||||||
return cls(min, role)
|
|
||||||
|
|
||||||
async def callback(self, interaction: discord.Interaction) -> None:
|
async def callback(self, interaction: discord.Interaction) -> None:
|
||||||
"""Match up people when the button is pressed"""
|
"""Match up people when the button is pressed"""
|
||||||
|
|
||||||
logger.info("Handling button press min=%s role=%s'",
|
logger.info("Handling button press min=%s", self.min)
|
||||||
self.min, self.role)
|
|
||||||
logger.info("User %s from %s in #%s", interaction.user,
|
logger.info("User %s from %s in #%s", interaction.user,
|
||||||
interaction.guild.name, interaction.channel.name)
|
interaction.guild.name, interaction.channel.name)
|
||||||
|
|
||||||
# Let the user know we've recieved the message
|
# Let the user know we've recieved the message
|
||||||
await interaction.response.send_message(content="Matchy is matching matchees...", ephemeral=True)
|
await interaction.response.send_message(content="Matchy is matching matchees...", ephemeral=True)
|
||||||
|
|
||||||
# Grab the role
|
groups = active_members_to_groups(interaction.channel, self.min)
|
||||||
matchee = matching.get_role_from_guild(interaction.guild, self.role)
|
|
||||||
|
|
||||||
# Create our groups!
|
|
||||||
matchees = list(
|
|
||||||
m for m in interaction.channel.members if matchee in m.roles)
|
|
||||||
groups = matching.members_to_groups(
|
|
||||||
matchees, State, self.min, allow_fallback=True)
|
|
||||||
|
|
||||||
# Send the groups
|
# Send the groups
|
||||||
for msg in (matching.group_to_message(g) for g in groups):
|
for msg in (matching.group_to_message(g) for g in groups):
|
||||||
|
@ -167,10 +187,26 @@ class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
|
||||||
await interaction.channel.send("That's all folks, happy matching and remember - DFTBA!")
|
await interaction.channel.send("That's all folks, happy matching and remember - DFTBA!")
|
||||||
|
|
||||||
# Save the groups to the history
|
# Save the groups to the history
|
||||||
State.save_groups(groups)
|
State.log_groups(groups)
|
||||||
|
state.save_to_file(State, STATE_FILE)
|
||||||
|
|
||||||
logger.info("Done. Matched %s matchees into %s groups.",
|
logger.info("Done! Matched into %s groups.", len(groups))
|
||||||
len(matchees), len(groups))
|
|
||||||
|
|
||||||
|
def get_matchees_in_channel(channel: discord.channel):
|
||||||
|
"""Fetches the matchees in a channel"""
|
||||||
|
# Gather up the prospective matchees
|
||||||
|
return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)]
|
||||||
|
|
||||||
|
|
||||||
|
def active_members_to_groups(channel: discord.channel, min_members: int):
|
||||||
|
"""Helper to create groups from channel members"""
|
||||||
|
|
||||||
|
# Gather up the prospective matchees
|
||||||
|
matchees = get_matchees_in_channel(channel)
|
||||||
|
|
||||||
|
# Create our groups!
|
||||||
|
return matching.members_to_groups(matchees, State, min_members, allow_fallback=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
230
state.py
230
state.py
|
@ -5,22 +5,65 @@ from schema import Schema, And, Use, Optional
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
import files
|
import files
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("state")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
_FILE = "state.json"
|
|
||||||
|
|
||||||
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
|
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
|
||||||
_DEFAULT_DICT = {
|
_VERSION = 1
|
||||||
"history": {},
|
|
||||||
"matchees": {}
|
|
||||||
}
|
def _migrate_to_v1(d: dict):
|
||||||
|
logger.info("Renaming %s to %s", _Key.MATCHEES, _Key.USERS)
|
||||||
|
d[_Key.USERS] = d[_Key.MATCHEES]
|
||||||
|
del d[_Key.MATCHEES]
|
||||||
|
|
||||||
|
|
||||||
|
# Set of migration functions to apply
|
||||||
|
_MIGRATIONS = [
|
||||||
|
_migrate_to_v1
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AuthScope(str):
|
||||||
|
"""Various auth scopes"""
|
||||||
|
OWNER = "owner"
|
||||||
|
MATCHER = "matcher"
|
||||||
|
|
||||||
|
|
||||||
|
class _Key(str):
|
||||||
|
"""Various keys used in the schema"""
|
||||||
|
HISTORY = "history"
|
||||||
|
GROUPS = "groups"
|
||||||
|
MEMBERS = "members"
|
||||||
|
USERS = "users"
|
||||||
|
SCOPES = "scopes"
|
||||||
|
MATCHES = "matches"
|
||||||
|
ACTIVE = "active"
|
||||||
|
CHANNELS = "channels"
|
||||||
|
REACTIVATE = "reactivate"
|
||||||
|
VERSION = "version"
|
||||||
|
|
||||||
|
# Unused
|
||||||
|
MATCHEES = "matchees"
|
||||||
|
|
||||||
|
|
||||||
_TIME_FORMAT = "%a %b %d %H:%M:%S %Y"
|
_TIME_FORMAT = "%a %b %d %H:%M:%S %Y"
|
||||||
|
|
||||||
|
|
||||||
_SCHEMA = Schema(
|
_SCHEMA = Schema(
|
||||||
{
|
{
|
||||||
Optional("history"): {
|
# The current version
|
||||||
Optional(str): { # a datetime
|
_Key.VERSION: And(Use(int)),
|
||||||
"groups": [
|
|
||||||
|
Optional(_Key.HISTORY): {
|
||||||
|
# A datetime
|
||||||
|
Optional(str): {
|
||||||
|
_Key.GROUPS: [
|
||||||
{
|
{
|
||||||
"members": [
|
_Key.MEMBERS: [
|
||||||
# The ID of each matchee in the match
|
# The ID of each matchee in the match
|
||||||
And(Use(int))
|
And(Use(int))
|
||||||
]
|
]
|
||||||
|
@ -28,17 +71,33 @@ _SCHEMA = Schema(
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Optional("matchees"): {
|
Optional(_Key.USERS): {
|
||||||
Optional(str): {
|
Optional(str): {
|
||||||
Optional("matches"): {
|
Optional(_Key.SCOPES): And(Use(list[str])),
|
||||||
|
Optional(_Key.MATCHES): {
|
||||||
# Matchee ID and Datetime pair
|
# Matchee ID and Datetime pair
|
||||||
Optional(str): And(Use(str))
|
Optional(str): And(Use(str))
|
||||||
|
},
|
||||||
|
Optional(_Key.CHANNELS): {
|
||||||
|
# The channel ID
|
||||||
|
Optional(str): {
|
||||||
|
# Whether the user is signed up in this channel
|
||||||
|
_Key.ACTIVE: And(Use(bool)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Empty but schema-valid internal dict
|
||||||
|
_EMPTY_DICT = {
|
||||||
|
_Key.HISTORY: {},
|
||||||
|
_Key.USERS: {},
|
||||||
|
_Key.VERSION: _VERSION
|
||||||
|
}
|
||||||
|
assert _SCHEMA.validate(_EMPTY_DICT)
|
||||||
|
|
||||||
|
|
||||||
class Member(Protocol):
|
class Member(Protocol):
|
||||||
@property
|
@property
|
||||||
|
@ -51,75 +110,148 @@ def ts_to_datetime(ts: str) -> datetime:
|
||||||
return datetime.strptime(ts, _TIME_FORMAT)
|
return datetime.strptime(ts, _TIME_FORMAT)
|
||||||
|
|
||||||
|
|
||||||
def validate(dict: dict):
|
|
||||||
"""Initialise and validate the state"""
|
|
||||||
_SCHEMA.validate(dict)
|
|
||||||
|
|
||||||
|
|
||||||
class State():
|
class State():
|
||||||
def __init__(self, data: dict = _DEFAULT_DICT):
|
def __init__(self, data: dict = _EMPTY_DICT):
|
||||||
"""Initialise and validate the state"""
|
"""Initialise and validate the state"""
|
||||||
validate(data)
|
self.validate(data)
|
||||||
self.__dict__ = copy.deepcopy(data)
|
self._dict = copy.deepcopy(data)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def history(self) -> list[dict]:
|
def _history(self) -> dict[str]:
|
||||||
return self.__dict__["history"]
|
return self._dict[_Key.HISTORY]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def matchees(self) -> dict[str, dict]:
|
def _users(self) -> dict[str]:
|
||||||
return self.__dict__["matchees"]
|
return self._dict[_Key.USERS]
|
||||||
|
|
||||||
def save(self) -> None:
|
def validate(self, dict: dict = None):
|
||||||
"""Save out the state"""
|
"""Initialise and validate a state dict"""
|
||||||
files.save(_FILE, self.__dict__)
|
if not dict:
|
||||||
|
dict = self._dict
|
||||||
|
_SCHEMA.validate(dict)
|
||||||
|
|
||||||
def oldest_history(self) -> datetime:
|
def get_oldest_timestamp(self) -> datetime:
|
||||||
"""Grab the oldest timestamp in history"""
|
"""Grab the oldest timestamp in history"""
|
||||||
if not self.history:
|
times = (ts_to_datetime(dt) for dt in self._history.keys())
|
||||||
return None
|
return next(times, None)
|
||||||
times = (ts_to_datetime(dt) for dt in self.history.keys())
|
|
||||||
return sorted(times)[0]
|
def get_user_matches(self, id: int) -> list[int]:
|
||||||
|
return self._users.get(str(id), {}).get(_Key.MATCHES, {})
|
||||||
|
|
||||||
def log_groups(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None:
|
def log_groups(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None:
|
||||||
"""Log the groups"""
|
"""Log the groups"""
|
||||||
tmp_state = State(self.__dict__)
|
tmp_state = State(self._dict)
|
||||||
ts = datetime.strftime(ts, _TIME_FORMAT)
|
ts = datetime.strftime(ts, _TIME_FORMAT)
|
||||||
|
|
||||||
# Grab or create the hitory item for this set of groups
|
# Grab or create the hitory item for this set of groups
|
||||||
history_item = {}
|
history_item = {}
|
||||||
tmp_state.history[ts] = history_item
|
tmp_state._history[ts] = history_item
|
||||||
history_item_groups = []
|
history_item_groups = []
|
||||||
history_item["groups"] = history_item_groups
|
history_item[_Key.GROUPS] = history_item_groups
|
||||||
|
|
||||||
for group in groups:
|
for group in groups:
|
||||||
|
|
||||||
# Add the group data
|
# Add the group data
|
||||||
history_item_groups.append({
|
history_item_groups.append({
|
||||||
"members": list(m.id for m in group)
|
_Key.MEMBERS: [m.id for m in group]
|
||||||
})
|
})
|
||||||
|
|
||||||
# Update the matchee data with the matches
|
# Update the matchee data with the matches
|
||||||
for m in group:
|
for m in group:
|
||||||
matchee = tmp_state.matchees.get(str(m.id), {})
|
matchee = tmp_state._users.get(str(m.id), {})
|
||||||
matchee_matches = matchee.get("matches", {})
|
matchee_matches = matchee.get(_Key.MATCHES, {})
|
||||||
|
|
||||||
for o in (o for o in group if o.id != m.id):
|
for o in (o for o in group if o.id != m.id):
|
||||||
matchee_matches[str(o.id)] = ts
|
matchee_matches[str(o.id)] = ts
|
||||||
|
|
||||||
matchee["matches"] = matchee_matches
|
matchee[_Key.MATCHES] = matchee_matches
|
||||||
tmp_state.matchees[str(m.id)] = matchee
|
tmp_state._users[str(m.id)] = matchee
|
||||||
|
|
||||||
# Validate before storing the result
|
# Validate before storing the result
|
||||||
validate(self.__dict__)
|
tmp_state.validate()
|
||||||
self.__dict__ = tmp_state.__dict__
|
self._dict = tmp_state._dict
|
||||||
|
|
||||||
def save_groups(self, groups: list[list[Member]]) -> None:
|
def set_user_scope(self, id: str, scope: str, value: bool = True):
|
||||||
"""Save out the groups to the state file"""
|
"""Add an auth scope to a user"""
|
||||||
self.log_groups(groups)
|
# Dive in
|
||||||
self.save()
|
user = self._users.get(str(id), {})
|
||||||
|
scopes = user.get(_Key.SCOPES, [])
|
||||||
|
|
||||||
|
# Set the value
|
||||||
|
if value and scope not in scopes:
|
||||||
|
scopes.append(scope)
|
||||||
|
elif not value and scope in scopes:
|
||||||
|
scopes.remove(scope)
|
||||||
|
|
||||||
|
# Roll out
|
||||||
|
user[_Key.SCOPES] = scopes
|
||||||
|
self._users[id] = user
|
||||||
|
|
||||||
|
def get_user_has_scope(self, id: str, scope: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a user has an auth scope
|
||||||
|
"owner" users have all scopes
|
||||||
|
"""
|
||||||
|
user = self._users.get(str(id), {})
|
||||||
|
scopes = user.get(_Key.SCOPES, [])
|
||||||
|
return AuthScope.OWNER in scopes or scope in scopes
|
||||||
|
|
||||||
|
def set_use_active_in_channel(self, id: str, channel_id: str, active: bool = True):
|
||||||
|
"""Set a user as active (or not) on a given channel"""
|
||||||
|
# Dive in
|
||||||
|
user = self._users.get(str(id), {})
|
||||||
|
channels = user.get(_Key.CHANNELS, {})
|
||||||
|
channel = channels.get(str(channel_id), {})
|
||||||
|
|
||||||
|
# Set the value
|
||||||
|
channel[_Key.ACTIVE] = active
|
||||||
|
|
||||||
|
# Unroll
|
||||||
|
channels[str(channel_id)] = channel
|
||||||
|
user[_Key.CHANNELS] = channels
|
||||||
|
self._users[str(id)] = user
|
||||||
|
|
||||||
|
def get_user_active_in_channel(self, id: str, channel_id: str) -> bool:
|
||||||
|
"""Get a users active channels"""
|
||||||
|
user = self._users.get(str(id), {})
|
||||||
|
channels = user.get(_Key.CHANNELS, {})
|
||||||
|
return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dict_internal(self) -> dict:
|
||||||
|
"""Only to be used to get the internal dict as a copy"""
|
||||||
|
return copy.deepcopy(self._dict)
|
||||||
|
|
||||||
|
|
||||||
def load() -> State:
|
def _migrate(dict: dict):
|
||||||
"""Load the state"""
|
"""Migrate a dict through versions"""
|
||||||
return State(files.load(_FILE) if os.path.isfile(_FILE) else _DEFAULT_DICT)
|
version = dict.get("version", 0)
|
||||||
|
for i in range(version, _VERSION):
|
||||||
|
logger.info("Migrating from v%s to v%s", version, version+1)
|
||||||
|
_MIGRATIONS[i](dict)
|
||||||
|
dict[_Key.VERSION] = _VERSION
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_file(file: str) -> State:
|
||||||
|
"""
|
||||||
|
Load the state from a file
|
||||||
|
Apply any required migrations
|
||||||
|
"""
|
||||||
|
loaded = _EMPTY_DICT
|
||||||
|
|
||||||
|
# If there's a file load it and try to migrate
|
||||||
|
if os.path.isfile(file):
|
||||||
|
loaded = files.load(file)
|
||||||
|
_migrate(loaded)
|
||||||
|
|
||||||
|
st = State(loaded)
|
||||||
|
|
||||||
|
# Save out the migrated (or new) file
|
||||||
|
files.save(file, st._dict)
|
||||||
|
|
||||||
|
return st
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_file(state: State, file: str):
|
||||||
|
"""Saves the state out to a file"""
|
||||||
|
files.save(file, state.dict_internal)
|
||||||
|
|
65
state_test.py
Normal file
65
state_test.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
"""
|
||||||
|
Test functions for the state module
|
||||||
|
"""
|
||||||
|
import state
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_state():
|
||||||
|
"""Simple validate basic state load"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
path = os.path.join(tmp, 'tmp.json')
|
||||||
|
state.load_from_file(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_load_reload():
|
||||||
|
"""Test a basic load, save, reload"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
path = os.path.join(tmp, 'tmp.json')
|
||||||
|
st = state.load_from_file(path)
|
||||||
|
state.save_to_file(st, path)
|
||||||
|
|
||||||
|
st = state.load_from_file(path)
|
||||||
|
state.save_to_file(st, path)
|
||||||
|
st = state.load_from_file(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_authscope():
|
||||||
|
"""Test setting and getting an auth scope"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
path = os.path.join(tmp, 'tmp.json')
|
||||||
|
st = state.load_from_file(path)
|
||||||
|
state.save_to_file(st, path)
|
||||||
|
|
||||||
|
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
|
||||||
|
|
||||||
|
st = state.load_from_file(path)
|
||||||
|
st.set_user_scope(1, state.AuthScope.MATCHER)
|
||||||
|
state.save_to_file(st, path)
|
||||||
|
|
||||||
|
st = state.load_from_file(path)
|
||||||
|
assert st.get_user_has_scope(1, state.AuthScope.MATCHER)
|
||||||
|
|
||||||
|
st.set_user_scope(1, state.AuthScope.MATCHER, False)
|
||||||
|
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
|
||||||
|
|
||||||
|
|
||||||
|
def test_channeljoin():
|
||||||
|
"""Test setting and getting an active channel"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
path = os.path.join(tmp, 'tmp.json')
|
||||||
|
st = state.load_from_file(path)
|
||||||
|
state.save_to_file(st, path)
|
||||||
|
|
||||||
|
assert not st.get_user_active_in_channel(1, "2")
|
||||||
|
|
||||||
|
st = state.load_from_file(path)
|
||||||
|
st.set_use_active_in_channel(1, "2", True)
|
||||||
|
state.save_to_file(st, path)
|
||||||
|
|
||||||
|
st = state.load_from_file(path)
|
||||||
|
assert st.get_user_active_in_channel(1, "2")
|
||||||
|
|
||||||
|
st.set_use_active_in_channel(1, "2", False)
|
||||||
|
assert not st.get_user_active_in_channel(1, "2")
|
Loading…
Add table
Reference in a new issue