Convert State to global
This was just getting too painful to manage, especially passing around these state objects
This commit is contained in:
parent
f926a36069
commit
69005ef498
7 changed files with 69 additions and 71 deletions
|
@ -5,13 +5,9 @@ import logging
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
import os
|
import os
|
||||||
from matchy.state import load_from_file
|
|
||||||
import matchy.cogs.matcher
|
import matchy.cogs.matcher
|
||||||
import matchy.cogs.owner
|
import matchy.cogs.owner
|
||||||
|
|
||||||
_STATE_FILE = ".matchy/state.json"
|
|
||||||
state = load_from_file(_STATE_FILE)
|
|
||||||
|
|
||||||
logger = logging.getLogger("matchy")
|
logger = logging.getLogger("matchy")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
@ -24,8 +20,8 @@ bot = commands.Bot(command_prefix='$',
|
||||||
|
|
||||||
@bot.event
|
@bot.event
|
||||||
async def setup_hook():
|
async def setup_hook():
|
||||||
await bot.add_cog(matchy.cogs.matcher.MatcherCog(bot, state))
|
await bot.add_cog(matchy.cogs.matcher.MatcherCog(bot))
|
||||||
await bot.add_cog(matchy.cogs.owner.OwnerCog(bot, state))
|
await bot.add_cog(matchy.cogs.owner.OwnerCog(bot))
|
||||||
|
|
||||||
|
|
||||||
@bot.event
|
@bot.event
|
||||||
|
|
|
@ -9,7 +9,7 @@ from datetime import datetime, timedelta, time
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import matchy.matching as matching
|
import matchy.matching as matching
|
||||||
from matchy.state import State, AuthScope
|
from matchy.state import AuthScope
|
||||||
import matchy.util as util
|
import matchy.util as util
|
||||||
import matchy.state as state
|
import matchy.state as state
|
||||||
|
|
||||||
|
@ -19,9 +19,8 @@ logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class MatcherCog(commands.Cog):
|
class MatcherCog(commands.Cog):
|
||||||
def __init__(self, bot: commands.Bot, state: State):
|
def __init__(self, bot: commands.Bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.state = state
|
|
||||||
|
|
||||||
@commands.Cog.listener()
|
@commands.Cog.listener()
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
|
@ -38,7 +37,7 @@ class MatcherCog(commands.Cog):
|
||||||
logger.info("Handling /join in %s %s from %s",
|
logger.info("Handling /join in %s %s from %s",
|
||||||
interaction.guild.name, interaction.channel, interaction.user.name)
|
interaction.guild.name, interaction.channel, interaction.user.name)
|
||||||
|
|
||||||
self.state.set_user_active_in_channel(
|
state.State.set_user_active_in_channel(
|
||||||
interaction.user.id, interaction.channel.id)
|
interaction.user.id, interaction.channel.id)
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
f"Roger roger {interaction.user.mention}!\n"
|
f"Roger roger {interaction.user.mention}!\n"
|
||||||
|
@ -51,7 +50,7 @@ class MatcherCog(commands.Cog):
|
||||||
logger.info("Handling /leave in %s %s from %s",
|
logger.info("Handling /leave in %s %s from %s",
|
||||||
interaction.guild.name, interaction.channel, interaction.user.name)
|
interaction.guild.name, interaction.channel, interaction.user.name)
|
||||||
|
|
||||||
self.state.set_user_active_in_channel(
|
state.State.set_user_active_in_channel(
|
||||||
interaction.user.id, interaction.channel.id, False)
|
interaction.user.id, interaction.channel.id, False)
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
|
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
|
||||||
|
@ -66,7 +65,7 @@ class MatcherCog(commands.Cog):
|
||||||
if days is None: # Default to a week
|
if days is None: # Default to a week
|
||||||
days = 7
|
days = 7
|
||||||
until = datetime.now() + timedelta(days=days)
|
until = datetime.now() + timedelta(days=days)
|
||||||
self.state.set_user_paused_in_channel(
|
state.State.set_user_paused_in_channel(
|
||||||
interaction.user.id, interaction.channel.id, until)
|
interaction.user.id, interaction.channel.id, until)
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
f"Sure thing {interaction.user.mention}!\n"
|
f"Sure thing {interaction.user.mention}!\n"
|
||||||
|
@ -79,8 +78,7 @@ class MatcherCog(commands.Cog):
|
||||||
logger.info("Handling /list command in %s %s from %s",
|
logger.info("Handling /list command in %s %s from %s",
|
||||||
interaction.guild.name, interaction.channel, interaction.user.name)
|
interaction.guild.name, interaction.channel, interaction.user.name)
|
||||||
|
|
||||||
(matchees, paused) = matching.get_matchees_in_channel(
|
(matchees, paused) = matching.get_matchees_in_channel(interaction.channel)
|
||||||
self.state, interaction.channel)
|
|
||||||
|
|
||||||
msg = ""
|
msg = ""
|
||||||
|
|
||||||
|
@ -94,7 +92,7 @@ class MatcherCog(commands.Cog):
|
||||||
msg += f"\nThere are {len(mentions)} paused matchees:\n"
|
msg += f"\nThere are {len(mentions)} paused matchees:\n"
|
||||||
msg += f"{util.format_list([m.mention for m in paused])}\n"
|
msg += f"{util.format_list([m.mention for m in paused])}\n"
|
||||||
|
|
||||||
tasks = self.state.get_channel_match_tasks(interaction.channel.id)
|
tasks = state.State.get_channel_match_tasks(interaction.channel.id)
|
||||||
for (day, hour, min) in tasks:
|
for (day, hour, min) in tasks:
|
||||||
next_run = util.get_next_datetime(day, hour)
|
next_run = util.get_next_datetime(day, hour)
|
||||||
date_str = util.datetime_as_discord_time(next_run)
|
date_str = util.datetime_as_discord_time(next_run)
|
||||||
|
@ -128,13 +126,13 @@ class MatcherCog(commands.Cog):
|
||||||
channel_id = str(interaction.channel.id)
|
channel_id = str(interaction.channel.id)
|
||||||
|
|
||||||
# Bail if not a matcher
|
# Bail if not a matcher
|
||||||
if not self.state.get_user_has_scope(interaction.user.id, AuthScope.MATCHER):
|
if not state.State.get_user_has_scope(interaction.user.id, AuthScope.MATCHER):
|
||||||
await interaction.response.send_message("You'll need the 'matcher' scope to schedule a match",
|
await interaction.response.send_message("You'll need the 'matcher' scope to schedule a match",
|
||||||
ephemeral=True, silent=True)
|
ephemeral=True, silent=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Add the scheduled task and save
|
# Add the scheduled task and save
|
||||||
self.state.set_channel_match_task(
|
state.State.set_channel_match_task(
|
||||||
channel_id, members_min, weekday, hour)
|
channel_id, members_min, weekday, hour)
|
||||||
|
|
||||||
# Let the user know what happened
|
# Let the user know what happened
|
||||||
|
@ -143,23 +141,26 @@ class MatcherCog(commands.Cog):
|
||||||
next_run = util.get_next_datetime(weekday, hour)
|
next_run = util.get_next_datetime(weekday, hour)
|
||||||
date_str = util.datetime_as_discord_time(next_run)
|
date_str = util.datetime_as_discord_time(next_run)
|
||||||
|
|
||||||
|
view = discord.ui.View(timeout=None)
|
||||||
|
view.add_item(ScheduleButton())
|
||||||
|
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
f"Done :) Next run will be at {date_str}",
|
f"Done :) Next run will be at {date_str}",
|
||||||
ephemeral=True, silent=True)
|
ephemeral=True, silent=True, view=view)
|
||||||
|
|
||||||
@app_commands.command(description="Cancel all scheduled matches in this channel")
|
@app_commands.command(description="Cancel all scheduled matches in this channel")
|
||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
async def cancel(self, interaction: discord.Interaction):
|
async def cancel(self, interaction: discord.Interaction):
|
||||||
"""Cancel scheduled matches in this channel"""
|
"""Cancel scheduled matches in this channel"""
|
||||||
# Bail if not a matcher
|
# Bail if not a matcher
|
||||||
if not self.state.get_user_has_scope(interaction.user.id, AuthScope.MATCHER):
|
if not state.State.get_user_has_scope(interaction.user.id, AuthScope.MATCHER):
|
||||||
await interaction.response.send_message("You'll need the 'matcher' scope to remove scheduled matches",
|
await interaction.response.send_message("You'll need the 'matcher' scope to remove scheduled matches",
|
||||||
ephemeral=True, silent=True)
|
ephemeral=True, silent=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Add the scheduled task and save
|
# Add the scheduled task and save
|
||||||
channel_id = str(interaction.channel.id)
|
channel_id = str(interaction.channel.id)
|
||||||
self.state.remove_channel_match_tasks(channel_id)
|
state.State.remove_channel_match_tasks(channel_id)
|
||||||
|
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
"Done, all scheduled matches cleared in this channel!",
|
"Done, all scheduled matches cleared in this channel!",
|
||||||
|
@ -181,7 +182,7 @@ class MatcherCog(commands.Cog):
|
||||||
|
|
||||||
# Grab the groups
|
# Grab the groups
|
||||||
groups = matching.active_members_to_groups(
|
groups = matching.active_members_to_groups(
|
||||||
self.state, interaction.channel, members_min)
|
interaction.channel, members_min)
|
||||||
|
|
||||||
# Let the user know when there's nobody to match
|
# Let the user know when there's nobody to match
|
||||||
if not groups:
|
if not groups:
|
||||||
|
@ -194,7 +195,7 @@ class MatcherCog(commands.Cog):
|
||||||
msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}"
|
msg = f"Roger! I've generated example groups for ya:\n\n{groups_list}"
|
||||||
view = discord.utils.MISSING
|
view = discord.utils.MISSING
|
||||||
|
|
||||||
if self.state.get_user_has_scope(interaction.user.id, AuthScope.MATCHER):
|
if state.State.get_user_has_scope(interaction.user.id, AuthScope.MATCHER):
|
||||||
# Otherwise set up the button
|
# Otherwise set up the button
|
||||||
msg += "\n\nClick the button to match up groups and send them to the channel.\n"
|
msg += "\n\nClick the button to match up groups and send them to the channel.\n"
|
||||||
view = discord.ui.View(timeout=None)
|
view = discord.ui.View(timeout=None)
|
||||||
|
@ -212,12 +213,12 @@ class MatcherCog(commands.Cog):
|
||||||
async def run_hourly_tasks(self):
|
async def run_hourly_tasks(self):
|
||||||
"""Run any hourly tasks we have"""
|
"""Run any hourly tasks we have"""
|
||||||
|
|
||||||
for (channel, min) in self.state.get_active_match_tasks():
|
for (channel, min) in state.State.get_active_match_tasks():
|
||||||
logger.info("Scheduled match task triggered in %s", channel)
|
logger.info("Scheduled match task triggered in %s", channel)
|
||||||
msg_channel = self.bot.get_channel(int(channel))
|
msg_channel = self.bot.get_channel(int(channel))
|
||||||
await matching.match_groups_in_channel(self.state, msg_channel, min)
|
await matching.match_groups_in_channel(state.State, msg_channel, min)
|
||||||
|
|
||||||
for (channel, _) in self.state.get_active_match_tasks(datetime.now() + timedelta(days=1)):
|
for (channel, _) in state.State.get_active_match_tasks(datetime.now() + timedelta(days=1)):
|
||||||
logger.info("Reminding about scheduled task in %s", channel)
|
logger.info("Reminding about scheduled task in %s", channel)
|
||||||
msg_channel = self.bot.get_channel(int(channel))
|
msg_channel = self.bot.get_channel(int(channel))
|
||||||
await msg_channel.send("Arf arf! just a reminder I'll be doin a matcherino in here in T-24hrs!"
|
await msg_channel.send("Arf arf! just a reminder I'll be doin a matcherino in here in T-24hrs!"
|
||||||
|
@ -244,7 +245,6 @@ class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.min: int = min
|
self.min: int = min
|
||||||
self.state = state.load_from_file()
|
|
||||||
|
|
||||||
# This is called when the button is clicked and the custom_id matches the template.
|
# This is called when the button is clicked and the custom_id matches the template.
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -3,16 +3,15 @@ Owner bot cog
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from matchy.state import State, AuthScope
|
import matchy.state as state
|
||||||
|
|
||||||
logger = logging.getLogger("owner")
|
logger = logging.getLogger("owner")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class OwnerCog(commands.Cog):
|
class OwnerCog(commands.Cog):
|
||||||
def __init__(self, bot: commands.Bot, state: State):
|
def __init__(self, bot: commands.Bot):
|
||||||
self._bot = bot
|
self._bot = bot
|
||||||
self._state = state
|
|
||||||
|
|
||||||
@commands.command()
|
@commands.command()
|
||||||
@commands.dm_only()
|
@commands.dm_only()
|
||||||
|
@ -48,7 +47,7 @@ class OwnerCog(commands.Cog):
|
||||||
Grant the matcher scope to a given user
|
Grant the matcher scope to a given user
|
||||||
"""
|
"""
|
||||||
if user.isdigit():
|
if user.isdigit():
|
||||||
self._state.set_user_scope(str(user), AuthScope.MATCHER)
|
state.State.set_user_scope(str(user), state.AuthScope.MATCHER)
|
||||||
logger.info("Granting user %s matcher scope", user)
|
logger.info("Granting user %s matcher scope", user)
|
||||||
await ctx.reply("Done!", ephemeral=True)
|
await ctx.reply("Done!", ephemeral=True)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -3,8 +3,8 @@ import logging
|
||||||
import discord
|
import discord
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Protocol, runtime_checkable
|
from typing import Protocol, runtime_checkable
|
||||||
from matchy.state import State, ts_to_datetime
|
|
||||||
import matchy.util as util
|
import matchy.util as util
|
||||||
|
import matchy.state as state
|
||||||
|
|
||||||
|
|
||||||
class _ScoreFactors(int):
|
class _ScoreFactors(int):
|
||||||
|
@ -95,7 +95,6 @@ def get_member_group_eligibility_score(member: Member,
|
||||||
|
|
||||||
|
|
||||||
def attempt_create_groups(matchees: list[Member],
|
def attempt_create_groups(matchees: list[Member],
|
||||||
state: State,
|
|
||||||
oldest_relevant_ts: datetime,
|
oldest_relevant_ts: datetime,
|
||||||
per_group: int) -> tuple[bool, list[list[Member]]]:
|
per_group: int) -> tuple[bool, list[list[Member]]]:
|
||||||
"""History aware group matching"""
|
"""History aware group matching"""
|
||||||
|
@ -110,10 +109,10 @@ def attempt_create_groups(matchees: list[Member],
|
||||||
while matchees_left:
|
while matchees_left:
|
||||||
# Get the next matchee to place
|
# Get the next matchee to place
|
||||||
matchee = matchees_left.pop()
|
matchee = matchees_left.pop()
|
||||||
matchee_matches = state.get_user_matches(matchee.id)
|
matchee_matches = state.State.get_user_matches(matchee.id)
|
||||||
relevant_matches = [int(id) for id, ts
|
relevant_matches = [int(id) for id, ts
|
||||||
in matchee_matches.items()
|
in matchee_matches.items()
|
||||||
if ts_to_datetime(ts) >= oldest_relevant_ts]
|
if state.ts_to_datetime(ts) >= oldest_relevant_ts]
|
||||||
|
|
||||||
# Try every single group from the current group onwards
|
# Try every single group from the current group onwards
|
||||||
# Progressing through the groups like this ensures we slowly fill them up with compatible people
|
# Progressing through the groups like this ensures we slowly fill them up with compatible people
|
||||||
|
@ -143,7 +142,6 @@ def attempt_create_groups(matchees: list[Member],
|
||||||
|
|
||||||
|
|
||||||
def members_to_groups(matchees: list[Member],
|
def members_to_groups(matchees: list[Member],
|
||||||
state: State,
|
|
||||||
per_group: int = 3,
|
per_group: int = 3,
|
||||||
allow_fallback: bool = False) -> list[list[Member]]:
|
allow_fallback: bool = False) -> list[list[Member]]:
|
||||||
"""Generate the groups from the set of matchees"""
|
"""Generate the groups from the set of matchees"""
|
||||||
|
@ -155,14 +153,14 @@ def members_to_groups(matchees: list[Member],
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Walk from the start of history until now trying to match up groups
|
# Walk from the start of history until now trying to match up groups
|
||||||
for oldest_relevant_datetime in state.get_history_timestamps(matchees) + [datetime.now()]:
|
for oldest_relevant_datetime in state.State.get_history_timestamps(matchees) + [datetime.now()]:
|
||||||
|
|
||||||
# Attempt with each starting matchee
|
# Attempt with each starting matchee
|
||||||
for shifted_matchees in util.iterate_all_shifts(matchees):
|
for shifted_matchees in util.iterate_all_shifts(matchees):
|
||||||
|
|
||||||
attempts += 1
|
attempts += 1
|
||||||
groups = attempt_create_groups(
|
groups = attempt_create_groups(
|
||||||
shifted_matchees, state, oldest_relevant_datetime, per_group)
|
shifted_matchees, oldest_relevant_datetime, per_group)
|
||||||
|
|
||||||
# Fail the match if our groups aren't big enough
|
# Fail the match if our groups aren't big enough
|
||||||
if num_groups <= 1 or (groups and all(len(g) >= per_group for g in groups)):
|
if num_groups <= 1 or (groups and all(len(g) >= per_group for g in groups)):
|
||||||
|
@ -179,9 +177,9 @@ def members_to_groups(matchees: list[Member],
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
|
||||||
async def match_groups_in_channel(state: State, channel: discord.channel, min: int):
|
async def match_groups_in_channel(channel: discord.channel, min: int):
|
||||||
"""Match up the groups in a given channel"""
|
"""Match up the groups in a given channel"""
|
||||||
groups = active_members_to_groups(state, channel, min)
|
groups = active_members_to_groups(channel, min)
|
||||||
|
|
||||||
# Send the groups
|
# Send the groups
|
||||||
for group in groups:
|
for group in groups:
|
||||||
|
@ -197,24 +195,26 @@ async def match_groups_in_channel(state: State, channel: discord.channel, min: i
|
||||||
# Close off with a message
|
# Close off with a message
|
||||||
await channel.send("That's all folks, happy matching and remember - DFTBA!")
|
await channel.send("That's all folks, happy matching and remember - DFTBA!")
|
||||||
# Save the groups to the history
|
# Save the groups to the history
|
||||||
state.log_groups(groups)
|
state.State.log_groups(groups)
|
||||||
|
|
||||||
logger.info("Done! Matched into %s groups.", len(groups))
|
logger.info("Done! Matched into %s groups.", len(groups))
|
||||||
|
|
||||||
|
|
||||||
def get_matchees_in_channel(state: State, channel: discord.channel):
|
def get_matchees_in_channel(channel: discord.channel):
|
||||||
"""Fetches the matchees in a channel"""
|
"""Fetches the matchees in a channel"""
|
||||||
# Reactivate any unpaused users
|
# Reactivate any unpaused users
|
||||||
state.reactivate_users(channel.id)
|
state.State.reactivate_users(channel.id)
|
||||||
# Gather up the prospective matchees
|
# Gather up the prospective matchees
|
||||||
active = [m for m in channel.members if state.get_user_active_in_channel(m.id, channel.id)]
|
active = [m for m in channel.members if state.State.get_user_active_in_channel(
|
||||||
paused = [m for m in channel.members if state.get_user_paused_in_channel(m.id, channel.id)]
|
m.id, channel.id)]
|
||||||
|
paused = [m for m in channel.members if state.State.get_user_paused_in_channel(
|
||||||
|
m.id, channel.id)]
|
||||||
return (active, paused)
|
return (active, paused)
|
||||||
|
|
||||||
|
|
||||||
def active_members_to_groups(state: State, channel: discord.channel, min_members: int):
|
def active_members_to_groups(channel: discord.channel, min_members: int):
|
||||||
"""Helper to create groups from channel members"""
|
"""Helper to create groups from channel members"""
|
||||||
# Gather up the prospective matchees
|
# Gather up the prospective matchees
|
||||||
matchees = get_matchees_in_channel(state, channel)
|
matchees = get_matchees_in_channel(channel)
|
||||||
# Create our groups!
|
# Create our groups!
|
||||||
return members_to_groups(matchees, state, min_members, allow_fallback=True)
|
return members_to_groups(matchees, min_members, allow_fallback=True)
|
||||||
|
|
|
@ -15,7 +15,6 @@ import matchy.util as util
|
||||||
logger = logging.getLogger("state")
|
logger = logging.getLogger("state")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
|
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
|
||||||
_VERSION = 4
|
_VERSION = 4
|
||||||
|
|
||||||
|
@ -193,7 +192,7 @@ def _save(file: str, content: dict):
|
||||||
shutil.move(intermediate, file)
|
shutil.move(intermediate, file)
|
||||||
|
|
||||||
|
|
||||||
class State():
|
class _State():
|
||||||
def __init__(self, data: dict, file: str | None = None):
|
def __init__(self, data: dict, file: str | None = None):
|
||||||
"""Copy the data, migrate if needed, and validate"""
|
"""Copy the data, migrate if needed, and validate"""
|
||||||
self._dict = copy.deepcopy(data)
|
self._dict = copy.deepcopy(data)
|
||||||
|
@ -216,7 +215,7 @@ class State():
|
||||||
"""
|
"""
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def inner(self, *args, **kwargs):
|
def inner(self, *args, **kwargs):
|
||||||
tmp = State(self._dict, self._file)
|
tmp = _State(self._dict, self._file)
|
||||||
func(tmp, *args, **kwargs)
|
func(tmp, *args, **kwargs)
|
||||||
_SCHEMA.validate(tmp._dict)
|
_SCHEMA.validate(tmp._dict)
|
||||||
if tmp._file:
|
if tmp._file:
|
||||||
|
@ -380,11 +379,15 @@ class State():
|
||||||
return self._dict[_Key.TASKS]
|
return self._dict[_Key.TASKS]
|
||||||
|
|
||||||
|
|
||||||
def load_from_file(file: str) -> State:
|
def load_from_file(file: str) -> _State:
|
||||||
"""
|
"""
|
||||||
Load the state from a files
|
Load the state from a files
|
||||||
"""
|
"""
|
||||||
loaded = _load(file) if os.path.isfile(file) else _EMPTY_DICT
|
loaded = _load(file) if os.path.isfile(file) else _EMPTY_DICT
|
||||||
st = State(loaded, file)
|
st = _State(loaded, file)
|
||||||
_save(file, st._dict)
|
_save(file, st._dict)
|
||||||
return st
|
return st
|
||||||
|
|
||||||
|
|
||||||
|
_STATE_FILE = ".matchy/state.json"
|
||||||
|
State = load_from_file(_STATE_FILE)
|
||||||
|
|
|
@ -11,6 +11,12 @@ import itertools
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clean_state():
|
||||||
|
"""Ensure every single one of these tests has a clean state"""
|
||||||
|
state.State = state._State(state._EMPTY_DICT)
|
||||||
|
|
||||||
|
|
||||||
def test_protocols():
|
def test_protocols():
|
||||||
"""Verify the protocols we're using match the discord ones"""
|
"""Verify the protocols we're using match the discord ones"""
|
||||||
assert isinstance(discord.Member, matching.Member)
|
assert isinstance(discord.Member, matching.Member)
|
||||||
|
@ -59,16 +65,16 @@ class Member():
|
||||||
return self._id
|
return self._id
|
||||||
|
|
||||||
|
|
||||||
def members_to_groups_validate(matchees: list[Member], tmp_state: state.State, per_group: int):
|
def members_to_groups_validate(matchees: list[Member], per_group: int):
|
||||||
"""Inner function to validate the main output of the groups function"""
|
"""Inner function to validate the main output of the groups function"""
|
||||||
groups = matching.members_to_groups(matchees, tmp_state, per_group)
|
groups = matching.members_to_groups(matchees, per_group)
|
||||||
|
|
||||||
# We should always have one group
|
# We should always have one group
|
||||||
assert len(groups)
|
assert len(groups)
|
||||||
|
|
||||||
# Log the groups to history
|
# Log the groups to history
|
||||||
# This will validate the internals
|
# This will validate the internals
|
||||||
tmp_state.log_groups(groups)
|
state.State.log_groups(groups)
|
||||||
|
|
||||||
# Ensure each group contains within the bounds of expected members
|
# Ensure each group contains within the bounds of expected members
|
||||||
for group in groups:
|
for group in groups:
|
||||||
|
@ -96,8 +102,7 @@ def members_to_groups_validate(matchees: list[Member], tmp_state: state.State, p
|
||||||
], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"])
|
], ids=['single', "larger_groups", "100_members", "5_group", "pairs", "356_big_groups"])
|
||||||
def test_members_to_groups_no_history(matchees, per_group):
|
def test_members_to_groups_no_history(matchees, per_group):
|
||||||
"""Test simple group matching works"""
|
"""Test simple group matching works"""
|
||||||
tmp_state = state.State(state._EMPTY_DICT)
|
members_to_groups_validate(matchees, per_group)
|
||||||
members_to_groups_validate(matchees, tmp_state, per_group)
|
|
||||||
|
|
||||||
|
|
||||||
def items_found_in_lists(list_of_lists, items):
|
def items_found_in_lists(list_of_lists, items):
|
||||||
|
@ -328,13 +333,12 @@ def items_found_in_lists(list_of_lists, items):
|
||||||
], ids=['simple_history', 'fallback', 'example_1', 'example_2', 'example_3'])
|
], ids=['simple_history', 'fallback', 'example_1', 'example_2', 'example_3'])
|
||||||
def test_unique_regressions(history_data, matchees, per_group, checks):
|
def test_unique_regressions(history_data, matchees, per_group, checks):
|
||||||
"""Test a bunch of unqiue failures that happened in the past"""
|
"""Test a bunch of unqiue failures that happened in the past"""
|
||||||
tmp_state = state.State(state._EMPTY_DICT)
|
|
||||||
|
|
||||||
# Replay the history
|
# Replay the history
|
||||||
for d in history_data:
|
for d in history_data:
|
||||||
tmp_state.log_groups(d["groups"], d["ts"])
|
state.State.log_groups(d["groups"], d["ts"])
|
||||||
|
|
||||||
groups = members_to_groups_validate(matchees, tmp_state, per_group)
|
groups = members_to_groups_validate(matchees, per_group)
|
||||||
|
|
||||||
# Run the custom validate functions
|
# Run the custom validate functions
|
||||||
for check in checks:
|
for check in checks:
|
||||||
|
@ -380,28 +384,25 @@ def test_stess_random_groups(per_group, num_members, num_history):
|
||||||
member.roles = [Role(i) for i in rand.sample(range(1, 8), 3)]
|
member.roles = [Role(i) for i in rand.sample(range(1, 8), 3)]
|
||||||
|
|
||||||
# For each history item match up groups and log those
|
# For each history item match up groups and log those
|
||||||
cumulative_state = state.State(state._EMPTY_DICT)
|
|
||||||
for i in range(num_history+1):
|
for i in range(num_history+1):
|
||||||
|
|
||||||
# Grab the num of members and replay
|
# Grab the num of members and replay
|
||||||
rand.shuffle(possible_members)
|
rand.shuffle(possible_members)
|
||||||
members = copy.deepcopy(possible_members[:num_members])
|
members = copy.deepcopy(possible_members[:num_members])
|
||||||
|
|
||||||
groups = members_to_groups_validate(
|
groups = members_to_groups_validate(members, per_group)
|
||||||
members, cumulative_state, per_group)
|
state.State.log_groups(
|
||||||
cumulative_state.log_groups(
|
|
||||||
groups, datetime.now() - timedelta(days=num_history-i))
|
groups, datetime.now() - timedelta(days=num_history-i))
|
||||||
|
|
||||||
|
|
||||||
def test_auth_scopes():
|
def test_auth_scopes():
|
||||||
tmp_state = state.State(state._EMPTY_DICT)
|
|
||||||
|
|
||||||
id = "1"
|
id = "1"
|
||||||
assert not tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
|
assert not state.State.get_user_has_scope(id, state.AuthScope.MATCHER)
|
||||||
|
|
||||||
id = "2"
|
id = "2"
|
||||||
tmp_state.set_user_scope(id, state.AuthScope.MATCHER)
|
state.State.set_user_scope(id, state.AuthScope.MATCHER)
|
||||||
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
|
assert state.State.get_user_has_scope(id, state.AuthScope.MATCHER)
|
||||||
|
|
||||||
# Validate the state by constucting a new one
|
# Validate the state by constucting a new one
|
||||||
_ = state.State(tmp_state._dict)
|
_ = state._State(state.State._dict)
|
||||||
|
|
|
@ -2,7 +2,6 @@ import discord
|
||||||
import discord.ext.commands as commands
|
import discord.ext.commands as commands
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
import matchy.state as state
|
|
||||||
import discord.ext.test as dpytest
|
import discord.ext.test as dpytest
|
||||||
|
|
||||||
from matchy.cogs.owner import OwnerCog
|
from matchy.cogs.owner import OwnerCog
|
||||||
|
@ -20,7 +19,7 @@ async def bot():
|
||||||
b = commands.Bot(command_prefix="$",
|
b = commands.Bot(command_prefix="$",
|
||||||
intents=intents)
|
intents=intents)
|
||||||
await b._async_setup_hook()
|
await b._async_setup_hook()
|
||||||
await b.add_cog(OwnerCog(b, state.State(state._EMPTY_DICT)))
|
await b.add_cog(OwnerCog(b))
|
||||||
dpytest.configure(b)
|
dpytest.configure(b)
|
||||||
yield b
|
yield b
|
||||||
await dpytest.empty_queue()
|
await dpytest.empty_queue()
|
||||||
|
|
Loading…
Add table
Reference in a new issue