Merge pull request #12 from mdiluz/late-night-cleanup

Late night cleanup
This commit is contained in:
Marc Di Luzio 2024-08-16 23:48:46 +01:00 committed by GitHub
commit 21d004e94a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 214 additions and 439 deletions

View file

@ -1,4 +1,4 @@
name: Run Tests name: Test, Build and Publish
on: on:
push: push:
@ -9,14 +9,15 @@ on:
pull_request: pull_request:
workflow_dispatch: workflow_dispatch:
# Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
env: env:
# Use the github container registry
REGISTRY: ghcr.io REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }} IMAGE_NAME: ${{ github.repository }}
jobs: jobs:
# Core test runner # Run the tests scripts
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
@ -36,18 +37,17 @@ jobs:
run: | run: |
python tests/test.py python tests/test.py
# Build and push the docker images
build-and-push-images: build-and-push-images:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: test needs: test
# Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
permissions: permissions:
contents: read contents: read
packages: write packages: write
attestations: write attestations: write
id-token: write id-token: write
#
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
@ -57,7 +57,6 @@ jobs:
with: with:
platforms: arm32v7/armhf # arm64v8/aarch64 - no current need for arm64 platforms: arm32v7/armhf # arm64v8/aarch64 - no current need for arm64
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
- name: Log in to the Container registry - name: Log in to the Container registry
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1 uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
with: with:
@ -65,12 +64,12 @@ jobs:
username: ${{ github.actor }} username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
# Use docker-container driver for multi-platform builds
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with: with:
driver: docker-container # Use docker-container driver for multi-platform builds driver: docker-container
# This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
- name: Extract metadata (tags, labels) for Docker - name: Extract metadata (tags, labels) for Docker
id: meta id: meta
uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7 uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
@ -82,9 +81,6 @@ jobs:
type=ref,event=pr type=ref,event=pr
type=edge,branch=main type=edge,branch=main
# This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages.
# It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository.
# It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step.
- name: Build and push Docker image - name: Build and push Docker image
id: push id: push
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
@ -94,10 +90,10 @@ jobs:
push: true push: true
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
# Cache to help the step run a little faster
cache-from: type=gha cache-from: type=gha
cache-to: type=gha,mode=max cache-to: type=gha,mode=max
# This step generates an artifact attestation for the image, which is an unforgeable statement about where and how it was built. It increases supply chain security for people who consume the image. For more information, see "[AUTOTITLE](/actions/security-guides/using-artifact-attestations-to-establish-provenance-for-builds)."
- name: Generate artifact attestation - name: Generate artifact attestation
uses: actions/attest-build-provenance@v1 uses: actions/attest-build-provenance@v1
with: with:

2
.gitignore vendored
View file

@ -1,6 +1,4 @@
__pycache__ __pycache__
config.json
state.json
.venv .venv
.coverage .coverage
.matchy .matchy

View file

