Implement user pausing with /pause

This commit is contained in:
Marc Di Luzio 2024-08-11 19:02:47 +01:00
parent a480549ad3
commit 7efe781e66
4 changed files with 125 additions and 67 deletions

View file

@ -30,7 +30,6 @@ Matchy is configured by a `config.json` file that takes this format:
## TODO ## TODO
* Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html) * Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html)
* Implement /pause to pause a user for a little while
* Move more constants to the config * Move more constants to the config
* Add scheduling functionality * Add scheduling functionality
* Fix logging in some sub files (doesn't seem to actually be output?) * Fix logging in some sub files (doesn't seem to actually be output?)

View file

@ -75,7 +75,7 @@ async def close(ctx: commands.Context):
@bot.tree.command(description="Join the matchees for this channel") @bot.tree.command(description="Join the matchees for this channel")
@commands.guild_only() @commands.guild_only()
async def join(interaction: discord.Interaction): async def join(interaction: discord.Interaction):
State.set_use_active_in_channel( State.set_user_active_in_channel(
interaction.user.id, interaction.channel.id) interaction.user.id, interaction.channel.id)
state.save_to_file(State, STATE_FILE) state.save_to_file(State, STATE_FILE)
await interaction.response.send_message( await interaction.response.send_message(
@ -87,13 +87,26 @@ async def join(interaction: discord.Interaction):
@bot.tree.command(description="Leave the matchees for this channel") @bot.tree.command(description="Leave the matchees for this channel")
@commands.guild_only() @commands.guild_only()
async def leave(interaction: discord.Interaction): async def leave(interaction: discord.Interaction):
State.set_use_active_in_channel( State.set_user_active_in_channel(
interaction.user.id, interaction.channel.id, False) interaction.user.id, interaction.channel.id, False)
state.save_to_file(State, STATE_FILE) state.save_to_file(State, STATE_FILE)
await interaction.response.send_message( await interaction.response.send_message(
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True) f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
@bot.tree.command(description="Pause your matching in this channel for a number of days")
@commands.guild_only()
@app_commands.describe(days="Days to pause for (defaults to 7)")
async def pause(interaction: discord.Interaction, days: int = None):
if not days: # Default to a week
days = 7
State.set_user_paused_in_channel(
interaction.user.id, interaction.channel.id, days)
state.save_to_file(State, STATE_FILE)
await interaction.response.send_message(
f"Sure thing {interaction.user.mention}. Paused you for {days} days!", ephemeral=True, silent=True)
@bot.tree.command(description="List the matchees for this channel") @bot.tree.command(description="List the matchees for this channel")
@commands.guild_only() @commands.guild_only()
async def list(interaction: discord.Interaction): async def list(interaction: discord.Interaction):
@ -195,6 +208,9 @@ class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
def get_matchees_in_channel(channel: discord.channel): def get_matchees_in_channel(channel: discord.channel):
"""Fetches the matchees in a channel""" """Fetches the matchees in a channel"""
# Reactivate any unpaused users
State.reactivate_users(channel.id)
# Gather up the prospective matchees # Gather up the prospective matchees
return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)] return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)]

View file

