Move python files into py dir

This commit is contained in:
Marc Di Luzio 2024-08-11 18:05:28 +01:00
parent 22ad36fb09
commit 129721eb50
9 changed files with 2 additions and 2 deletions

77
py/config.py Normal file
View file

@ -0,0 +1,77 @@
"""Very simple config loading library"""
from schema import Schema, And, Use
import files
import os
import logging
logger = logging.getLogger("config")
logger.setLevel(logging.INFO)
_FILE = "config.json"
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_VERSION = 1
class _Keys():
TOKEN = "token"
VERSION = "version"
# Removed
OWNERS = "owners"
_SCHEMA = Schema(
{
# The current version
_Keys.VERSION: And(Use(int)),
# Discord bot token
_Keys.TOKEN: And(Use(str)),
}
)
def _migrate_to_v1(d: dict):
# Owners moved to History in v1
# Note: owners will be required to be re-added to the state.json
owners = d.pop(_Keys.OWNERS)
logger.warn(
"Migration removed owners from config, these must be re-added to the state.json")
logger.warn("Owners: %s", owners)
# Set of migration functions to apply
_MIGRATIONS = [
_migrate_to_v1
]
class Config():
def __init__(self, data: dict):
"""Initialise and validate the config"""
_SCHEMA.validate(data)
self._dict = data
@property
def token(self) -> str:
return self._dict["token"]
def _migrate(dict: dict):
"""Migrate a dict through versions"""
version = dict.get("version", 0)
for i in range(version, _VERSION):
_MIGRATIONS[i](dict)
dict["version"] = _VERSION
def load_from_file(file: str = _FILE) -> Config:
"""
Load the state from a file
Apply any required migrations
"""
assert os.path.isfile(file)
loaded = files.load(file)
_migrate(loaded)
return Config(loaded)

20
py/files.py Normal file
View file

@ -0,0 +1,20 @@
"""File operation helpers"""
import json
import shutil
def load(file: str) -> dict:
"""Load a json file directly as a dict"""
with open(file) as f:
return json.load(f)
def save(file: str, content: dict):
"""
Save out a content dictionary to a file
Stores it in an intermediary file first incase the dump fails
"""
intermediate = file + ".nxt"
with open(intermediate, "w") as f:
json.dump(content, f, indent=4)
shutil.move(intermediate, file)

208
py/matching.py Normal file
View file

