diff --git a/README.md b/README.md index 056bacd..d0ec770 100644 --- a/README.md +++ b/README.md @@ -20,18 +20,16 @@ Only usable by `OWNER` users, reloads the config and syncs commands, or closes d Matchy is configured by a `config.json` file that takes this format: ``` { + "version": 1, "token": "<>", - "owners": [ - <> - ] } ``` User IDs can be grabbed by turning on Discord's developer mode and right clicking on a user. ## TODO * Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html) +* Implement /pause to pause a user for a little while +* Move more constants to the config * Add scheduling functionality -* Version the config and history files -* Implement /signup rather than using roles -* Implement authorisation scopes instead of just OWNER values +* Fix logging in some sub files * Improve the weirdo \ No newline at end of file diff --git a/config.py b/config.py index 9882d99..2719d73 100644 --- a/config.py +++ b/config.py @@ -1,38 +1,77 @@ """Very simple config loading library""" from schema import Schema, And, Use import files +import os +import logging + +logger = logging.getLogger("config") +logger.setLevel(logging.INFO) _FILE = "config.json" + +# Warning: Changing any of the below needs proper thought to ensure backwards compatibility +_VERSION = 1 + + +class _Keys(): + TOKEN = "token" + VERSION = "version" + + # Removed + OWNERS = "owners" + + _SCHEMA = Schema( { - # Discord bot token - "token": And(Use(str)), + # The current version + _Keys.VERSION: And(Use(int)), - # ids of owners authorised to use owner-only commands - "owners": And(Use(list[int])), + # Discord bot token + _Keys.TOKEN: And(Use(str)), } ) +def _migrate_to_v1(d: dict): + # Owners moved to History in v1 + # Note: owners will be required to be re-added to the state.json + owners = d.pop(_Keys.OWNERS) + logger.warn( + "Migration removed owners from config, these must be re-added to the state.json") + logger.warn("Owners: %s", owners) + + +# Set of migration functions to apply +_MIGRATIONS = [ + _migrate_to_v1 +] + + class Config(): def __init__(self, data: dict): """Initialise and validate the config""" _SCHEMA.validate(data) - self.__dict__ = data + self._dict = data @property def token(self) -> str: - return self.__dict__["token"] - - @property - def owners(self) -> list[int]: - return self.__dict__["owners"] - - def reload(self) -> None: - """Reload the config back into the dict""" - self.__dict__ = load().__dict__ + return self._dict["token"] -def load() -> Config: - """Load the config""" - return Config(files.load(_FILE)) +def _migrate(dict: dict): + """Migrate a dict through versions""" + version = dict.get("version", 0) + for i in range(version, _VERSION): + _MIGRATIONS[i](dict) + dict["version"] = _VERSION + + +def load_from_file(file: str = _FILE) -> Config: + """ + Load the state from a file + Apply any required migrations + """ + assert os.path.isfile(file) + loaded = files.load(file) + _migrate(loaded) + return Config(loaded) diff --git a/history.py b/history.py deleted file mode 100644 index 06b77e4..0000000 --- a/history.py +++ /dev/null @@ -1,125 +0,0 @@ -"""Store matching history""" -import os -from datetime import datetime -from schema import Schema, And, Use, Optional -from typing import Protocol -import files -import copy - -_FILE = "history.json" - -# Warning: Changing any of the below needs proper thought to ensure backwards compatibility -_DEFAULT_DICT = { - "history": {}, - "matchees": {} -} -_TIME_FORMAT = "%a %b %d %H:%M:%S %Y" -_SCHEMA = Schema( - { - Optional("history"): { - Optional(str): { # a datetime - "groups": [ - { - "members": [ - # The ID of each matchee in the match - And(Use(int)) - ] - } - ] - } - }, - Optional("matchees"): { - Optional(str): { - Optional("matches"): { - # Matchee ID and Datetime pair - Optional(str): And(Use(str)) - } - } - } - } -) - - -class Member(Protocol): - @property - def id(self) -> int: - pass - - -def ts_to_datetime(ts: str) -> datetime: - """Convert a ts to datetime using the history format""" - return datetime.strptime(ts, _TIME_FORMAT) - - -def validate(dict: dict): - """Initialise and validate the history""" - _SCHEMA.validate(dict) - - -class History(): - def __init__(self, data: dict = _DEFAULT_DICT): - """Initialise and validate the history""" - validate(data) - self.__dict__ = copy.deepcopy(data) - - @property - def history(self) -> list[dict]: - return self.__dict__["history"] - - @property - def matchees(self) -> dict[str, dict]: - return self.__dict__["matchees"] - - def save(self) -> None: - """Save out the history""" - files.save(_FILE, self.__dict__) - - def oldest(self) -> datetime: - """Grab the oldest timestamp in history""" - if not self.history: - return None - times = (ts_to_datetime(dt) for dt in self.history.keys()) - return sorted(times)[0] - - def log_groups_to_history(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None: - """Log the groups""" - tmp_history = History(self.__dict__) - ts = datetime.strftime(ts, _TIME_FORMAT) - - # Grab or create the hitory item for this set of groups - history_item = {} - tmp_history.history[ts] = history_item - history_item_groups = [] - history_item["groups"] = history_item_groups - - for group in groups: - - # Add the group data - history_item_groups.append({ - "members": list(m.id for m in group) - }) - - # Update the matchee data with the matches - for m in group: - matchee = tmp_history.matchees.get(str(m.id), {}) - matchee_matches = matchee.get("matches", {}) - - for o in (o for o in group if o.id != m.id): - matchee_matches[str(o.id)] = ts - - matchee["matches"] = matchee_matches - tmp_history.matchees[str(m.id)] = matchee - - # Validate before storing the result - validate(self.__dict__) - self.__dict__ = tmp_history.__dict__ - - def save_groups_to_history(self, groups: list[list[Member]]) -> None: - """Save out the groups to the history file""" - self.log_groups_to_history(groups) - self.save() - - -def load() -> History: - """Load the history""" - return History(files.load(_FILE) if os.path.isfile(_FILE) else _DEFAULT_DICT) diff --git a/matching.py b/matching.py index 3b51a1d..4f9b31d 100644 --- a/matching.py +++ b/matching.py @@ -1,6 +1,5 @@ """Utility functions for matchy""" import logging -import random from datetime import datetime, timedelta from typing import Protocol, runtime_checkable import state @@ -9,17 +8,16 @@ import state # Number of days to step forward from the start of history for each match attempt _ATTEMPT_TIMESTEP_INCREMENT = timedelta(days=7) -# Attempts for each of those time periods -_ATTEMPTS_PER_TIMESTEP = 3 -# Various eligability scoring factors for group meetups -_SCORE_CURRENT_MEMBERS = 2**1 -_SCORE_REPEAT_ROLE = 2**2 -_SCORE_REPEAT_MATCH = 2**3 -_SCORE_EXTRA_MEMBERS = 2**4 +class _ScoreFactors(int): + """Various eligability scoring factors for group meetups""" + REPEAT_ROLE = 2**2 + REPEAT_MATCH = 2**3 + EXTRA_MEMBER = 2**5 + + # Scores higher than this are fully rejected + UPPER_THRESHOLD = 2**6 -# Scores higher than this are fully rejected -_SCORE_UPPER_THRESHOLD = 2**6 logger = logging.getLogger("matching") logger.setLevel(logging.INFO) @@ -69,33 +67,42 @@ def members_to_groups_simple(matchees: list[Member], per_group: int) -> tuple[bo def get_member_group_eligibility_score(member: Member, group: list[Member], - relevant_matches: list[int], - per_group: int) -> int: + prior_matches: list[int], + per_group: int) -> float: """Rates a member against a group""" - rating = len(group) * _SCORE_CURRENT_MEMBERS + # An empty group is a "perfect" score atomatically + rating = 0 + if not group: + return rating - repeat_meetings = sum(m.id in relevant_matches for m in group) - rating += repeat_meetings * _SCORE_REPEAT_MATCH + # Add score based on prior matchups of this user + rating += sum(m.id in prior_matches for m in group) * \ + _ScoreFactors.REPEAT_MATCH - repeat_roles = sum(r in member.roles for r in (m.roles for m in group)) - rating += (repeat_roles * _SCORE_REPEAT_ROLE) + # Calculate the number of roles that match + all_role_ids = set(r.id for mr in [r.roles for r in group] for r in mr) + member_role_ids = [r.id for r in member.roles] + repeat_roles = sum(id in member_role_ids for id in all_role_ids) + rating += repeat_roles * _ScoreFactors.REPEAT_ROLE - extra_members = len(group) - per_group - if extra_members > 0: - rating += extra_members * _SCORE_EXTRA_MEMBERS + # Add score based on the number of extra members + # Calculate the member offset (+1 for this user) + extra_members = (len(group) - per_group) + 1 + if extra_members >= 0: + rating += extra_members * _ScoreFactors.EXTRA_MEMBER return rating def attempt_create_groups(matchees: list[Member], - hist: state.State, + current_state: state.State, oldest_relevant_ts: datetime, per_group: int) -> tuple[bool, list[list[Member]]]: """History aware group matching""" num_groups = max(len(matchees)//per_group, 1) # Set up the groups in place - groups = list([] for _ in range(num_groups)) + groups = [[] for _ in range(num_groups)] matchees_left = matchees.copy() @@ -103,21 +110,21 @@ def attempt_create_groups(matchees: list[Member], while matchees_left: # Get the next matchee to place matchee = matchees_left.pop() - matchee_matches = hist.matchees.get( - str(matchee.id), {}).get("matches", {}) - relevant_matches = list(int(id) for id, ts in matchee_matches.items() - if state.ts_to_datetime(ts) >= oldest_relevant_ts) + matchee_matches = current_state.get_user_matches(matchee.id) + relevant_matches = [int(id) for id, ts + in matchee_matches.items() + if state.ts_to_datetime(ts) >= oldest_relevant_ts] # Try every single group from the current group onwards # Progressing through the groups like this ensures we slowly fill them up with compatible people - scores: list[tuple[int, int]] = [] + scores: list[tuple[int, float]] = [] for group in groups: score = get_member_group_eligibility_score( - matchee, group, relevant_matches, num_groups) + matchee, group, relevant_matches, per_group) # If the score isn't too high, consider this group - if score <= _SCORE_UPPER_THRESHOLD: + if score <= _ScoreFactors.UPPER_THRESHOLD: scores.append((group, score)) # Optimisation: @@ -143,31 +150,41 @@ def datetime_range(start_time: datetime, increment: timedelta, end: datetime): current += increment +def iterate_all_shifts(list: list): + """Yields each shifted variation of the input list""" + yield list + for _ in range(len(list)-1): + list = list[1:] + [list[0]] + yield list + + def members_to_groups(matchees: list[Member], hist: state.State = state.State(), per_group: int = 3, allow_fallback: bool = False) -> list[list[Member]]: """Generate the groups from the set of matchees""" attempts = 0 # Tracking for logging purposes - rand = random.Random(117) # Some stable randomness + num_groups = len(matchees)//per_group + + # Bail early if there's no-one to match + if not matchees: + return [] # Grab the oldest timestamp - history_start = hist.oldest_history() or datetime.now() + history_start = hist.get_oldest_timestamp() or datetime.now() # Walk from the start of time until now using the timestep increment for oldest_relevant_datetime in datetime_range(history_start, _ATTEMPT_TIMESTEP_INCREMENT, datetime.now()): - # Have a few attempts before stepping forward in time - for _ in range(_ATTEMPTS_PER_TIMESTEP): - - rand.shuffle(matchees) # Shuffle the matchees each attempt + # Attempt with each starting matchee + for shifted_matchees in iterate_all_shifts(matchees): attempts += 1 groups = attempt_create_groups( - matchees, hist, oldest_relevant_datetime, per_group) + shifted_matchees, hist, oldest_relevant_datetime, per_group) # Fail the match if our groups aren't big enough - if (len(matchees)//per_group) <= 1 or (groups and all(len(g) >= per_group for g in groups)): + if num_groups <= 1 or (groups and all(len(g) >= per_group for g in groups)): logger.info("Matched groups after %s attempt(s)", attempts) return groups @@ -176,6 +193,10 @@ def members_to_groups(matchees: list[Member], logger.info("Fell back to simple groups after %s attempt(s)", attempts) return members_to_groups_simple(matchees, per_group) + # Simply assert false, this should never happen + # And should be caught by tests + assert False + def group_to_message(group: list[Member]) -> str: """Get the message to send for each group""" @@ -185,8 +206,3 @@ def group_to_message(group: list[Member]) -> str: else: mentions = mentions[0] return f"Matched up {mentions}!" - - -def get_role_from_guild(guild: Guild, role: str) -> Role: - """Find a role in a guild""" - return next((r for r in guild.roles if r.name == role), None) diff --git a/matching_test.py b/matching_test.py index 8409202..5a40897 100644 --- a/matching_test.py +++ b/matching_test.py @@ -1,5 +1,5 @@ """ - Test functions for Matchy + Test functions for the matching module """ import discord import pytest @@ -30,6 +30,7 @@ class Role(): class Member(): def __init__(self, id: int, roles: list[Role] = []): self._id = id + self._roles = roles @property def mention(self) -> str: @@ -37,7 +38,7 @@ class Member(): @property def roles(self) -> list[Role]: - return [] + return self._roles @property def id(self) -> int: @@ -153,7 +154,59 @@ def items_found_in_lists(list_of_lists, items): # Nothing specific to validate ] ), -], ids=['simple_history', 'fallback']) + # Specific test pulled out of the stress test + ( + [ + { + "ts": datetime.now() - timedelta(days=4), + "groups": [ + [Member(i) for i in [1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15]] + ] + }, + { + "ts": datetime.now() - timedelta(days=5), + "groups": [ + [Member(i) for i in [1, 2, 3, 4, 5, 6, 7, 8]] + ] + } + ], + [Member(i) for i in [1, 2, 11, 4, 12, 3, 7, 5, 8, 10, 9, 6]], + 3, + [ + # Nothing specific to validate + ] + ), + # Silly example that failued due to bad role logic + ( + [ + # No history + ], + [ + # print([(m.id, [r.id for r in m.roles]) for m in matchees]) to get the below + Member(i, [Role(r) for r in roles]) for (i, roles) in + [ + (4, [1, 2, 3, 4, 5, 6, 7, 8]), + (8, [1]), + (9, [1, 2, 3, 4, 5]), + (6, [1, 2, 3]), + (11, [1, 2, 3]), + (7, [1, 2, 3, 4, 5, 6, 7]), + (1, [1, 2, 3, 4]), + (5, [1, 2, 3, 4, 5]), + (12, [1, 2, 3, 4]), + (10, [1]), + (13, [1, 2, 3, 4, 5, 6]), + (2, [1, 2, 3, 4, 5, 6]), + (3, [1, 2, 3, 4, 5, 6, 7]) + ] + ], + 2, + [ + # Nothing else + ] + ) +], ids=['simple_history', 'fallback', 'example_1', 'example_2']) def test_members_to_groups_with_history(history_data, matchees, per_group, checks): """Test more advanced group matching works""" tmp_state = state.State() @@ -180,8 +233,8 @@ def test_members_to_groups_stress_test(): # Slowly ramp a randomized shuffled list of members with randomised roles for num_members in range(1, 5): - matchees = list(Member(i, list(Role(i) for i in range(1, rand.randint(2, num_members*2 + 1)))) - for i in range(1, rand.randint(2, num_members*10 + 1))) + matchees = [Member(i, [Role(i) for i in range(1, rand.randint(2, num_members*2 + 1))]) + for i in range(1, rand.randint(2, num_members*10 + 1))] rand.shuffle(matchees) for num_history in range(8): @@ -190,14 +243,14 @@ def test_members_to_groups_stress_test(): # Start some time from now to the past time = datetime.now() - timedelta(days=rand.randint(0, num_history*5)) history_data = [] - for x in range(0, num_history): + for _ in range(0, num_history): run = { "ts": time } groups = [] for y in range(1, num_history): - groups.append(list(Member(i) - for i in range(1, max(num_members, rand.randint(2, num_members*10 + 1))))) + groups.append([Member(i) + for i in range(1, max(num_members, rand.randint(2, num_members*10 + 1)))]) run["groups"] = groups history_data.append(run) @@ -212,4 +265,32 @@ def test_members_to_groups_stress_test(): for d in history_data: tmp_state.log_groups(d["groups"], d["ts"]) - inner_validate_members_to_groups(matchees, tmp_state, per_group) + inner_validate_members_to_groups( + matchees, tmp_state, per_group) + + +def test_auth_scopes(): + tmp_state = state.State() + + id = "1" + tmp_state.set_user_scope(id, state.AuthScope.OWNER) + assert tmp_state.get_user_has_scope(id, state.AuthScope.OWNER) + assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER) + + id = "2" + tmp_state.set_user_scope(id, state.AuthScope.MATCHER) + assert not tmp_state.get_user_has_scope(id, state.AuthScope.OWNER) + assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER) + + tmp_state.validate() + + +def test_iterate_all_shifts(): + original = [1, 2, 3, 4] + lists = [val for val in matching.iterate_all_shifts(original)] + assert lists == [ + [1, 2, 3, 4], + [2, 3, 4, 1], + [3, 4, 1, 2], + [4, 1, 2, 3], + ] diff --git a/matchy.py b/matchy.py index a76d78c..7f5fc8e 100755 --- a/matchy.py +++ b/matchy.py @@ -11,8 +11,11 @@ import config import re -Config = config.load() -State = state.load() +STATE_FILE = "state.json" +CONFIG_FILE = "config.json" + +Config = config.load_from_file(CONFIG_FILE) +State = state.load_from_file(STATE_FILE) logger = logging.getLogger("matchy") logger.setLevel(logging.INFO) @@ -39,7 +42,7 @@ async def on_ready(): def owner_only(ctx: commands.Context) -> bool: """Checks the author is an owner""" - return ctx.message.author.id in Config.owners + return State.get_user_has_scope(ctx.message.author.id, state.AuthScope.OWNER) @bot.command() @@ -47,9 +50,10 @@ def owner_only(ctx: commands.Context) -> bool: @commands.check(owner_only) async def sync(ctx: commands.Context): """Handle sync command""" - msg = await ctx.reply("Reloading config...", ephemeral=True) - Config.reload() - logger.info("Reloaded config") + msg = await ctx.reply("Reloading state...", ephemeral=True) + global State + State = state.load_from_file(STATE_FILE) + logger.info("Reloaded state") await msg.edit(content="Syncing commands...") synced = await bot.tree.sync() @@ -68,96 +72,112 @@ async def close(ctx: commands.Context): await bot.close() +@bot.tree.command(description="Join the matchees for this channel") +@commands.guild_only() +async def join(interaction: discord.Interaction): + State.set_use_active_in_channel( + interaction.user.id, interaction.channel.id) + state.save_to_file(State, STATE_FILE) + await interaction.response.send_message( + f"Roger roger {interaction.user.mention}!\n" + + f"Added you to {interaction.channel.mention}!", + ephemeral=True, silent=True) + + +@bot.tree.command(description="Leave the matchees for this channel") +@commands.guild_only() +async def leave(interaction: discord.Interaction): + State.set_use_active_in_channel( + interaction.user.id, interaction.channel.id, False) + state.save_to_file(State, STATE_FILE) + await interaction.response.send_message( + f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True) + + +@bot.tree.command(description="List the matchees for this channel") +@commands.guild_only() +async def list(interaction: discord.Interaction): + matchees = get_matchees_in_channel(interaction.channel) + mentions = [m.mention for m in matchees] + msg = "Current matchees in this channel:\n" + \ + f"{', '.join(mentions[:-1])} and {mentions[-1]}" + await interaction.response.send_message(msg, ephemeral=True, silent=True) + + @bot.tree.command(description="Match up matchees") @commands.guild_only() -@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)", - matchee_role="Role for matchees (defaults to @Matchee)") -async def match(interaction: discord.Interaction, members_min: int = None, matchee_role: str = None): +@app_commands.describe(members_min="Minimum matchees per match (defaults to 3)") +async def match(interaction: discord.Interaction, members_min: int = None): """Match groups of channel members""" - logger.info("Handling request '/match group_min=%s matchee_role=%s'", - members_min, matchee_role) + logger.info("Handling request '/match group_min=%s", members_min) logger.info("User %s from %s in #%s", interaction.user, interaction.guild.name, interaction.channel.name) # Sort out the defaults, if not specified they'll come in as None if not members_min: members_min = 3 - if not matchee_role: - matchee_role = "Matchee" - # Grab the roles and verify the given role - matcher = matching.get_role_from_guild(interaction.guild, "Matcher") - matcher = matcher and matcher in interaction.user.roles - matchee = matching.get_role_from_guild(interaction.guild, matchee_role) - if not matchee: - await interaction.response.send_message(f"Server is missing '{matchee_role}' role :(", ephemeral=True) + # Grab the groups + groups = active_members_to_groups(interaction.channel, members_min) + + # Let the user know when there's nobody to match + if not groups: + await interaction.response.send_message("Nobody to match up :(", ephemeral=True, silent=True) return - # Create some example groups to show the user - matchees = list( - m for m in interaction.channel.members if matchee in m.roles) - groups = matching.members_to_groups( - matchees, State, members_min, allow_fallback=True) - # Post about all the groups with a button to send to the channel groups_list = '\n'.join(matching.group_to_message(g) for g in groups) msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}" view = discord.utils.MISSING - if not matcher: + if State.get_user_has_scope(interaction.user.id, state.AuthScope.MATCHER): # Let a non-matcher know why they don't have the button - msg += "\n\nYou'll need the 'Matcher' role to post this to the channel, sorry!" + msg += f"\n\nYou'll need the {state.AuthScope.MATCHER} scope to post this to the channel, sorry!" else: # Otherwise set up the button msg += "\n\nClick the button to match up groups and send them to the channel.\n" view = discord.ui.View(timeout=None) - view.add_item(DynamicGroupButton(members_min, matchee_role)) + view.add_item(DynamicGroupButton(members_min)) await interaction.response.send_message(msg, ephemeral=True, silent=True, view=view) logger.info("Done.") +# Increment when adjusting the custom_id so we don't confuse old users +_BUTTON_CUSTOM_ID_VERSION = 1 + + class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button], - template=r'match:min:(?P[0-9]+):role:(?P[@\w\s]+)'): - def __init__(self, min: int, role: str) -> None: + template=f'match:v{_BUTTON_CUSTOM_ID_VERSION}:' + r'min:(?P[0-9]+)'): + def __init__(self, min: int) -> None: super().__init__( discord.ui.Button( label='Match Groups!', style=discord.ButtonStyle.blurple, - custom_id=f'match:min:{min}:role:{role}', + custom_id=f'match:min:{min}', ) ) self.min: int = min - self.role: int = role # This is called when the button is clicked and the custom_id matches the template. @classmethod async def from_custom_id(cls, interaction: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /): min = int(match['min']) - role = str(match['role']) - return cls(min, role) + return cls(min) async def callback(self, interaction: discord.Interaction) -> None: """Match up people when the button is pressed""" - logger.info("Handling button press min=%s role=%s'", - self.min, self.role) + logger.info("Handling button press min=%s", self.min) logger.info("User %s from %s in #%s", interaction.user, interaction.guild.name, interaction.channel.name) # Let the user know we've recieved the message await interaction.response.send_message(content="Matchy is matching matchees...", ephemeral=True) - # Grab the role - matchee = matching.get_role_from_guild(interaction.guild, self.role) - - # Create our groups! - matchees = list( - m for m in interaction.channel.members if matchee in m.roles) - groups = matching.members_to_groups( - matchees, State, self.min, allow_fallback=True) + groups = active_members_to_groups(interaction.channel, self.min) # Send the groups for msg in (matching.group_to_message(g) for g in groups): @@ -167,10 +187,26 @@ class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button], await interaction.channel.send("That's all folks, happy matching and remember - DFTBA!") # Save the groups to the history - State.save_groups(groups) + State.log_groups(groups) + state.save_to_file(State, STATE_FILE) - logger.info("Done. Matched %s matchees into %s groups.", - len(matchees), len(groups)) + logger.info("Done! Matched into %s groups.", len(groups)) + + +def get_matchees_in_channel(channel: discord.channel): + """Fetches the matchees in a channel""" + # Gather up the prospective matchees + return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)] + + +def active_members_to_groups(channel: discord.channel, min_members: int): + """Helper to create groups from channel members""" + + # Gather up the prospective matchees + matchees = get_matchees_in_channel(channel) + + # Create our groups! + return matching.members_to_groups(matchees, State, min_members, allow_fallback=True) if __name__ == "__main__": diff --git a/state.py b/state.py index 524943a..344a00f 100644 --- a/state.py +++ b/state.py @@ -5,22 +5,65 @@ from schema import Schema, And, Use, Optional from typing import Protocol import files import copy +import logging + +logger = logging.getLogger("state") +logger.setLevel(logging.INFO) -_FILE = "state.json" # Warning: Changing any of the below needs proper thought to ensure backwards compatibility -_DEFAULT_DICT = { - "history": {}, - "matchees": {} -} +_VERSION = 1 + + +def _migrate_to_v1(d: dict): + logger.info("Renaming %s to %s", _Key.MATCHEES, _Key.USERS) + d[_Key.USERS] = d[_Key.MATCHEES] + del d[_Key.MATCHEES] + + +# Set of migration functions to apply +_MIGRATIONS = [ + _migrate_to_v1 +] + + +class AuthScope(str): + """Various auth scopes""" + OWNER = "owner" + MATCHER = "matcher" + + +class _Key(str): + """Various keys used in the schema""" + HISTORY = "history" + GROUPS = "groups" + MEMBERS = "members" + USERS = "users" + SCOPES = "scopes" + MATCHES = "matches" + ACTIVE = "active" + CHANNELS = "channels" + REACTIVATE = "reactivate" + VERSION = "version" + + # Unused + MATCHEES = "matchees" + + _TIME_FORMAT = "%a %b %d %H:%M:%S %Y" + + _SCHEMA = Schema( { - Optional("history"): { - Optional(str): { # a datetime - "groups": [ + # The current version + _Key.VERSION: And(Use(int)), + + Optional(_Key.HISTORY): { + # A datetime + Optional(str): { + _Key.GROUPS: [ { - "members": [ + _Key.MEMBERS: [ # The ID of each matchee in the match And(Use(int)) ] @@ -28,17 +71,33 @@ _SCHEMA = Schema( ] } }, - Optional("matchees"): { + Optional(_Key.USERS): { Optional(str): { - Optional("matches"): { + Optional(_Key.SCOPES): And(Use(list[str])), + Optional(_Key.MATCHES): { # Matchee ID and Datetime pair Optional(str): And(Use(str)) + }, + Optional(_Key.CHANNELS): { + # The channel ID + Optional(str): { + # Whether the user is signed up in this channel + _Key.ACTIVE: And(Use(bool)), + } } } - } + }, } ) +# Empty but schema-valid internal dict +_EMPTY_DICT = { + _Key.HISTORY: {}, + _Key.USERS: {}, + _Key.VERSION: _VERSION +} +assert _SCHEMA.validate(_EMPTY_DICT) + class Member(Protocol): @property @@ -51,75 +110,148 @@ def ts_to_datetime(ts: str) -> datetime: return datetime.strptime(ts, _TIME_FORMAT) -def validate(dict: dict): - """Initialise and validate the state""" - _SCHEMA.validate(dict) - - class State(): - def __init__(self, data: dict = _DEFAULT_DICT): + def __init__(self, data: dict = _EMPTY_DICT): """Initialise and validate the state""" - validate(data) - self.__dict__ = copy.deepcopy(data) + self.validate(data) + self._dict = copy.deepcopy(data) @property - def history(self) -> list[dict]: - return self.__dict__["history"] + def _history(self) -> dict[str]: + return self._dict[_Key.HISTORY] @property - def matchees(self) -> dict[str, dict]: - return self.__dict__["matchees"] + def _users(self) -> dict[str]: + return self._dict[_Key.USERS] - def save(self) -> None: - """Save out the state""" - files.save(_FILE, self.__dict__) + def validate(self, dict: dict = None): + """Initialise and validate a state dict""" + if not dict: + dict = self._dict + _SCHEMA.validate(dict) - def oldest_history(self) -> datetime: + def get_oldest_timestamp(self) -> datetime: """Grab the oldest timestamp in history""" - if not self.history: - return None - times = (ts_to_datetime(dt) for dt in self.history.keys()) - return sorted(times)[0] + times = (ts_to_datetime(dt) for dt in self._history.keys()) + return next(times, None) + + def get_user_matches(self, id: int) -> list[int]: + return self._users.get(str(id), {}).get(_Key.MATCHES, {}) def log_groups(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None: """Log the groups""" - tmp_state = State(self.__dict__) + tmp_state = State(self._dict) ts = datetime.strftime(ts, _TIME_FORMAT) # Grab or create the hitory item for this set of groups history_item = {} - tmp_state.history[ts] = history_item + tmp_state._history[ts] = history_item history_item_groups = [] - history_item["groups"] = history_item_groups + history_item[_Key.GROUPS] = history_item_groups for group in groups: # Add the group data history_item_groups.append({ - "members": list(m.id for m in group) + _Key.MEMBERS: [m.id for m in group] }) # Update the matchee data with the matches for m in group: - matchee = tmp_state.matchees.get(str(m.id), {}) - matchee_matches = matchee.get("matches", {}) + matchee = tmp_state._users.get(str(m.id), {}) + matchee_matches = matchee.get(_Key.MATCHES, {}) for o in (o for o in group if o.id != m.id): matchee_matches[str(o.id)] = ts - matchee["matches"] = matchee_matches - tmp_state.matchees[str(m.id)] = matchee + matchee[_Key.MATCHES] = matchee_matches + tmp_state._users[str(m.id)] = matchee # Validate before storing the result - validate(self.__dict__) - self.__dict__ = tmp_state.__dict__ + tmp_state.validate() + self._dict = tmp_state._dict - def save_groups(self, groups: list[list[Member]]) -> None: - """Save out the groups to the state file""" - self.log_groups(groups) - self.save() + def set_user_scope(self, id: str, scope: str, value: bool = True): + """Add an auth scope to a user""" + # Dive in + user = self._users.get(str(id), {}) + scopes = user.get(_Key.SCOPES, []) + + # Set the value + if value and scope not in scopes: + scopes.append(scope) + elif not value and scope in scopes: + scopes.remove(scope) + + # Roll out + user[_Key.SCOPES] = scopes + self._users[id] = user + + def get_user_has_scope(self, id: str, scope: str) -> bool: + """ + Check if a user has an auth scope + "owner" users have all scopes + """ + user = self._users.get(str(id), {}) + scopes = user.get(_Key.SCOPES, []) + return AuthScope.OWNER in scopes or scope in scopes + + def set_use_active_in_channel(self, id: str, channel_id: str, active: bool = True): + """Set a user as active (or not) on a given channel""" + # Dive in + user = self._users.get(str(id), {}) + channels = user.get(_Key.CHANNELS, {}) + channel = channels.get(str(channel_id), {}) + + # Set the value + channel[_Key.ACTIVE] = active + + # Unroll + channels[str(channel_id)] = channel + user[_Key.CHANNELS] = channels + self._users[str(id)] = user + + def get_user_active_in_channel(self, id: str, channel_id: str) -> bool: + """Get a users active channels""" + user = self._users.get(str(id), {}) + channels = user.get(_Key.CHANNELS, {}) + return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)] + + @property + def dict_internal(self) -> dict: + """Only to be used to get the internal dict as a copy""" + return copy.deepcopy(self._dict) -def load() -> State: - """Load the state""" - return State(files.load(_FILE) if os.path.isfile(_FILE) else _DEFAULT_DICT) +def _migrate(dict: dict): + """Migrate a dict through versions""" + version = dict.get("version", 0) + for i in range(version, _VERSION): + logger.info("Migrating from v%s to v%s", version, version+1) + _MIGRATIONS[i](dict) + dict[_Key.VERSION] = _VERSION + + +def load_from_file(file: str) -> State: + """ + Load the state from a file + Apply any required migrations + """ + loaded = _EMPTY_DICT + + # If there's a file load it and try to migrate + if os.path.isfile(file): + loaded = files.load(file) + _migrate(loaded) + + st = State(loaded) + + # Save out the migrated (or new) file + files.save(file, st._dict) + + return st + + +def save_to_file(state: State, file: str): + """Saves the state out to a file""" + files.save(file, state.dict_internal) diff --git a/state_test.py b/state_test.py new file mode 100644 index 0000000..bd97648 --- /dev/null +++ b/state_test.py @@ -0,0 +1,65 @@ +""" + Test functions for the state module +""" +import state +import tempfile +import os + + +def test_basic_state(): + """Simple validate basic state load""" + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, 'tmp.json') + state.load_from_file(path) + + +def test_simple_load_reload(): + """Test a basic load, save, reload""" + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, 'tmp.json') + st = state.load_from_file(path) + state.save_to_file(st, path) + + st = state.load_from_file(path) + state.save_to_file(st, path) + st = state.load_from_file(path) + + +def test_authscope(): + """Test setting and getting an auth scope""" + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, 'tmp.json') + st = state.load_from_file(path) + state.save_to_file(st, path) + + assert not st.get_user_has_scope(1, state.AuthScope.MATCHER) + + st = state.load_from_file(path) + st.set_user_scope(1, state.AuthScope.MATCHER) + state.save_to_file(st, path) + + st = state.load_from_file(path) + assert st.get_user_has_scope(1, state.AuthScope.MATCHER) + + st.set_user_scope(1, state.AuthScope.MATCHER, False) + assert not st.get_user_has_scope(1, state.AuthScope.MATCHER) + + +def test_channeljoin(): + """Test setting and getting an active channel""" + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, 'tmp.json') + st = state.load_from_file(path) + state.save_to_file(st, path) + + assert not st.get_user_active_in_channel(1, "2") + + st = state.load_from_file(path) + st.set_use_active_in_channel(1, "2", True) + state.save_to_file(st, path) + + st = state.load_from_file(path) + assert st.get_user_active_in_channel(1, "2") + + st.set_use_active_in_channel(1, "2", False) + assert not st.get_user_active_in_channel(1, "2")