Convert State to global

This was just getting too painful to manage, especially passing around these state objects
This commit is contained in:
Marc Di Luzio 2024-08-17 14:58:19 +01:00
parent f926a36069
commit 69005ef498
7 changed files with 69 additions and 71 deletions

View file

@ -5,13 +5,9 @@ import logging
import discord import discord
from discord.ext import commands from discord.ext import commands
import os import os
from matchy.state import load_from_file
import matchy.cogs.matcher import matchy.cogs.matcher
import matchy.cogs.owner import matchy.cogs.owner
_STATE_FILE = ".matchy/state.json"
state = load_from_file(_STATE_FILE)
logger = logging.getLogger("matchy") logger = logging.getLogger("matchy")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -24,8 +20,8 @@ bot = commands.Bot(command_prefix='$',
@bot.event @bot.event
async def setup_hook(): async def setup_hook():
await bot.add_cog(matchy.cogs.matcher.MatcherCog(bot, state)) await bot.add_cog(matchy.cogs.matcher.MatcherCog(bot))
await bot.add_cog(matchy.cogs.owner.OwnerCog(bot, state)) await bot.add_cog(matchy.cogs.owner.OwnerCog(bot))
@bot.event @bot.event

View file

@ -9,7 +9,7 @@ from datetime import datetime, timedelta, time
import re import re
import matchy.matching as matching import matchy.matching as matching
from matchy.state import State, AuthScope from matchy.state import AuthScope
import matchy.util as util import matchy.util as util
import matchy.state as state import matchy.state as state
@ -19,9 +19,8 @@ logger.setLevel(logging.INFO)
class MatcherCog(commands.Cog): class MatcherCog(commands.Cog):
def __init__(self, bot: commands.Bot, state: State): def __init__(self, bot: commands.Bot):
self.bot = bot self.bot = bot
self.state = state
@commands.Cog.listener() @commands.Cog.listener()
async def on_ready(self): async def on_ready(self):
@ -38,7 +37,7 @@ class MatcherCog(commands.Cog):
logger.info("Handling /join in %s %s from %s", logger.info("Handling /join in %s %s from %s",
interaction.guild.name, interaction.channel, interaction.user.name) interaction.guild.name, interaction.channel, interaction.user.name)
self.state.set_user_active_in_channel( state.State.set_user_active_in_channel(
interaction.user.id, interaction.channel.id) interaction.user.id, interaction.channel.id)
await interaction.response.send_message( await interaction.response.send_message(
f"Roger roger {interaction.user.mention}!\n" f"Roger roger {interaction.user.mention}!\n"
@ -51,7 +50,7 @@ class MatcherCog(commands.Cog):
logger.info("Handling /leave in %s %s from %s", logger.info("Handling /leave in %s %s from %s",
interaction.guild.name, interaction.channel, interaction.user.name) interaction.guild.name, interaction.channel, interaction.user.name)
self.state.set_user_active_in_channel( state.State.set_user_active_in_channel(
interaction.user.id, interaction.channel.id, False) interaction.user.id, interaction.channel.id, False)
await interaction.response.send_message( await interaction.response.send_message(
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True) f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
@ -66,7 +65,7 @@ class MatcherCog(commands.Cog):
if days is None: # Default to a week if days is None: # Default to a week
days = 7 days = 7
until = datetime.now() + timedelta(days=days) until = datetime.now() + timedelta(days=days)
self.state.set_user_paused_in_channel( state.State.set_user_paused_in_channel(
interaction.user.id, interaction.channel.id, until) interaction.user.id, interaction.channel.id, until)
await interaction.response.send_message( await interaction.response.send_message(
f"Sure thing {interaction.user.mention}!\n" f"Sure thing {interaction.user.mention}!\n"
@ -79,8 +78,7 @@ class MatcherCog(commands.Cog):
logger.info("Handling /list command in %s %s from %s", logger.info("Handling /list command in %s %s from %s",
interaction.guild.name, interaction.channel, interaction.user.name) interaction.guild.name, interaction.channel, interaction.user.name)
(matchees, paused) = matching.get_matchees_in_channel( (matchees, paused) = matching.get_matchees_in_channel(interaction.channel)
self.state, interaction.channel)
msg = "" msg = ""
@ -94,7 +92,7 @@ class MatcherCog(commands.Cog):
msg += f"\nThere are {len(mentions)} paused matchees:\n" msg += f"\nThere are {len(mentions)} paused matchees:\n"
msg += f"{util.format_list([m.mention for m in paused])}\n" msg += f"{util.format_list([m.mention for m in paused])}\n"
tasks = self.state.get_channel_match_tasks(interaction.channel.id) tasks = state.State.get_channel_match_tasks(interaction.channel.id)
for (day, hour, min) in tasks: for (day, hour, min) in tasks:
next_run = util.get_next_datetime(day, hour) next_run = util.get_next_datetime(day, hour)
date_str = util.datetime_as_discord_time(next_run) date_str = util.datetime_as_discord_time(next_run)
@ -128,13 +126,13 @@ class MatcherCog(commands.Cog):
channel_id = str(interaction.channel.id) channel_id = str(interaction.channel.id)
# Bail if not a matcher # Bail if not a matcher
if not self.state.get_user_has_scope(interaction.user.id, AuthScope.MATCHER): if not state.State.get_user_has_scope(interaction.user.id, AuthScope.MATCHER):
await interaction.response.send_message("You'll need the 'matcher' scope to schedule a match", await interaction.response.send_message("You'll need the 'matcher' scope to schedule a match",
ephemeral=True, silent=True) ephemeral=True, silent=True)
return return
# Add the scheduled task and save # Add the scheduled task and save
self.state.set_channel_match_task( state.State.set_channel_match_task(
channel_id, members_min, weekday, hour) channel_id, members_min, weekday, hour)
# Let the user know what happened # Let the user know what happened
@ -143,23 +141,26 @@ class MatcherCog(commands.Cog):
next_run = util.get_next_datetime(weekday, hour) next_run = util.get_next_datetime(weekday, hour)
date_str = util.datetime_as_discord_time(next_run) date_str = util.datetime_as_discord_time(next_run)
view = discord.ui.View(timeout=None)
view.add_item(ScheduleButton())
await interaction.response.send_message( await interaction.response.send_message(
f"Done :) Next run will be at {date_str}", f"Done :) Next run will be at {date_str}",
ephemeral=True, silent=True) ephemeral=True, silent=True, view=view)
@app_commands.command(description="Cancel all scheduled matches in this channel") @app_commands.command(description="Cancel all scheduled matches in this channel")
@commands.guild_only() @commands.guild_only()
async def cancel(self, interaction: discord.Interaction): async def cancel(self, interaction: discord.Interaction):
"""Cancel scheduled matches in this channel""" """Cancel scheduled matches in this channel"""
# Bail if not a matcher # Bail if not a matcher
if not self.state.get_user_has_scope(interaction.user.id, AuthScope.MATCHER): if not state.State.get_user_has_scope(interaction.user.id, AuthScope.MATCHER):
await interaction.response.send_message("You'll need the 'matcher' scope to remove scheduled matches", await interaction.response.send_message("You'll need the 'matcher' scope to remove scheduled matches",
ephemeral=True, silent=True) ephemeral=True, silent=True)
return return
# Add the scheduled task and save # Add the scheduled task and save
channel_id = str(interaction.channel.id) channel_id = str(interaction.channel.id)
self.state.remove_channel_match_tasks(channel_id) state.State.remove_channel_match_tasks(channel_id)
await interaction.response.send_message( await interaction.response.send_message(
"Done, all scheduled matches cleared in this channel!", "Done, all scheduled matches cleared in this channel!",
@ -181,7 +182,7 @@ class MatcherCog(commands.Cog):
# Grab the groups # Grab the groups
groups = matching.active_members_to_groups( groups = matching.active_members_to_groups(
self.state, interaction.channel, members_min) interaction.channel, members_min)
# Let the user know when there's nobody to match # Let the user know when there's nobody to match
if not groups: if not groups:
@ -194,7 +195,7 @@ class MatcherCog(commands.Cog):
msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}" msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}"
view = discord.utils.MISSING view = discord.utils.MISSING
if self.state.get_user_has_scope(interaction.user.id, AuthScope.MATCHER): if state.State.get_user_has_scope(interaction.user.id, AuthScope.MATCHER):
# Otherwise set up the button # Otherwise set up the button
msg += "\n\nClick the button to match up groups and send them to the channel.\n" msg += "\n\nClick the button to match up groups and send them to the channel.\n"
view = discord.ui.View(timeout=None) view = discord.ui.View(timeout=None)
@ -212,12 +213,12 @@ class MatcherCog(commands.Cog):
async def run_hourly_tasks(self): async def run_hourly_tasks(self):
"""Run any hourly tasks we have""" """Run any hourly tasks we have"""
for (channel, min) in self.state.get_active_match_tasks(): for (channel, min) in state.State.get_active_match_tasks():
logger.info("Scheduled match task triggered in %s", channel) logger.info("Scheduled match task triggered in %s", channel)
msg_channel = self.bot.get_channel(int(channel)) msg_channel = self.bot.get_channel(int(channel))
await matching.match_groups_in_channel(self.state, msg_channel, min) await matching.match_groups_in_channel(state.State, msg_channel, min)
for (channel, _) in self.state.get_active_match_tasks(datetime.now() + timedelta(days=1)): for (channel, _) in state.State.get_active_match_tasks(datetime.now() + timedelta(days=1)):
logger.info("Reminding about scheduled task in %s", channel) logger.info("Reminding about scheduled task in %s", channel)
msg_channel = self.bot.get_channel(int(channel)) msg_channel = self.bot.get_channel(int(channel))
await msg_channel.send("Arf arf! just a reminder I'll be doin a matcherino in here in T-24hrs!" await msg_channel.send("Arf arf! just a reminder I'll be doin a matcherino in here in T-24hrs!"
@ -244,7 +245,6 @@ class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
) )
) )
self.min: int = min self.min: int = min
self.state = state.load_from_file()
# This is called when the button is clicked and the custom_id matches the template. # This is called when the button is clicked and the custom_id matches the template.
@classmethod @classmethod

View file

@ -3,16 +3,15 @@ Owner bot cog
""" """
import logging import logging
from discord.ext import commands from discord.ext import commands
from matchy.state import State, AuthScope import matchy.state as state
logger = logging.getLogger("owner") logger = logging.getLogger("owner")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
class OwnerCog(commands.Cog): class OwnerCog(commands.Cog):
def __init__(self, bot: commands.Bot, state: State): def __init__(self, bot: commands.Bot):
self._bot = bot self._bot = bot
self._state = state
@commands.command() @commands.command()
@commands.dm_only() @commands.dm_only()
@ -48,7 +47,7 @@ class OwnerCog(commands.Cog):
Grant the matcher scope to a given user Grant the matcher scope to a given user
""" """
if user.isdigit(): if user.isdigit():
self._state.set_user_scope(str(user), AuthScope.MATCHER) state.State.set_user_scope(str(user), state.AuthScope.MATCHER)
logger.info("Granting user %s matcher scope", user) logger.info("Granting user %s matcher scope", user)
await ctx.reply("Done!", ephemeral=True) await ctx.reply("Done!", ephemeral=True)
else: else:

View file

@ -3,8 +3,8 @@ import logging
import discord import discord
from datetime import datetime from datetime import datetime
from typing import Protocol, runtime_checkable from typing import Protocol, runtime_checkable
from matchy.state import State, ts_to_datetime
import matchy.util as util import matchy.util as util
import matchy.state as state
class _ScoreFactors(int): class _ScoreFactors(int):
@ -95,7 +95,6 @@ def get_member_group_eligibility_score(member: Member,
def attempt_create_groups(matchees: list[Member], def attempt_create_groups(matchees: list[Member],
state: State,
oldest_relevant_ts: datetime, oldest_relevant_ts: datetime,
per_group: int) -> tuple[bool, list[list[Member]]]: per_group: int) -> tuple[bool, list[list[Member]]]:
"""History aware group matching""" """History aware group matching"""
@ -110,10 +109,10 @@ def attempt_create_groups(matchees: list[Member],
while matchees_left: while matchees_left:
# Get the next matchee to place # Get the next matchee to place
matchee = matchees_left.pop() matchee = matchees_left.pop()
matchee_matches = state.get_user_matches(matchee.id) matchee_matches = state.State.get_user_matches(matchee.id)
relevant_matches = [int(id) for id, ts relevant_matches = [int(id) for id, ts
in matchee_matches.items() in matchee_matches.items()
if ts_to_datetime(ts) >= oldest_relevant_ts] if state.ts_to_datetime(ts) >= oldest_relevant_ts]
# Try every single group from the current group onwards # Try every single group from the current group onwards
# Progressing through the groups like this ensures we slowly fill them up with compatible people # Progressing through the groups like this ensures we slowly fill them up with compatible people
@ -143,7 +142,6 @@ def attempt_create_groups(matchees: list[Member],
def members_to_groups(matchees: list[Member], def members_to_groups(matchees: list[Member],
state: State,
per_group: int = 3, per_group: int = 3,
allow_fallback: bool = False) -> list[list[Member]]: allow_fallback: bool = False) -> list[list[Member]]:
"""Generate the groups from the set of matchees""" """Generate the groups from the set of matchees"""
@ -155,14 +153,14 @@ def members_to_groups(matchees: list[Member],
return [] return []
# Walk from the start of history until now trying to match up groups # Walk from the start of history until now trying to match up groups
for oldest_relevant_datetime in state.get_history_timestamps(matchees) + [datetime.now()]: for oldest_relevant_datetime in state.State.get_history_timestamps(matchees) + [datetime.now()]:
# Attempt with each starting matchee # Attempt with each starting matchee
for shifted_matchees in util.iterate_all_shifts(matchees): for shifted_matchees in util.iterate_all_shifts(matchees):
attempts += 1 attempts += 1
groups = attempt_create_groups( groups = attempt_create_groups(
shifted_matchees, state, oldest_relevant_datetime, per_group) shifted_matchees, oldest_relevant_datetime, per_group)
# Fail the match if our groups aren't big enough # Fail the match if our groups aren't big enough
if num_groups <= 1 or (groups and all(len(g) >= per_group for g in groups)): if num_groups <= 1 or (groups and all(len(g) >= per_group for g in groups)):
@ -179,9 +177,9 @@ def members_to_groups(matchees: list[Member],
assert False assert False
async def match_groups_in_channel(state: State, channel: discord.channel, min: int): async def match_groups_in_channel(channel: discord.channel, min: int):
"""Match up the groups in a given channel""" """Match up the groups in a given channel"""
groups = active_members_to_groups(state, channel, min) groups = active_members_to_groups(channel, min)
# Send the groups # Send the groups
for group in groups: for group in groups:
@ -197,24 +195,26 @@ async def match_groups_in_channel(state: State, channel: discord.channel, min: i
# Close off with a message # Close off with a message
await channel.send("That's all folks, happy matching and remember - DFTBA!") await channel.send("That's all folks, happy matching and remember - DFTBA!")
# Save the groups to the history # Save the groups to the history
state.log_groups(groups) state.State.log_groups(groups)
logger.info("Done! Matched into %s groups.", len(groups)) logger.info("Done! Matched into %s groups.", len(groups))
def get_matchees_in_channel(state: State, channel: discord.channel): def get_matchees_in_channel(channel: discord.channel):
"""Fetches the matchees in a channel""" """Fetches the matchees in a channel"""
# Reactivate any unpaused users # Reactivate any unpaused users
state.reactivate_users(channel.id) state.State.reactivate_users(channel.id)
# Gather up the prospective matchees # Gather up the prospective matchees
active = [m for m in channel.members if state.get_user_active_in_channel(m.id, channel.id)] active = [m for m in channel.members if state.State.get_user_active_in_channel(
paused = [m for m in channel.members if state.get_user_paused_in_channel(m.id, channel.id)] m.id, channel.id)]
paused = [m for m in channel.members if state.State.get_user_paused_in_channel(
m.id, channel.id)]
return (active, paused) return (active, paused)
def active_members_to_groups(state: State, channel: discord.channel, min_members: int): def active_members_to_groups(channel: discord.channel, min_members: int):
"""Helper to create groups from channel members""" """Helper to create groups from channel members"""
# Gather up the prospective matchees # Gather up the prospective matchees
matchees = get_matchees_in_channel(state, channel) matchees = get_matchees_in_channel(channel)
# Create our groups! # Create our groups!
return members_to_groups(matchees, state, min_members, allow_fallback=True) return members_to_groups(matchees, min_members, allow_fallback=True)

View file

@ -15,7 +15,6 @@ import matchy.util as util
logger = logging.getLogger("state") logger = logging.getLogger("state")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility # Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_VERSION = 4 _VERSION = 4
@ -193,7 +192,7 @@ def _save(file: str, content: dict):
shutil.move(intermediate, file) shutil.move(intermediate, file)
class State(): class _State():
def __init__(self, data: dict, file: str | None = None): def __init__(self, data: dict, file: str | None = None):
"""Copy the data, migrate if needed, and validate""" """Copy the data, migrate if needed, and validate"""
self._dict = copy.deepcopy(data) self._dict = copy.deepcopy(data)
@ -216,7 +215,7 @@ class State():
""" """
@wraps(func) @wraps(func)
def inner(self, *args, **kwargs): def inner(self, *args, **kwargs):
tmp = State(self._dict, self._file) tmp = _State(self._dict, self._file)
func(tmp, *args, **kwargs) func(tmp, *args, **kwargs)
_SCHEMA.validate(tmp._dict) _SCHEMA.validate(tmp._dict)
if tmp._file: if tmp._file:
@ -380,11 +379,15 @@ class State():
return self._dict[_Key.TASKS] return self._dict[_Key.TASKS]
def load_from_file(file: str) -> State: def load_from_file(file: str) -> _State:
""" """
Load the state from a files Load the state from a files
""" """
loaded = _load(file) if os.path.isfile(file) else _EMPTY_DICT loaded = _load(file) if os.path.isfile(file) else _EMPTY_DICT
st = State(loaded, file) st = _State(loaded, file)
_save(file, st._dict) _save(file, st._dict)
return st return st
_STATE_FILE = ".matchy/state.json"
State = load_from_file(_STATE_FILE)

View file

@ -11,6 +11,12 @@ import itertools
from datetime import datetime, timedelta from datetime import datetime, timedelta
@pytest.fixture(autouse=True)
def clean_state():
"""Ensure every single one of these tests has a clean state"""
state.State = state._State(state._EMPTY_DICT)
def test_protocols(): def test_protocols():
"""Verify the protocols we're using match the discord ones""" """Verify the protocols we're using match the discord ones"""
assert isinstance(discord.Member, matching.Member) assert isinstance(discord.Member, matching.Member)
@ -59,16 +65,16 @@ class Member():
return self._id return self._id
def members_to_groups_validate(matchees: list[Member], tmp_state: state.State, per_group: int): def members_to_groups_validate(matchees: list[Member], per_group: int):
"""Inner function to validate the main output of the groups function""" """Inner function to validate the main output of the groups function"""
groups = matching.members_to_groups(matchees, tmp_state, per_group) groups = matching.members_to_groups(matchees, per_group)
# We should always have one group # We should always have one group
assert len(groups) assert len(groups)
# Log the groups to history # Log the groups to history
# This will validate the internals # This will validate the internals
tmp_state.log_groups(groups) state.State.log_groups(groups)
# Ensure each group contains within the bounds of expected members # Ensure each group contains within the bounds of expected members
for group in groups: for group in groups:
@ -96,8 +102,7 @@ def members_to_groups_validate(matchees: list[Member], tmp_state: state.State, p
], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"]) ], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"])
def test_members_to_groups_no_history(matchees, per_group): def test_members_to_groups_no_history(matchees, per_group):
"""Test simple group matching works""" """Test simple group matching works"""
tmp_state = state.State(state._EMPTY_DICT) members_to_groups_validate(matchees, per_group)
members_to_groups_validate(matchees, tmp_state, per_group)
def items_found_in_lists(list_of_lists, items): def items_found_in_lists(list_of_lists, items):
@ -328,13 +333,12 @@ def items_found_in_lists(list_of_lists, items):
], ids=['simple_history', 'fallback', 'example_1', 'example_2', 'example_3']) ], ids=['simple_history', 'fallback', 'example_1', 'example_2', 'example_3'])
def test_unique_regressions(history_data, matchees, per_group, checks): def test_unique_regressions(history_data, matchees, per_group, checks):
"""Test a bunch of unqiue failures that happened in the past""" """Test a bunch of unqiue failures that happened in the past"""
tmp_state = state.State(state._EMPTY_DICT)
# Replay the history # Replay the history
for d in history_data: for d in history_data:
tmp_state.log_groups(d["groups"], d["ts"]) state.State.log_groups(d["groups"], d["ts"])
groups = members_to_groups_validate(matchees, tmp_state, per_group) groups = members_to_groups_validate(matchees, per_group)
# Run the custom validate functions # Run the custom validate functions
for check in checks: for check in checks:
@ -380,28 +384,25 @@ def test_stess_random_groups(per_group, num_members, num_history):
member.roles = [Role(i) for i in rand.sample(range(1, 8), 3)] member.roles = [Role(i) for i in rand.sample(range(1, 8), 3)]
# For each history item match up groups and log those # For each history item match up groups and log those
cumulative_state = state.State(state._EMPTY_DICT)
for i in range(num_history+1): for i in range(num_history+1):
# Grab the num of members and replay # Grab the num of members and replay
rand.shuffle(possible_members) rand.shuffle(possible_members)
members = copy.deepcopy(possible_members[:num_members]) members = copy.deepcopy(possible_members[:num_members])
groups = members_to_groups_validate( groups = members_to_groups_validate(members, per_group)
members, cumulative_state, per_group) state.State.log_groups(
cumulative_state.log_groups(
groups, datetime.now() - timedelta(days=num_history-i)) groups, datetime.now() - timedelta(days=num_history-i))
def test_auth_scopes(): def test_auth_scopes():
tmp_state = state.State(state._EMPTY_DICT)
id = "1" id = "1"
assert not tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER) assert not state.State.get_user_has_scope(id, state.AuthScope.MATCHER)
id = "2" id = "2"
tmp_state.set_user_scope(id, state.AuthScope.MATCHER) state.State.set_user_scope(id, state.AuthScope.MATCHER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER) assert state.State.get_user_has_scope(id, state.AuthScope.MATCHER)
# Validate the state by constucting a new one # Validate the state by constucting a new one
_ = state.State(tmp_state._dict) _ = state._State(state.State._dict)

View file

@ -2,7 +2,6 @@ import discord
import discord.ext.commands as commands import discord.ext.commands as commands
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import matchy.state as state
import discord.ext.test as dpytest import discord.ext.test as dpytest
from matchy.cogs.owner import OwnerCog from matchy.cogs.owner import OwnerCog
@ -20,7 +19,7 @@ async def bot():
b = commands.Bot(command_prefix="$", b = commands.Bot(command_prefix="$",
intents=intents) intents=intents)
await b._async_setup_hook() await b._async_setup_hook()
await b.add_cog(OwnerCog(b, state.State(state._EMPTY_DICT))) await b.add_cog(OwnerCog(b))
dpytest.configure(b) dpytest.configure(b)
yield b yield b
await dpytest.empty_queue() await dpytest.empty_queue()