@ -3,7 +3,9 @@
Matchy matches matchees. Matchy matches matchees.
Matchy is a Discord bot that groups up users for fun and vibes. Matchy can be installed on your server by clicking [here](https://discord.com/oauth2/authorize?client_id=1270849346987884696&permissions=0&integration_type=0&scope=bot). Matchy only allows authorised users to trigger posts in channels. Matchy is a Discord bot that groups up users for fun and vibes. Matchy can be installed on your server [here](https://discord.com/oauth2/authorize?client_id=1270849346987884696&permissions=0&integration_type=0&scope=bot).
Note: Matchy currently only allows owner-authorised users to trigger posts in channels.
![Tests](https://github.com/mdiluz/matchy/actions/workflows/test.yml/badge.svg) ![Tests](https://github.com/mdiluz/matchy/actions/workflows/test.yml/badge.svg)
@ -46,26 +48,8 @@ Python tests are written to use `pytest` and cover most internal functionality.
## Hosting ## Hosting
### Config and State ### State
Matchy is configured by an optional `$MATCHY_CONFIG` envar or a `.matchy/config.json` file that takes this format: State is stored locally in a `.matchy/state.json` file. This will be created by the bot. This stores historical information on users, maching schedules, user auth scopes and more. See [`state.py`](matchy/files/state.py) for schema information if you need to inspect it.
```json
{
"version" : 2,
"match" : {
"score_factors": {
"repeat_role" : 4,
"repeat_match" : 8,
"extra_member" : 32,
"upper_threshold" : 64
}
}
}
```
Only the version is required.
See [`config.py`](matchy/files/config.py) for explanations for any extra settings here.
_State_ is stored locally in a `.matchy/state.json` file. This will be created by the bot. This stores historical information on users, maching schedules, user auth scopes and more. See [`state.py`](matchy/files/state.py) for schema information if you need to inspect it.
### Secrets ### Secrets
The `TOKEN` envar is required run the bot. It's recommended this is placed in a local `.env` file. To generate bot token for development see [this discord.py guide](https://discordpy.readthedocs.io/en/stable/discord.html). The `TOKEN` envar is required run the bot. It's recommended this is placed in a local `.env` file. To generate bot token for development see [this discord.py guide](https://discordpy.readthedocs.io/en/stable/discord.html).

View file

@ -5,8 +5,8 @@ import logging
import discord import discord
from discord.ext import commands from discord.ext import commands
import os import os
from matchy.files.state import load_from_file from matchy.state import load_from_file
import matchy.cogs.matchy import matchy.cogs.matcher
import matchy.cogs.owner import matchy.cogs.owner
_STATE_FILE = ".matchy/state.json" _STATE_FILE = ".matchy/state.json"
@ -24,7 +24,7 @@ bot = commands.Bot(command_prefix='$',
@bot.event @bot.event
async def setup_hook(): async def setup_hook():
await bot.add_cog(matchy.cogs.matchy.MatchyCog(bot, state)) await bot.add_cog(matchy.cogs.matcher.MatcherCog(bot, state))
await bot.add_cog(matchy.cogs.owner.OwnerCog(bot, state)) await bot.add_cog(matchy.cogs.owner.OwnerCog(bot, state))

View file

@ -6,17 +6,19 @@ import discord
from discord import app_commands from discord import app_commands
from discord.ext import commands, tasks from discord.ext import commands, tasks
from datetime import datetime, timedelta, time from datetime import datetime, timedelta, time
import re
import matchy.views.match as match
import matchy.matching as matching import matchy.matching as matching
from matchy.files.state import State, AuthScope from matchy.state import State, AuthScope
import matchy.util as util import matchy.util as util
import matchy.state as state
logger = logging.getLogger("cog") logger = logging.getLogger("cog")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
class MatchyCog(commands.Cog): class MatcherCog(commands.Cog):
def __init__(self, bot: commands.Bot, state: State): def __init__(self, bot: commands.Bot, state: State):
self.bot = bot self.bot = bot
self.state = state self.state = state
@ -25,7 +27,7 @@ class MatchyCog(commands.Cog):
async def on_ready(self): async def on_ready(self):
"""Bot is ready and connected""" """Bot is ready and connected"""
self.run_hourly_tasks.start() self.run_hourly_tasks.start()
self.bot.add_dynamic_items(match.DynamicGroupButton) self.bot.add_dynamic_items(DynamicGroupButton)
activity = discord.Game("/join") activity = discord.Game("/join")
await self.bot.change_presence(status=discord.Status.online, activity=activity) await self.bot.change_presence(status=discord.Status.online, activity=activity)
logger.info("Bot is up and ready!") logger.info("Bot is up and ready!")
@ -180,7 +182,7 @@ class MatchyCog(commands.Cog):
# Otherwise set up the button # Otherwise set up the button
msg += "\n\nClick the button to match up groups and send them to the channel.\n" msg += "\n\nClick the button to match up groups and send them to the channel.\n"
view = discord.ui.View(timeout=None) view = discord.ui.View(timeout=None)
view.add_item(match.DynamicGroupButton(members_min)) view.add_item(DynamicGroupButton(members_min))
else: else:
# Let a non-matcher know why they don't have the button # Let a non-matcher know why they don't have the button
msg += f"\n\nYou'll need the {AuthScope.MATCHER}" msg += f"\n\nYou'll need the {AuthScope.MATCHER}"
@ -204,3 +206,45 @@ class MatchyCog(commands.Cog):
msg_channel = self.bot.get_channel(int(channel)) msg_channel = self.bot.get_channel(int(channel))
await msg_channel.send("Arf arf! just a reminder I'll be doin a matcherino in here in T-24hrs!" await msg_channel.send("Arf arf! just a reminder I'll be doin a matcherino in here in T-24hrs!"
+ "\nUse /join if you haven't already, or /pause if you want to skip a week :)") + "\nUse /join if you haven't already, or /pause if you want to skip a week :)")
# Increment when adjusting the custom_id so we don't confuse old users
_MATCH_BUTTON_CUSTOM_ID_VERSION = 1
_MATCH_BUTTON_CUSTOM_ID_PREFIX = f'match:v{_MATCH_BUTTON_CUSTOM_ID_VERSION}:'
class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
template=_MATCH_BUTTON_CUSTOM_ID_PREFIX + r'min:(?P<min>[0-9]+)'):
"""
Describes a simple button that lets the user trigger a match
"""
def __init__(self, min: int) -> None:
super().__init__(
discord.ui.Button(
label='Match Groups!',
style=discord.ButtonStyle.blurple,
custom_id=_MATCH_BUTTON_CUSTOM_ID_PREFIX + f'min:{min}',
)
)
self.min: int = min
self.state = state.load_from_file()
# This is called when the button is clicked and the custom_id matches the template.
@classmethod
async def from_custom_id(cls, intrctn: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /):
min = int(match['min'])
return cls(min)
async def callback(self, intrctn: discord.Interaction) -> None:
"""Match up people when the button is pressed"""
logger.info("Handling button press min=%s", self.min)
logger.info("User %s from %s in #%s", intrctn.user,
intrctn.guild.name, intrctn.channel.name)
# Let the user know we've recieved the message
await intrctn.response.send_message(content="Matchy is matching matchees...", ephemeral=True)
# Perform the match
await matching.match_groups_in_channel(self.state, intrctn.channel, self.min)

View file

@ -3,7 +3,7 @@ Owner bot cog
""" """
import logging import logging
from discord.ext import commands from discord.ext import commands
from matchy.files.state import State, AuthScope from matchy.state import State, AuthScope
logger = logging.getLogger("owner") logger = logging.getLogger("owner")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)

View file

@ -1,152 +0,0 @@
"""Very simple config loading library"""
from schema import Schema, Use, Optional
import matchy.files.ops as ops
import os
import logging
import json
logger = logging.getLogger("config")
logger.setLevel(logging.INFO)
# Envar takes precedent
_ENVAR = "MATCHY_CONFIG"
_FILE = ".matchy/config.json"
# Warning: Changing any of the below needs proper thought to ensure backwards compatibility
_VERSION = 2
class _Key():
VERSION = "version"
MATCH = "match"
SCORE_FACTORS = "score_factors"
REPEAT_ROLE = "repeat_role"
REPEAT_MATCH = "repeat_match"
EXTRA_MEMBER = "extra_member"
UPPER_THRESHOLD = "upper_threshold"
# Removed
_OWNERS = "owners"
_TOKEN = "token"
_SCHEMA = Schema(
{
# The current version
_Key.VERSION: Use(int),
# Settings for the match algorithmn, see matching.py for explanations on usage
Optional(_Key.MATCH): {
Optional(_Key.SCORE_FACTORS): {
Optional(_Key.REPEAT_ROLE): Use(int),
Optional(_Key.REPEAT_MATCH): Use(int),
Optional(_Key.EXTRA_MEMBER): Use(int),
Optional(_Key.UPPER_THRESHOLD): Use(int),
}
}
}
)
_EMPTY_DICT = {
_Key.VERSION: _VERSION
}
def _migrate_to_v1(d: dict):
# Owners moved to History in v1
# Note: owners will be required to be re-added to the state.json
if _Key._OWNERS in d:
owners = d.pop(_Key._OWNERS)
logger.warning(
"Migration removed owners from config, these must be re-added to the state.json")
logger.warning("Owners: %s", owners)
def _migrate_to_v2(d: dict):
# Token moved to the environment
if _Key._TOKEN in d:
del d[_Key._TOKEN]
# Set of migration functions to apply
_MIGRATIONS = [
_migrate_to_v1,
_migrate_to_v2
]
class _ScoreFactors():
def __init__(self, data: dict):
"""Initialise and validate the config"""
self._dict = data
@property
def repeat_role(self) -> int:
return self._dict.get(_Key.REPEAT_ROLE, None)
@property
def repeat_match(self) -> int:
return self._dict.get(_Key.REPEAT_MATCH, None)
@property
def extra_member(self) -> int:
return self._dict.get(_Key.EXTRA_MEMBER, None)
@property
def upper_threshold(self) -> int:
return self._dict.get(_Key.UPPER_THRESHOLD, None)
class _Config():
def __init__(self, data: dict):
"""Initialise and validate the config"""
_SCHEMA.validate(data)
self._dict = data
@property
def token(self) -> str:
return self._dict["token"]
@property
def score_factors(self) -> _ScoreFactors:
return _ScoreFactors(self._dict.get(_Key.SCORE_FACTORS, {}))
def _migrate(dict: dict):
"""Migrate a dict through versions"""
version = dict.get("version", 0)
for i in range(version, _VERSION):
_MIGRATIONS[i](dict)
dict["version"] = _VERSION
def _load() -> _Config:
"""
Load the state from an envar or file
Apply any required migrations
"""
# Try the envar first
envar = os.environ.get(_ENVAR)
if envar:
loaded = json.loads(envar)
logger.info("Config loaded from $%s", _ENVAR)
else:
# Otherwise try the file
if os.path.isfile(_FILE):
loaded = ops.load(_FILE)
logger.info("Config loaded from %s", _FILE)
else:
loaded = _EMPTY_DICT
logger.warning("No %s file found, using defaults", _FILE)
_migrate(loaded)
return _Config(loaded)
# Core config for users to use
# Singleton as there should only be one, it's static, and global
Config = _load()

View file

@ -1,26 +0,0 @@
"""File operation helpers"""
import json
import shutil
import pathlib
import os
def load(file: str) -> dict:
"""Load a json file directly as a dict"""
with open(file) as f:
return json.load(f)
def save(file: str, content: dict):
"""
Save out a content dictionary to a file
"""
# Ensure the save directory exists first
dir = pathlib.Path(os.path.dirname(file))
dir.mkdir(parents=True, exist_ok=True)
# Store in an intermediary directory first
intermediate = file + ".nxt"
with open(intermediate, "w") as f:
json.dump(content, f, indent=4)
shutil.move(intermediate, file)

View file

@ -3,9 +3,8 @@ import logging
import discord import discord
from datetime import datetime from datetime import datetime
from typing import Protocol, runtime_checkable from typing import Protocol, runtime_checkable
from matchy.files.state import State, ts_to_datetime from matchy.state import State, ts_to_datetime
import matchy.util as util import matchy.util as util
import matchy.files.config as config
class _ScoreFactors(int): class _ScoreFactors(int):
@ -15,14 +14,14 @@ class _ScoreFactors(int):
""" """
# Added for each role the matchee has that another group member has # Added for each role the matchee has that another group member has
REPEAT_ROLE = config.Config.score_factors.repeat_role or 2**2 REPEAT_ROLE = 2**2
# Added for each member in the group that the matchee has already matched with # Added for each member in the group that the matchee has already matched with
REPEAT_MATCH = config.Config.score_factors.repeat_match or 2**3 REPEAT_MATCH = 2**3
# Added for each additional member over the set "per group" value # Added for each additional member over the set "per group" value
EXTRA_MEMBER = config.Config.score_factors.extra_member or 2**5 EXTRA_MEMBER = 2**5
# Upper threshold, if the user scores higher than this they will not be placed in that group # Upper threshold, if the user scores higher than this they will not be placed in that group
UPPER_THRESHOLD = config.Config.score_factors.upper_threshold or 2**6 UPPER_THRESHOLD = 2**6
logger = logging.getLogger("matching") logger = logging.getLogger("matching")
@ -66,12 +65,6 @@ class Guild(Protocol):
pass pass
def members_to_groups_simple(matchees: list[Member], per_group: int) -> tuple[bool, list[list[Member]]]:
"""Super simple group matching, literally no logic"""
num_groups = max(len(matchees)//per_group, 1)
return [matchees[i::num_groups] for i in range(num_groups)]
def get_member_group_eligibility_score(member: Member, def get_member_group_eligibility_score(member: Member,
group: list[Member], group: list[Member],
prior_matches: list[int], prior_matches: list[int],
@ -149,14 +142,6 @@ def attempt_create_groups(matchees: list[Member],
return groups return groups
def iterate_all_shifts(list: list):
"""Yields each shifted variation of the input list"""
yield list
for _ in range(len(list)-1):
list = list[1:] + [list[0]]
yield list
def members_to_groups(matchees: list[Member], def members_to_groups(matchees: list[Member],
state: State, state: State,
per_group: int = 3, per_group: int = 3,
@ -173,7 +158,7 @@ def members_to_groups(matchees: list[Member],
for oldest_relevant_datetime in state.get_history_timestamps(matchees) + [datetime.now()]: for oldest_relevant_datetime in state.get_history_timestamps(matchees) + [datetime.now()]:
# Attempt with each starting matchee # Attempt with each starting matchee
for shifted_matchees in iterate_all_shifts(matchees): for shifted_matchees in util.iterate_all_shifts(matchees):
attempts += 1 attempts += 1
groups = attempt_create_groups( groups = attempt_create_groups(
@ -187,7 +172,7 @@ def members_to_groups(matchees: list[Member],
# If we've still failed, just use the simple method # If we've still failed, just use the simple method
if allow_fallback: if allow_fallback:
logger.info("Fell back to simple groups after %s attempt(s)", attempts) logger.info("Fell back to simple groups after %s attempt(s)", attempts)
return members_to_groups_simple(matchees, per_group) return [matchees[i::num_groups] for i in range(num_groups)]
# Simply assert false, this should never happen # Simply assert false, this should never happen
# And should be caught by tests # And should be caught by tests
@ -200,10 +185,8 @@ async def match_groups_in_channel(state: State, channel: discord.channel, min: i
# Send the groups # Send the groups
for group in groups: for group in groups:
message = await channel.send( message = await channel.send(
f"Matched up {util.format_list([m.mention for m in group])}!") f"Matched up {util.format_list([m.mention for m in group])}!")
# Set up a thread for this match if the bot has permissions to do so # Set up a thread for this match if the bot has permissions to do so
if channel.permissions_for(channel.guild.me).create_public_threads: if channel.permissions_for(channel.guild.me).create_public_threads:
await channel.create_thread( await channel.create_thread(
@ -213,7 +196,6 @@ async def match_groups_in_channel(state: State, channel: discord.channel, min: i
# Close off with a message # Close off with a message
await channel.send("That's all folks, happy matching and remember - DFTBA!") await channel.send("That's all folks, happy matching and remember - DFTBA!")
# Save the groups to the history # Save the groups to the history
state.log_groups(groups) state.log_groups(groups)
@ -224,16 +206,13 @@ def get_matchees_in_channel(state: State, channel: discord.channel):
"""Fetches the matchees in a channel""" """Fetches the matchees in a channel"""
# Reactivate any unpaused users # Reactivate any unpaused users
state.reactivate_users(channel.id) state.reactivate_users(channel.id)
# Gather up the prospective matchees # Gather up the prospective matchees
return [m for m in channel.members if state.get_user_active_in_channel(m.id, channel.id)] return [m for m in channel.members if state.get_user_active_in_channel(m.id, channel.id)]
def active_members_to_groups(state: State, channel: discord.channel, min_members: int): def active_members_to_groups(state: State, channel: discord.channel, min_members: int):
"""Helper to create groups from channel members""" """Helper to create groups from channel members"""
# Gather up the prospective matchees # Gather up the prospective matchees
matchees = get_matchees_in_channel(state, channel) matchees = get_matchees_in_channel(state, channel)
# Create our groups! # Create our groups!
return members_to_groups(matchees, state, min_members, allow_fallback=True) return members_to_groups(matchees, state, min_members, allow_fallback=True)

View file

@ -4,10 +4,12 @@ from datetime import datetime
from schema import Schema, Use, Optional from schema import Schema, Use, Optional
from collections.abc import Generator from collections.abc import Generator
from typing import Protocol from typing import Protocol
import matchy.files.ops as ops import json
import shutil
import pathlib
import copy import copy
import logging import logging
from contextlib import contextmanager from functools import wraps
logger = logging.getLogger("state") logger = logging.getLogger("state")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -169,18 +171,58 @@ def datetime_to_ts(ts: datetime) -> str:
return datetime.strftime(ts, _TIME_FORMAT) return datetime.strftime(ts, _TIME_FORMAT)
def _load(file: str) -> dict:
"""Load a json file directly as a dict"""
with open(file) as f:
return json.load(f)
def _save(file: str, content: dict):
"""
Save out a content dictionary to a file
"""
# Ensure the save directory exists first
dir = pathlib.Path(os.path.dirname(file))
dir.mkdir(parents=True, exist_ok=True)
# Store in an intermediary directory first
intermediate = file + ".nxt"
with open(intermediate, "w") as f:
json.dump(content, f, indent=4)
shutil.move(intermediate, file)
class State(): class State():
def __init__(self, data: dict, file: str | None = None): def __init__(self, data: dict, file: str | None = None):
"""Initialise and validate the state""" """Copy the data, migrate if needed, and validate"""
self.validate(data)
self._dict = copy.deepcopy(data) self._dict = copy.deepcopy(data)
self._file = file self._file = file
def validate(self, dict: dict = None): version = self._dict.get("version", 0)
"""Initialise and validate a state dict""" for i in range(version, _VERSION):
if not dict: logger.info("Migrating from v%s to v%s", version, version+1)
dict = self._dict _MIGRATIONS[i](self._dict)
_SCHEMA.validate(dict) self._dict[_Key.VERSION] = _VERSION
_SCHEMA.validate(self._dict)
@staticmethod
def safe_write(func):
"""
Wraps any function running it first on some temporary state
Validates the resulting state and only then attempts to save it out
before storing the dict back in the State
"""
@wraps(func)
def inner(self, *args, **kwargs):
tmp = State(self._dict, self._file)
func(tmp, *args, **kwargs)
_SCHEMA.validate(tmp._dict)
if tmp._file:
_save(tmp._file, tmp._dict)
self._dict = tmp._dict
return inner
def get_history_timestamps(self, users: list[Member]) -> list[datetime]: def get_history_timestamps(self, users: list[Member]) -> list[datetime]:
"""Grab all timestamps in the history""" """Grab all timestamps in the history"""
@ -202,24 +244,24 @@ class State():
def get_user_matches(self, id: int) -> list[int]: def get_user_matches(self, id: int) -> list[int]:
return self._users.get(str(id), {}).get(_Key.MATCHES, {}) return self._users.get(str(id), {}).get(_Key.MATCHES, {})
@safe_write
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None: def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
"""Log the groups""" """Log the groups"""
ts = datetime_to_ts(ts or datetime.now()) ts = datetime_to_ts(ts or datetime.now())
with self._safe_wrap_write() as safe_state:
for group in groups: for group in groups:
# Update the matchee data with the matches # Update the matchee data with the matches
for m in group: for m in group:
matchee = safe_state._users.setdefault(str(m.id), {}) matchee = self._users.setdefault(str(m.id), {})
matchee_matches = matchee.setdefault(_Key.MATCHES, {}) matchee_matches = matchee.setdefault(_Key.MATCHES, {})
for o in (o for o in group if o.id != m.id): for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts matchee_matches[str(o.id)] = ts
@safe_write
def set_user_scope(self, id: str, scope: str, value: bool = True): def set_user_scope(self, id: str, scope: str, value: bool = True):
"""Add an auth scope to a user""" """Add an auth scope to a user"""
with self._safe_wrap_write() as safe_state:
# Dive in # Dive in
user = safe_state._users.setdefault(str(id), {}) user = self._users.setdefault(str(id), {})
scopes = user.setdefault(_Key.SCOPES, []) scopes = user.setdefault(_Key.SCOPES, [])
# Set the value # Set the value
@ -255,10 +297,10 @@ class State():
self._set_user_channel_prop( self._set_user_channel_prop(
id, channel_id, _Key.REACTIVATE, datetime_to_ts(until)) id, channel_id, _Key.REACTIVATE, datetime_to_ts(until))
@safe_write
def reactivate_users(self, channel_id: str): def reactivate_users(self, channel_id: str):
"""Reactivate any users who've passed their reactivation time on this channel""" """Reactivate any users who've passed their reactivation time on this channel"""
with self._safe_wrap_write() as safe_state: for user in self._users.values():
for user in safe_state._users.values():
channels = user.get(_Key.CHANNELS, {}) channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {}) channel = channels.get(str(channel_id), {})
if channel and not channel[_Key.ACTIVE]: if channel and not channel[_Key.ACTIVE]:
@ -295,10 +337,10 @@ class State():
for task in tasks: for task in tasks:
yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN]) yield (task[_Key.WEEKDAY], task[_Key.HOUR], task[_Key.MEMBERS_MIN])
@safe_write
def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool: def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool:
"""Set up a match task on a channel""" """Set up a match task on a channel"""
with self._safe_wrap_write() as safe_state: channel = self._tasks.setdefault(str(channel_id), {})
channel = safe_state._tasks.setdefault(str(channel_id), {})
matches = channel.setdefault(_Key.MATCH_TASKS, []) matches = channel.setdefault(_Key.MATCH_TASKS, [])
found = False found = False
@ -327,11 +369,6 @@ class State():
# We did not manage to remove the schedule (or add it? though that should be impossible) # We did not manage to remove the schedule (or add it? though that should be impossible)
return False return False
@property
def dict_internal_copy(self) -> dict:
"""Only to be used to get the internal dict as a copy"""
return copy.deepcopy(self._dict)
@property @property
def _users(self) -> dict[str]: def _users(self) -> dict[str]:
return self._dict[_Key.USERS] return self._dict[_Key.USERS]
@ -340,63 +377,20 @@ class State():
def _tasks(self) -> dict[str]: def _tasks(self) -> dict[str]:
return self._dict[_Key.TASKS] return self._dict[_Key.TASKS]
@safe_write
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value): def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
"""Set a user channel property helper""" """Set a user channel property helper"""
with self._safe_wrap_write() as safe_state: user = self._users.setdefault(str(id), {})
# Dive in
user = safe_state._users.setdefault(str(id), {})
channels = user.setdefault(_Key.CHANNELS, {}) channels = user.setdefault(_Key.CHANNELS, {})
channel = channels.setdefault(str(channel_id), {}) channel = channels.setdefault(str(channel_id), {})
# Set the value
channel[key] = value channel[key] = value
# TODO: Make this a decorator?
@contextmanager
def _safe_wrap_write(self):
"""Safely run any function wrapped in a validate"""
# Wrap in a temporary state to validate first to prevent corruption
tmp_state = State(self._dict)
try:
yield tmp_state
finally:
# Validate and then overwrite our dict with the new one
tmp_state.validate()
self._dict = tmp_state._dict
# Write this change out if we have a file
if self._file:
self._save_to_file()
def _save_to_file(self):
"""Saves the state out to the chosen file"""
ops.save(self._file, self.dict_internal_copy)
def _migrate(dict: dict):
"""Migrate a dict through versions"""
version = dict.get("version", 0)
for i in range(version, _VERSION):
logger.info("Migrating from v%s to v%s", version, version+1)
_MIGRATIONS[i](dict)
dict[_Key.VERSION] = _VERSION
def load_from_file(file: str) -> State: def load_from_file(file: str) -> State:
""" """
Load the state from a files Load the state from a files
Apply any required migrations
""" """
loaded = _EMPTY_DICT loaded = _load(file) if os.path.isfile(file) else _EMPTY_DICT
# If there's a file load it and try to migrate
if os.path.isfile(file):
loaded = ops.load(file)
_migrate(loaded)
st = State(loaded, file) st = State(loaded, file)
_save(file, st._dict)
# Save out the migrated (or new) file
ops.save(file, st._dict)
return st return st

View file

@ -37,3 +37,11 @@ def get_next_datetime(weekday, hour) -> datetime:
next_date.replace(hour=hour) next_date.replace(hour=hour)
return next_date return next_date
def iterate_all_shifts(list: list):
"""Yields each shifted variation of the input list"""
yield list
for _ in range(len(list)-1):
list = list[1:] + [list[0]]
yield list

View file

@ -1,53 +0,0 @@
"""
Class for a button that matches groups in a channel
"""
import logging
import discord
import re
import matchy.files.state as state
import matchy.matching as matching
logger = logging.getLogger("match_button")
logger.setLevel(logging.INFO)
# Increment when adjusting the custom_id so we don't confuse old users
_MATCH_BUTTON_CUSTOM_ID_VERSION = 1
_MATCH_BUTTON_CUSTOM_ID_PREFIX = f'match:v{_MATCH_BUTTON_CUSTOM_ID_VERSION}:'
class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
template=_MATCH_BUTTON_CUSTOM_ID_PREFIX + r'min:(?P<min>[0-9]+)'):
"""
Describes a simple button that lets the user trigger a match
"""
def __init__(self, min: int) -> None:
super().__init__(
discord.ui.Button(
label='Match Groups!',
style=discord.ButtonStyle.blurple,
custom_id=_MATCH_BUTTON_CUSTOM_ID_PREFIX + f'min:{min}',
)
)
self.min: int = min
self.state = state.load_from_file()
# This is called when the button is clicked and the custom_id matches the template.
@classmethod
async def from_custom_id(cls, intrctn: discord.Interaction, item: discord.ui.Button, match: re.Match[str], /):
min = int(match['min'])
return cls(min)
async def callback(self, intrctn: discord.Interaction) -> None:
"""Match up people when the button is pressed"""
logger.info("Handling button press min=%s", self.min)
logger.info("User %s from %s in #%s", intrctn.user,
intrctn.guild.name, intrctn.channel.name)
# Let the user know we've recieved the message
await intrctn.response.send_message(content="Matchy is matching matchees...", ephemeral=True)
# Perform the match
await matching.match_groups_in_channel(self.state, intrctn.channel, self.min)

View file

@ -5,7 +5,7 @@ import discord
import pytest import pytest
import random import random
import matchy.matching as matching import matchy.matching as matching
import matchy.files.state as state import matchy.state as state
import copy import copy
import itertools import itertools
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -403,15 +403,5 @@ def test_auth_scopes():
tmp_state.set_user_scope(id, state.AuthScope.MATCHER) tmp_state.set_user_scope(id, state.AuthScope.MATCHER)
assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER) assert tmp_state.get_user_has_scope(id, state.AuthScope.MATCHER)
tmp_state.validate() # Validate the state by constucting a new one
_ = state.State(tmp_state._dict)
def test_iterate_all_shifts():
original = [1, 2, 3, 4]
lists = [val for val in matching.iterate_all_shifts(original)]
assert lists == [
[1, 2, 3, 4],
[2, 3, 4, 1],
[3, 4, 1, 2],
[4, 1, 2, 3],
]

View file

@ -2,7 +2,7 @@ import discord
import discord.ext.commands as commands import discord.ext.commands as commands
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import matchy.files.state as state import matchy.state as state
import discord.ext.test as dpytest import discord.ext.test as dpytest
from matchy.cogs.owner import OwnerCog from matchy.cogs.owner import OwnerCog

View file

@ -1,7 +1,7 @@
""" """
Test functions for the state module Test functions for the state module
""" """
import matchy.files.state as state import matchy.state as state
import tempfile import tempfile
import os import os
@ -18,10 +18,11 @@ def test_simple_load_reload():
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json') path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path) st = state.load_from_file(path)
st._save_to_file() state._save(st._file, st._dict)
state._save(st._file, st._dict)
st = state.load_from_file(path) st = state.load_from_file(path)
st._save_to_file() state._save(st._file, st._dict)
st = state.load_from_file(path) st = state.load_from_file(path)
@ -30,13 +31,13 @@ def test_authscope():
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json') path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path) st = state.load_from_file(path)
st._save_to_file() state._save(st._file, st._dict)
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER) assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
st = state.load_from_file(path) st = state.load_from_file(path)
st.set_user_scope(1, state.AuthScope.MATCHER) st.set_user_scope(1, state.AuthScope.MATCHER)
st._save_to_file() state._save(st._file, st._dict)
st = state.load_from_file(path) st = state.load_from_file(path)
assert st.get_user_has_scope(1, state.AuthScope.MATCHER) assert st.get_user_has_scope(1, state.AuthScope.MATCHER)
@ -50,13 +51,13 @@ def test_channeljoin():
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'tmp.json') path = os.path.join(tmp, 'tmp.json')
st = state.load_from_file(path) st = state.load_from_file(path)
st._save_to_file() state._save(st._file, st._dict)
assert not st.get_user_active_in_channel(1, "2") assert not st.get_user_active_in_channel(1, "2")
st = state.load_from_file(path) st = state.load_from_file(path)
st.set_user_active_in_channel(1, "2", True) st.set_user_active_in_channel(1, "2", True)
st._save_to_file() state._save(st._file, st._dict)
st = state.load_from_file(path) st = state.load_from_file(path)
assert st.get_user_active_in_channel(1, "2") assert st.get_user_active_in_channel(1, "2")

12
tests/util_test.py Normal file
View file

@ -0,0 +1,12 @@
import matchy.util as util
def test_iterate_all_shifts():
original = [1, 2, 3, 4]
lists = [val for val in util.iterate_all_shifts(original)]
assert lists == [
[1, 2, 3, 4],
[2, 3, 4, 1],
[3, 4, 1, 2],
[4, 1, 2, 3],
]