@ -0,0 +1,208 @@
"""Utility functions for matchy"""
import logging
from datetime import datetime, timedelta
from typing import Protocol, runtime_checkable
import state
# Number of days to step forward from the start of history for each match attempt
_ATTEMPT_TIMESTEP_INCREMENT = timedelta(days=7)
class _ScoreFactors(int):
"""Various eligability scoring factors for group meetups"""
REPEAT_ROLE = 2**2
REPEAT_MATCH = 2**3
EXTRA_MEMBER = 2**5
# Scores higher than this are fully rejected
UPPER_THRESHOLD = 2**6
logger = logging.getLogger("matching")
logger.setLevel(logging.INFO)
@runtime_checkable
class Role(Protocol):
@property
def id(self) -> int:
pass
@runtime_checkable
class Member(Protocol):
@property
def mention(self) -> str:
pass
@property
def id(self) -> int:
pass
@property
def roles(self) -> list[Role]:
pass
@runtime_checkable
class Role(Protocol):
@property
def name(self) -> str:
pass
@runtime_checkable
class Guild(Protocol):
@property
def roles(self) -> list[Role]:
pass
def members_to_groups_simple(matchees: list[Member], per_group: int) -> tuple[bool, list[list[Member]]]:
"""Super simple group matching, literally no logic"""
num_groups = max(len(matchees)//per_group, 1)
return [matchees[i::num_groups] for i in range(num_groups)]
def get_member_group_eligibility_score(member: Member,
group: list[Member],
prior_matches: list[int],
per_group: int) -> float:
"""Rates a member against a group"""
# An empty group is a "perfect" score atomatically
rating = 0
if not group:
return rating
# Add score based on prior matchups of this user
rating += sum(m.id in prior_matches for m in group) * \
_ScoreFactors.REPEAT_MATCH
# Calculate the number of roles that match
all_role_ids = set(r.id for mr in [r.roles for r in group] for r in mr)
member_role_ids = [r.id for r in member.roles]
repeat_roles = sum(id in member_role_ids for id in all_role_ids)
rating += repeat_roles * _ScoreFactors.REPEAT_ROLE
# Add score based on the number of extra members
# Calculate the member offset (+1 for this user)
extra_members = (len(group) - per_group) + 1
if extra_members >= 0:
rating += extra_members * _ScoreFactors.EXTRA_MEMBER
return rating
def attempt_create_groups(matchees: list[Member],
current_state: state.State,
oldest_relevant_ts: datetime,
per_group: int) -> tuple[bool, list[list[Member]]]:
"""History aware group matching"""
num_groups = max(len(matchees)//per_group, 1)
# Set up the groups in place
groups = [[] for _ in range(num_groups)]
matchees_left = matchees.copy()
# Sequentially try and fit each matchee into a group
while matchees_left:
# Get the next matchee to place
matchee = matchees_left.pop()
matchee_matches = current_state.get_user_matches(matchee.id)
relevant_matches = [int(id) for id, ts
in matchee_matches.items()
if state.ts_to_datetime(ts) >= oldest_relevant_ts]
# Try every single group from the current group onwards
# Progressing through the groups like this ensures we slowly fill them up with compatible people
scores: list[tuple[int, float]] = []
for group in groups:
score = get_member_group_eligibility_score(
matchee, group, relevant_matches, per_group)
# If the score isn't too high, consider this group
if score <= _ScoreFactors.UPPER_THRESHOLD:
scores.append((group, score))
# Optimisation:
# A score of 0 means we've got something good enough and can skip
if score == 0:
break
if scores:
(group, _) = sorted(scores, key=lambda pair: pair[1])[0]
group.append(matchee)
else:
# If we failed to add this matchee, bail on the group creation as it could not be done
return None
return groups
def datetime_range(start_time: datetime, increment: timedelta, end: datetime):
"""Yields a datetime range with a given increment"""
current = start_time
while current <= end or end is None:
yield current
current += increment
def iterate_all_shifts(list: list):
"""Yields each shifted variation of the input list"""
yield list
for _ in range(len(list)-1):
list = list[1:] + [list[0]]
yield list
def members_to_groups(matchees: list[Member],
hist: state.State = state.State(),
per_group: int = 3,
allow_fallback: bool = False) -> list[list[Member]]:
"""Generate the groups from the set of matchees"""
attempts = 0 # Tracking for logging purposes
num_groups = len(matchees)//per_group
# Bail early if there's no-one to match
if not matchees:
return []
# Grab the oldest timestamp
history_start = hist.get_oldest_timestamp() or datetime.now()
# Walk from the start of time until now using the timestep increment
for oldest_relevant_datetime in datetime_range(history_start, _ATTEMPT_TIMESTEP_INCREMENT, datetime.now()):
# Attempt with each starting matchee
for shifted_matchees in iterate_all_shifts(matchees):
attempts += 1
groups = attempt_create_groups(
shifted_matchees, hist, oldest_relevant_datetime, per_group)
# Fail the match if our groups aren't big enough
if num_groups <= 1 or (groups and all(len(g) >= per_group for g in groups)):
logger.info("Matched groups after %s attempt(s)", attempts)
return groups
# If we've still failed, just use the simple method
if allow_fallback:
logger.info("Fell back to simple groups after %s attempt(s)", attempts)
return members_to_groups_simple(matchees, per_group)
# Simply assert false, this should never happen
# And should be caught by tests
assert False
def group_to_message(group: list[Member]) -> str:
"""Get the message to send for each group"""
mentions = [m.mention for m in group]
if len(group) > 1:
mentions = f"{', '.join(mentions[:-1])} and {mentions[-1]}"
else:
mentions = mentions[0]
return f"Matched up {mentions}!"

296
py/matching_test.py Normal file
View file

@ -0,0 +1,296 @@
"""
Test functions for the matching module
"""
import discord
import pytest
import random
import matching
import state
from datetime import datetime, timedelta
def test_protocols():
"""Verify the protocols we're using match the discord ones"""
assert isinstance(discord.Member, matching.Member)
assert isinstance(discord.Guild, matching.Guild)
assert isinstance(discord.Role, matching.Role)
assert isinstance(Member, matching.Member)
# assert isinstance(Role, matching.Role)
class Role():
def __init__(self, id: int):
self._id = id
@property
def id(self) -> int:
return self._id
class Member():
def __init__(self, id: int, roles: list[Role] = []):
self._id = id
self._roles = roles
@property
def mention(self) -> str:
return f"<@{self._id}>"
@property
def roles(self) -> list[Role]:
return self._roles
@property
def id(self) -> int:
return self._id
def inner_validate_members_to_groups(matchees: list[Member], tmp_state: state.State, per_group: int):
"""Inner function to validate the main output of the groups function"""
groups = matching.members_to_groups(matchees, tmp_state, per_group)
# We should always have one group
assert len(groups)
# Log the groups to history
# This will validate the internals
tmp_state.log_groups(groups)
# Ensure each group contains within the bounds of expected members
for group in groups:
if len(matchees) >= per_group:
assert len(group) >= per_group
else:
assert len(group) == len(matchees)
assert len(group) < per_group*2 # TODO: We could be more strict here
return groups
@pytest.mark.parametrize("matchees, per_group", [
# Simplest test possible
([Member(1)], 1),
# More requested than we have
([Member(1)], 2),
# A selection of hyper-simple checks to validate core functionality
([Member(1)] * 100, 3),
([Member(1)] * 12, 5),
([Member(1)] * 11, 2),
([Member(1)] * 356, 8),
], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"])
def test_members_to_groups_no_history(matchees, per_group):
"""Test simple group matching works"""
tmp_state = state.State()
inner_validate_members_to_groups(matchees, tmp_state, per_group)
def items_found_in_lists(list_of_lists, items):
"""validates if any sets of items are found in individual lists"""
for sublist in list_of_lists:
if all(item in sublist for item in items):
return True
return False
@pytest.mark.parametrize("history_data, matchees, per_group, checks", [
# Slightly more difficult test
(
# Describe a history where we previously matched up some people and ensure they don't get rematched
[
{
"ts": datetime.now() - timedelta(days=1),
"groups": [
[Member(1), Member(2)],
[Member(3), Member(4)],
]
}
],
[
Member(1),
Member(2),
Member(3),
Member(4),
],
2,
[
lambda groups: not items_found_in_lists(
groups, [Member(1), Member(2)]),
lambda groups: not items_found_in_lists(
groups, [Member(3), Member(4)])
]
),
# Feed the system an "impossible" test
# The function should fall back to ignoring history and still give us something
(
[
{
"ts": datetime.now() - timedelta(days=1),
"groups": [
[
Member(1),
Member(2),
Member(3)
],
[
Member(4),
Member(5),
Member(6)
],
]
}
],
[
Member(1, [Role(1), Role(2), Role(3), Role(4)]),
Member(2, [Role(1), Role(2), Role(3), Role(4)]),
Member(3, [Role(1), Role(2), Role(3), Role(4)]),
Member(4, [Role(1), Role(2), Role(3), Role(4)]),
Member(5, [Role(1), Role(2), Role(3), Role(4)]),
Member(6, [Role(1), Role(2), Role(3), Role(4)]),
],
3,
[
# Nothing specific to validate
]
),
# Specific test pulled out of the stress test
(
[
{
"ts": datetime.now() - timedelta(days=4),
"groups": [
[Member(i) for i in [1, 2, 3, 4, 5, 6,
7, 8, 9, 10, 11, 12, 13, 14, 15]]
]
},
{
"ts": datetime.now() - timedelta(days=5),
"groups": [
[Member(i) for i in [1, 2, 3, 4, 5, 6, 7, 8]]
]
}
],
[Member(i) for i in [1, 2, 11, 4, 12, 3, 7, 5, 8, 10, 9, 6]],
3,
[
# Nothing specific to validate
]
),
# Silly example that failued due to bad role logic
(
[
# No history
],
[
# print([(m.id, [r.id for r in m.roles]) for m in matchees]) to get the below
Member(i, [Role(r) for r in roles]) for (i, roles) in
[
(4, [1, 2, 3, 4, 5, 6, 7, 8]),
(8, [1]),
(9, [1, 2, 3, 4, 5]),
(6, [1, 2, 3]),
(11, [1, 2, 3]),
(7, [1, 2, 3, 4, 5, 6, 7]),
(1, [1, 2, 3, 4]),
(5, [1, 2, 3, 4, 5]),
(12, [1, 2, 3, 4]),
(10, [1]),
(13, [1, 2, 3, 4, 5, 6]),
(2, [1, 2, 3, 4, 5, 6]),
(3, [1, 2, 3, 4, 5, 6, 7])
]
],
2,
[
# Nothing else
]
)
], ids=['simple_history', 'fallback', 'example_1', 'example_2'])
def test_members_to_groups_with_history(history_data, matchees, per_group, checks):
"""Test more advanced group matching works"""
tmp_state = state.State()
# Replay the history
for d in history_data:
tmp_state.log_groups(d["groups"], d["ts"])
groups = inner_validate_members_to_groups(matchees, tmp_state, per_group)
# Run the custom validate functions
for check in checks:
assert check(groups)
def test_members_to_groups_stress_test():
"""stress test firing significant random data at the code"""
# Use a stable rand, feel free to adjust this if needed but this lets the test be stable
rand = random.Random(123)
# Slowly ramp up the group size
for per_group in range(2, 6):
# Slowly ramp a randomized shuffled list of members with randomised roles
for num_members in range(1, 5):
matchees = [Member(i, [Role(i) for i in range(1, rand.randint(2, num_members*2 + 1))])
for i in range(1, rand.randint(2, num_members*10 + 1))]
rand.shuffle(matchees)
for num_history in range(8):
# Generate some super random history
# Start some time from now to the past
time = datetime.now() - timedelta(days=rand.randint(0, num_history*5))
history_data = []
for _ in range(0, num_history):
run = {
"ts": time
}
groups = []
for y in range(1, num_history):
groups.append([Member(i)
for i in range(1, max(num_members, rand.randint(2, num_members*10 + 1)))])
run["groups"] = groups
history_data.append(run)
# Step some time backwards in time
time -= timedelta(days=rand.randint(1, num_history))
# No guarantees on history data order so make it a little harder for matchy
rand.shuffle(history_data)
# Replay the history
tmp_state = state.State()
for d in history_data:
tmp_state.log_groups(d["groups"], d["ts"])
inner_validate_members_to_groups(
matchees, tmp_state, per_group)
def test_auth_scopes():
tmp_state = state.State()
id = "1"
tmp_state.set_user_scope(id, state.AuthScope.OWNER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
id = "2"
tmp_state.set_user_scope(id, state.AuthScope.MATCHER)
assert not tmp_state.get_user_has_scope(id, state.AuthScope.OWNER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
tmp_state.validate()
def test_iterate_all_shifts():
original = [1, 2, 3, 4]
lists = [val for val in matching.iterate_all_shifts(original)]
assert lists == [
[1, 2, 3, 4],
[2, 3, 4, 1],
[3, 4, 1, 2],
[4, 1, 2, 3],
]

214
py/matchy.py Executable file
View file

@ -0,0 +1,214 @@
"""
matchy.py - Discord bot that matches people into groups
"""
import logging
import discord
from discord import app_commands
from discord.ext import commands
import matching
import state
import config
import re
STATE_FILE = "state.json"
CONFIG_FILE = "config.json"
Config = config.load_from_file(CONFIG_FILE)
State = state.load_from_file(STATE_FILE)
logger = logging.getLogger("matchy")
logger.setLevel(logging.INFO)
intents = discord.Intents.default()
intents.message_content = True
intents.members = True
bot = commands.Bot(command_prefix='$',
description="Matchy matches matchees", intents=intents)
@bot.event
async def setup_hook():
bot.add_dynamic_items(DynamicGroupButton)
@bot.event
async def on_ready():
"""Bot is ready and connected"""
logger.info("Bot is up and ready!")
activity = discord.Game("/match")
await bot.change_presence(status=discord.Status.online, activity=activity)
def owner_only(ctx: commands.Context) -> bool:
"""Checks the author is an owner"""
return State.get_user_has_scope(ctx.message.author.id, state.AuthScope.OWNER)
@bot.command()
@commands.dm_only()
@commands.check(owner_only)
async def sync(ctx: commands.Context):
"""Handle sync command"""
msg = await ctx.reply("Reloading state...", ephemeral=True)
global State
State = state.load_from_file(STATE_FILE)
logger.info("Reloaded state")
await msg.edit(content="Syncing commands...")
synced = await bot.tree.sync()
logger.info("Synced %s command(s)", len(synced))
await msg.edit(content="Done!")
@bot.command()
@commands.dm_only()
@commands.check(owner_only)
async def close(ctx: commands.Context):
"""Handle restart command"""
await ctx.reply("Closing bot...", ephemeral=True)
logger.info("Closing down the bot")
await bot.close()
@bot.tree.command(description="Join the matchees for this channel")
@commands.guild_only()
async def join(interaction: discord.Interaction):
State.set_use_active_in_channel(
interaction.user.id, interaction.channel.id)
state.save_to_file(State, STATE_FILE)
await interaction.response.send_message(
f"Roger roger {interaction.user.mention}!\n"
+ f"Added you to {interaction.channel.mention}!",
ephemeral=True, silent=True)
@bot.tree.command(description="Leave the matchees for this channel")
@commands.guild_only()
async def leave(interaction: discord.Interaction):
State.set_use_active_in_channel(
interaction.user.id, interaction.channel.id, False)
state.save_to_file(State, STATE_FILE)
await interaction.response.send_message(
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
@bot.tree.command(description="List the matchees for this channel")
@commands.guild_only()
async def list(interaction: discord.Interaction):
matchees = get_matchees_in_channel(interaction.channel)
mentions = [m.mention for m in matchees]
msg = "Current matchees in this channel:\n" + \
f"{', '.join(mentions[:-1])} and {mentions[-1]}"
await interaction.response.send_message(msg, ephemeral=True, silent=True)
@bot.tree.command(description="Match up matchees")
@commands.guild_only()
@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)")
async def match(interaction: discord.Interaction, members_min: int = None):
"""Match groups of channel members"""
logger.info("Handling request '/match group_min=%s", members_min)
logger.info("User %s from %s in #%s", interaction.user,
interaction.guild.name, interaction.channel.name)
# Sort out the defaults, if not specified they'll come in as None
if not members_min:
members_min = 3
# Grab the groups
groups = active_members_to_groups(interaction.channel, members_min)
# Let the user know when there's nobody to match
if not groups:
await interaction.response.send_message("Nobody to match up :(", ephemeral=True, silent=True)
return
# Post about all the groups with a button to send to the channel
groups_list = '\n'.join(matching.group_to_message(g) for g in groups)
msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}"
view = discord.utils.MISSING
if State.get_user_has_scope(interaction.user.id, state.AuthScope.MATCHER):
# Let a non-matcher know why they don't have the button
msg += f"\n\nYou'll need the {state.AuthScope.MATCHER} scope to post this to the channel, sorry!"
else:
# Otherwise set up the button
msg += "\n\nClick the button to match up groups and send them to the channel.\n"
view = discord.ui.View(timeout=None)
view.add_item(DynamicGroupButton(members_min))
await interaction.response.send_message(msg, ephemeral=True, silent=True, view=view)
logger.info("Done.")
# Increment when adjusting the custom_id so we don't confuse old users
_BUTTON_CUSTOM_ID_VERSION = 1
class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
template=f'match:v{_BUTTON_CUSTOM_ID_VERSION}:' + r'min:(?P<min>[0-9]+)'):
def __init__(self, min: int) -> None:
super().__init__(
discord.ui.Button(
label='Match Groups!',
style=discord.ButtonStyle.blurple,
custom_id=f'match:min:{min}',
)
)
self.min: int = min
# This is called when the button is clicked and the custom_id matches the template.
@classmethod
async def from_custom_id(cls, interaction: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /):
min = int(match['min'])
return cls(min)
async def callback(self, interaction: discord.Interaction) -> None:
"""Match up people when the button is pressed"""
logger.info("Handling button press min=%s", self.min)
logger.info("User %s from %s in #%s", interaction.user,
interaction.guild.name, interaction.channel.name)
# Let the user know we've recieved the message
await interaction.response.send_message(content="Matchy is matching matchees...", ephemeral=True)
groups = active_members_to_groups(interaction.channel, self.min)
# Send the groups
for msg in (matching.group_to_message(g) for g in groups):
await interaction.channel.send(msg)
# Close off with a message
await interaction.channel.send("That's all folks, happy matching and remember - DFTBA!")
# Save the groups to the history
State.log_groups(groups)
state.save_to_file(State, STATE_FILE)
logger.info("Done! Matched into %s groups.", len(groups))
def get_matchees_in_channel(channel: discord.channel):
"""Fetches the matchees in a channel"""
# Gather up the prospective matchees
return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)]
def active_members_to_groups(channel: discord.channel, min_members: int):
"""Helper to create groups from channel members"""
# Gather up the prospective matchees
matchees = get_matchees_in_channel(channel)
# Create our groups!
return matching.members_to_groups(matchees, State, min_members, allow_fallback=True)
if __name__ == "__main__":
handler = logging.StreamHandler()
bot.run(Config.token, log_handler=handler, root_logger=True)

257
py/state.py Normal file
View file

@ -0,0 +1,257 @@
"""Store bot state"""
import os
from datetime import datetime
from schema import Schema, And, Use, Optional
from typing import Protocol
import files
import copy
import logging
logger = logging.getLogger("state")
logger.setLevel(logging.INFO)
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_VERSION = 1
def _migrate_to_v1(d: dict):
logger.info("Renaming %s to %s", _Key.MATCHEES, _Key.USERS)
d[_Key.USERS] = d[_Key.MATCHEES]
del d[_Key.MATCHEES]
# Set of migration functions to apply
_MIGRATIONS = [
_migrate_to_v1
]
class AuthScope(str):
"""Various auth scopes"""
OWNER = "owner"
MATCHER = "matcher"
class _Key(str):
"""Various keys used in the schema"""
HISTORY = "history"
GROUPS = "groups"
MEMBERS = "members"
USERS = "users"
SCOPES = "scopes"
MATCHES = "matches"
ACTIVE = "active"
CHANNELS = "channels"
REACTIVATE = "reactivate"
VERSION = "version"
# Unused
MATCHEES = "matchees"
_TIME_FORMAT = "%a %b %d %H:%M:%S %Y"
_SCHEMA = Schema(
{
# The current version
_Key.VERSION: And(Use(int)),
Optional(_Key.HISTORY): {
# A datetime
Optional(str): {
_Key.GROUPS: [
{
_Key.MEMBERS: [
# The ID of each matchee in the match
And(Use(int))
]
}
]
}
},
Optional(_Key.USERS): {
Optional(str): {
Optional(_Key.SCOPES): And(Use(list[str])),
Optional(_Key.MATCHES): {
# Matchee ID and Datetime pair
Optional(str): And(Use(str))
},
Optional(_Key.CHANNELS): {
# The channel ID
Optional(str): {
# Whether the user is signed up in this channel
_Key.ACTIVE: And(Use(bool)),
}
}
}
},
}
)
# Empty but schema-valid internal dict
_EMPTY_DICT = {
_Key.HISTORY: {},
_Key.USERS: {},
_Key.VERSION: _VERSION
}
assert _SCHEMA.validate(_EMPTY_DICT)
class Member(Protocol):
@property
def id(self) -> int:
pass
def ts_to_datetime(ts: str) -> datetime:
"""Convert a ts to datetime using the internal format"""
return datetime.strptime(ts, _TIME_FORMAT)
class State():
def __init__(self, data: dict = _EMPTY_DICT):
"""Initialise and validate the state"""
self.validate(data)
self._dict = copy.deepcopy(data)
@property
def _history(self) -> dict[str]:
return self._dict[_Key.HISTORY]
@property
def _users(self) -> dict[str]:
return self._dict[_Key.USERS]
def validate(self, dict: dict = None):
"""Initialise and validate a state dict"""
if not dict:
dict = self._dict
_SCHEMA.validate(dict)
def get_oldest_timestamp(self) -> datetime:
"""Grab the oldest timestamp in history"""
times = (ts_to_datetime(dt) for dt in self._history.keys())
return next(times, None)
def get_user_matches(self, id: int) -> list[int]:
return self._users.get(str(id), {}).get(_Key.MATCHES, {})
def log_groups(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None:
"""Log the groups"""
tmp_state = State(self._dict)
ts = datetime.strftime(ts, _TIME_FORMAT)
# Grab or create the hitory item for this set of groups
history_item = {}
tmp_state._history[ts] = history_item
history_item_groups = []
history_item[_Key.GROUPS] = history_item_groups
for group in groups:
# Add the group data
history_item_groups.append({
_Key.MEMBERS: [m.id for m in group]
})
# Update the matchee data with the matches
for m in group:
matchee = tmp_state._users.get(str(m.id), {})
matchee_matches = matchee.get(_Key.MATCHES, {})
for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts
matchee[_Key.MATCHES] = matchee_matches
tmp_state._users[str(m.id)] = matchee
# Validate before storing the result
tmp_state.validate()
self._dict = tmp_state._dict
def set_user_scope(self, id: str, scope: str, value: bool = True):
"""Add an auth scope to a user"""
# Dive in
user = self._users.get(str(id), {})
scopes = user.get(_Key.SCOPES, [])
# Set the value
if value and scope not in scopes:
scopes.append(scope)
elif not value and scope in scopes:
scopes.remove(scope)
# Roll out
user[_Key.SCOPES] = scopes
self._users[id] = user
def get_user_has_scope(self, id: str, scope: str) -> bool:
"""
Check if a user has an auth scope
"owner" users have all scopes
"""
user = self._users.get(str(id), {})
scopes = user.get(_Key.SCOPES, [])
return AuthScope.OWNER in scopes or scope in scopes
def set_use_active_in_channel(self, id: str, channel_id: str, active: bool = True):
"""Set a user as active (or not) on a given channel"""
# Dive in
user = self._users.get(str(id), {})
channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {})
# Set the value
channel[_Key.ACTIVE] = active
# Unroll
channels[str(channel_id)] = channel
user[_Key.CHANNELS] = channels
self._users[str(id)] = user
def get_user_active_in_channel(self, id: str, channel_id: str) -> bool:
"""Get a users active channels"""
user = self._users.get(str(id), {})
channels = user.get(_Key.CHANNELS, {})
return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)]
@property
def dict_internal(self) -> dict:
"""Only to be used to get the internal dict as a copy"""
return copy.deepcopy(self._dict)
def _migrate(dict: dict):
"""Migrate a dict through versions"""
version = dict.get("version", 0)
for i in range(version, _VERSION):
logger.info("Migrating from v%s to v%s", version, version+1)
_MIGRATIONS[i](dict)
dict[_Key.VERSION] = _VERSION
def load_from_file(file: str) -> State:
"""
Load the state from a file
Apply any required migrations
"""
loaded = _EMPTY_DICT
# If there's a file load it and try to migrate
if os.path.isfile(file):
loaded = files.load(file)
_migrate(loaded)
st = State(loaded)
# Save out the migrated (or new) file
files.save(file, st._dict)
return st
def save_to_file(state: State, file: str):
"""Saves the state out to a file"""
files.save(file, state.dict_internal)

65
py/state_test.py Normal file
View file

@ -0,0 +1,65 @@
"""
Test functions for the state module
"""
import state
import tempfile
import os
def test_basic_state():
"""Simple validate basic state load"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
state.load_from_file(path)
def test_simple_load_reload():
"""Test a basic load, save, reload"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
st = state.load_from_file(path)
state.save_to_file(st, path)
st = state.load_from_file(path)
def test_authscope():
"""Test setting and getting an auth scope"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
st = state.load_from_file(path)
st.set_user_scope(1, state.AuthScope.MATCHER)
state.save_to_file(st, path)
st = state.load_from_file(path)
assert st.get_user_has_scope(1, state.AuthScope.MATCHER)
st.set_user_scope(1, state.AuthScope.MATCHER, False)
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
def test_channeljoin():
"""Test setting and getting an active channel"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path)
state.save_to_file(st, path)
assert not st.get_user_active_in_channel(1, "2")
st = state.load_from_file(path)
st.set_use_active_in_channel(1, "2", True)
state.save_to_file(st, path)
st = state.load_from_file(path)
assert st.get_user_active_in_channel(1, "2")
st.set_use_active_in_channel(1, "2", False)
assert not st.get_user_active_in_channel(1, "2")