@ -1,11 +1,12 @@
"""Store bot state""" """Store bot state"""
import os import os
from datetime import datetime from datetime import datetime, timedelta
from schema import Schema, And, Use, Optional from schema import Schema, And, Use, Optional
from typing import Protocol from typing import Protocol
import files import files
import copy import copy
import logging import logging
from contextlib import contextmanager
logger = logging.getLogger("state") logger = logging.getLogger("state")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -83,6 +84,8 @@ _SCHEMA = Schema(
Optional(str): { Optional(str): {
# Whether the user is signed up in this channel # Whether the user is signed up in this channel
_Key.ACTIVE: And(Use(bool)), _Key.ACTIVE: And(Use(bool)),
# A timestamp for when to re-activate the user
Optional(_Key.REACTIVATE): And(Use(str)),
} }
} }
} }
@ -106,24 +109,21 @@ class Member(Protocol):
def ts_to_datetime(ts: str) -> datetime: def ts_to_datetime(ts: str) -> datetime:
"""Convert a ts to datetime using the internal format""" """Convert a string ts to datetime using the internal format"""
return datetime.strptime(ts, _TIME_FORMAT) return datetime.strptime(ts, _TIME_FORMAT)
def datetime_to_ts(ts: datetime) -> str:
"""Convert a datetime to a string ts using the internal format"""
return datetime.strftime(ts, _TIME_FORMAT)
class State(): class State():
def __init__(self, data: dict = _EMPTY_DICT): def __init__(self, data: dict = _EMPTY_DICT):
"""Initialise and validate the state""" """Initialise and validate the state"""
self.validate(data) self.validate(data)
self._dict = copy.deepcopy(data) self._dict = copy.deepcopy(data)
@property
def _history(self) -> dict[str]:
return self._dict[_Key.HISTORY]
@property
def _users(self) -> dict[str]:
return self._dict[_Key.USERS]
def validate(self, dict: dict = None): def validate(self, dict: dict = None):
"""Initialise and validate a state dict""" """Initialise and validate a state dict"""
if not dict: if not dict:
@ -138,14 +138,13 @@ class State():
def get_user_matches(self, id: int) -> list[int]: def get_user_matches(self, id: int) -> list[int]:
return self._users.get(str(id), {}).get(_Key.MATCHES, {}) return self._users.get(str(id), {}).get(_Key.MATCHES, {})
def log_groups(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None: def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
"""Log the groups""" """Log the groups"""
tmp_state = State(self._dict) ts = datetime_to_ts(ts or datetime.now())
ts = datetime.strftime(ts, _TIME_FORMAT) with self._safe_wrap() as safe_state:
# Grab or create the hitory item for this set of groups # Grab or create the hitory item for this set of groups
history_item = {} history_item = {}
tmp_state._history[ts] = history_item safe_state._history[ts] = history_item
history_item_groups = [] history_item_groups = []
history_item[_Key.GROUPS] = history_item_groups history_item[_Key.GROUPS] = history_item_groups
@ -158,23 +157,20 @@ class State():
# Update the matchee data with the matches # Update the matchee data with the matches
for m in group: for m in group:
matchee = tmp_state._users.get(str(m.id), {}) matchee = safe_state._users.get(str(m.id), {})
matchee_matches = matchee.get(_Key.MATCHES, {}) matchee_matches = matchee.get(_Key.MATCHES, {})
for o in (o for o in group if o.id != m.id): for o in (o for o in group if o.id != m.id):
matchee_matches[str(o.id)] = ts matchee_matches[str(o.id)] = ts
matchee[_Key.MATCHES] = matchee_matches matchee[_Key.MATCHES] = matchee_matches
tmp_state._users[str(m.id)] = matchee safe_state._users[str(m.id)] = matchee
# Validate before storing the result
tmp_state.validate()
self._dict = tmp_state._dict
def set_user_scope(self, id: str, scope: str, value: bool = True): def set_user_scope(self, id: str, scope: str, value: bool = True):
"""Add an auth scope to a user""" """Add an auth scope to a user"""
with self._safe_wrap() as safe_state:
# Dive in # Dive in
user = self._users.get(str(id), {}) user = safe_state._users.get(str(id), {})
scopes = user.get(_Key.SCOPES, []) scopes = user.get(_Key.SCOPES, [])
# Set the value # Set the value
@ -185,7 +181,7 @@ class State():
# Roll out # Roll out
user[_Key.SCOPES] = scopes user[_Key.SCOPES] = scopes
self._users[id] = user safe_state._users[str(id)] = user
def get_user_has_scope(self, id: str, scope: str) -> bool: def get_user_has_scope(self, id: str, scope: str) -> bool:
""" """
@ -196,20 +192,9 @@ class State():
scopes = user.get(_Key.SCOPES, []) scopes = user.get(_Key.SCOPES, [])
return AuthScope.OWNER in scopes or scope in scopes return AuthScope.OWNER in scopes or scope in scopes
def set_use_active_in_channel(self, id: str, channel_id: str, active: bool = True): def set_user_active_in_channel(self, id: str, channel_id: str, active: bool = True):
"""Set a user as active (or not) on a given channel""" """Set a user as active (or not) on a given channel"""
# Dive in self._set_user_channel_prop(id, channel_id, _Key.ACTIVE, active)
user = self._users.get(str(id), {})
channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {})
# Set the value
channel[_Key.ACTIVE] = active
# Unroll
channels[str(channel_id)] = channel
user[_Key.CHANNELS] = channels
self._users[str(id)] = user
def get_user_active_in_channel(self, id: str, channel_id: str) -> bool: def get_user_active_in_channel(self, id: str, channel_id: str) -> bool:
"""Get a users active channels""" """Get a users active channels"""
@ -217,11 +202,69 @@ class State():
channels = user.get(_Key.CHANNELS, {}) channels = user.get(_Key.CHANNELS, {})
return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)] return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)]
def set_user_paused_in_channel(self, id: str, channel_id: str, days: int):
"""Sets a user as paused in a channel"""
# Deactivate the user in the channel first
self.set_user_active_in_channel(id, channel_id, False)
# Set the reactivate time the number of days in the future
ts = datetime.now() + timedelta(days=days)
self._set_user_channel_prop(
id, channel_id, _Key.REACTIVATE, datetime_to_ts(ts))
def reactivate_users(self, channel_id: str):
"""Reactivate any users who've passed their reactivation time on this channel"""
with self._safe_wrap() as safe_state:
for user in safe_state._users.values():
channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {})
if channel and not channel[_Key.ACTIVE]:
reactivate = channel.get(_Key.REACTIVATE, None)
# Check if we've gone past the reactivation time and re-activate
if reactivate and datetime.now() > ts_to_datetime(reactivate):
channel[_Key.ACTIVE] = True
@property @property
def dict_internal(self) -> dict: def dict_internal_copy(self) -> dict:
"""Only to be used to get the internal dict as a copy""" """Only to be used to get the internal dict as a copy"""
return copy.deepcopy(self._dict) return copy.deepcopy(self._dict)
@property
def _history(self) -> dict[str]:
return self._dict[_Key.HISTORY]
@property
def _users(self) -> dict[str]:
return self._dict[_Key.USERS]
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
"""Set a user channel property helper"""
with self._safe_wrap() as safe_state:
# Dive in
user = safe_state._users.get(str(id), {})
channels = user.get(_Key.CHANNELS, {})
channel = channels.get(str(channel_id), {})
# Set the value
channel[key] = value
# Unroll
channels[str(channel_id)] = channel
user[_Key.CHANNELS] = channels
safe_state._users[str(id)] = user
@contextmanager
def _safe_wrap(self):
"""Safely run any function wrapped in a validate"""
# Wrap in a temporary state to validate first to prevent corruption
tmp_state = State(self._dict)
try:
yield tmp_state
finally:
# Validate and then overwrite our dict with the new one
tmp_state.validate()
self._dict = tmp_state._dict
def _migrate(dict: dict): def _migrate(dict: dict):
"""Migrate a dict through versions""" """Migrate a dict through versions"""
@ -254,4 +297,4 @@ def load_from_file(file: str) -> State:
def save_to_file(state: State, file: str): def save_to_file(state: State, file: str):
"""Saves the state out to a file""" """Saves the state out to a file"""
files.save(file, state.dict_internal) files.save(file, state.dict_internal_copy)

View file

@ -55,11 +55,11 @@ def test_channeljoin():
assert not st.get_user_active_in_channel(1, "2") assert not st.get_user_active_in_channel(1, "2")
st = state.load_from_file(path) st = state.load_from_file(path)
st.set_use_active_in_channel(1, "2", True) st.set_user_active_in_channel(1, "2", True)
state.save_to_file(st, path) state.save_to_file(st, path)
st = state.load_from_file(path) st = state.load_from_file(path)
assert st.get_user_active_in_channel(1, "2") assert st.get_user_active_in_channel(1, "2")
st.set_use_active_in_channel(1, "2", False) st.set_user_active_in_channel(1, "2", False)
assert not st.get_user_active_in_channel(1, "2") assert not st.get_user_active_in_channel(1, "2")