Implement user pausing with /pause
This commit is contained in:
parent
a480549ad3
commit
7efe781e66
4 changed files with 125 additions and 67 deletions
|
@ -30,7 +30,6 @@ Matchy is configured by a `config.json` file that takes this format:
|
|||
|
||||
## TODO
|
||||
* Write bot tests with [dpytest](https://dpytest.readthedocs.io/en/latest/tutorials/getting_started.html)
|
||||
* Implement /pause to pause a user for a little while
|
||||
* Move more constants to the config
|
||||
* Add scheduling functionality
|
||||
* Fix logging in some sub files (doesn't seem to actually be output?)
|
||||
|
|
20
py/matchy.py
20
py/matchy.py
|
@ -75,7 +75,7 @@ async def close(ctx: commands.Context):
|
|||
@bot.tree.command(description="Join the matchees for this channel")
|
||||
@commands.guild_only()
|
||||
async def join(interaction: discord.Interaction):
|
||||
State.set_use_active_in_channel(
|
||||
State.set_user_active_in_channel(
|
||||
interaction.user.id, interaction.channel.id)
|
||||
state.save_to_file(State, STATE_FILE)
|
||||
await interaction.response.send_message(
|
||||
|
@ -87,13 +87,26 @@ async def join(interaction: discord.Interaction):
|
|||
@bot.tree.command(description="Leave the matchees for this channel")
|
||||
@commands.guild_only()
|
||||
async def leave(interaction: discord.Interaction):
|
||||
State.set_use_active_in_channel(
|
||||
State.set_user_active_in_channel(
|
||||
interaction.user.id, interaction.channel.id, False)
|
||||
state.save_to_file(State, STATE_FILE)
|
||||
await interaction.response.send_message(
|
||||
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
|
||||
|
||||
|
||||
@bot.tree.command(description="Pause your matching in this channel for a number of days")
|
||||
@commands.guild_only()
|
||||
@app_commands.describe(days="Days to pause for (defaults to 7)")
|
||||
async def pause(interaction: discord.Interaction, days: int = None):
|
||||
if not days: # Default to a week
|
||||
days = 7
|
||||
State.set_user_paused_in_channel(
|
||||
interaction.user.id, interaction.channel.id, days)
|
||||
state.save_to_file(State, STATE_FILE)
|
||||
await interaction.response.send_message(
|
||||
f"Sure thing {interaction.user.mention}. Paused you for {days} days!", ephemeral=True, silent=True)
|
||||
|
||||
|
||||
@bot.tree.command(description="List the matchees for this channel")
|
||||
@commands.guild_only()
|
||||
async def list(interaction: discord.Interaction):
|
||||
|
@ -195,6 +208,9 @@ class DynamicGroupButton(discord.ui.DynamicItem[discord.ui.Button],
|
|||
|
||||
def get_matchees_in_channel(channel: discord.channel):
|
||||
"""Fetches the matchees in a channel"""
|
||||
# Reactivate any unpaused users
|
||||
State.reactivate_users(channel.id)
|
||||
|
||||
# Gather up the prospective matchees
|
||||
return [m for m in channel.members if State.get_user_active_in_channel(m.id, channel.id)]
|
||||
|
||||
|
|
119
py/state.py
119
py/state.py
|
@ -1,11 +1,12 @@
|
|||
"""Store bot state"""
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from schema import Schema, And, Use, Optional
|
||||
from typing import Protocol
|
||||
import files
|
||||
import copy
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = logging.getLogger("state")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
@ -83,6 +84,8 @@ _SCHEMA = Schema(
|
|||
Optional(str): {
|
||||
# Whether the user is signed up in this channel
|
||||
_Key.ACTIVE: And(Use(bool)),
|
||||
# A timestamp for when to re-activate the user
|
||||
Optional(_Key.REACTIVATE): And(Use(str)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -106,24 +109,21 @@ class Member(Protocol):
|
|||
|
||||
|
||||
def ts_to_datetime(ts: str) -> datetime:
|
||||
"""Convert a ts to datetime using the internal format"""
|
||||
"""Convert a string ts to datetime using the internal format"""
|
||||
return datetime.strptime(ts, _TIME_FORMAT)
|
||||
|
||||
|
||||
def datetime_to_ts(ts: datetime) -> str:
|
||||
"""Convert a datetime to a string ts using the internal format"""
|
||||
return datetime.strftime(ts, _TIME_FORMAT)
|
||||
|
||||
|
||||
class State():
|
||||
def __init__(self, data: dict = _EMPTY_DICT):
|
||||
"""Initialise and validate the state"""
|
||||
self.validate(data)
|
||||
self._dict = copy.deepcopy(data)
|
||||
|
||||
@property
|
||||
def _history(self) -> dict[str]:
|
||||
return self._dict[_Key.HISTORY]
|
||||
|
||||
@property
|
||||
def _users(self) -> dict[str]:
|
||||
return self._dict[_Key.USERS]
|
||||
|
||||
def validate(self, dict: dict = None):
|
||||
"""Initialise and validate a state dict"""
|
||||
if not dict:
|
||||
|
@ -138,14 +138,13 @@ class State():
|
|||
def get_user_matches(self, id: int) -> list[int]:
|
||||
return self._users.get(str(id), {}).get(_Key.MATCHES, {})
|
||||
|
||||
def log_groups(self, groups: list[list[Member]], ts: datetime = datetime.now()) -> None:
|
||||
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
|
||||
"""Log the groups"""
|
||||
tmp_state = State(self._dict)
|
||||
ts = datetime.strftime(ts, _TIME_FORMAT)
|
||||
|
||||
ts = datetime_to_ts(ts or datetime.now())
|
||||
with self._safe_wrap() as safe_state:
|
||||
# Grab or create the hitory item for this set of groups
|
||||
history_item = {}
|
||||
tmp_state._history[ts] = history_item
|
||||
safe_state._history[ts] = history_item
|
||||
history_item_groups = []
|
||||
history_item[_Key.GROUPS] = history_item_groups
|
||||
|
||||
|
@ -158,23 +157,20 @@ class State():
|
|||
|
||||
# Update the matchee data with the matches
|
||||
for m in group:
|
||||
matchee = tmp_state._users.get(str(m.id), {})
|
||||
matchee = safe_state._users.get(str(m.id), {})
|
||||
matchee_matches = matchee.get(_Key.MATCHES, {})
|
||||
|
||||
for o in (o for o in group if o.id != m.id):
|
||||
matchee_matches[str(o.id)] = ts
|
||||
|
||||
matchee[_Key.MATCHES] = matchee_matches
|
||||
tmp_state._users[str(m.id)] = matchee
|
||||
|
||||
# Validate before storing the result
|
||||
tmp_state.validate()
|
||||
self._dict = tmp_state._dict
|
||||
safe_state._users[str(m.id)] = matchee
|
||||
|
||||
def set_user_scope(self, id: str, scope: str, value: bool = True):
|
||||
"""Add an auth scope to a user"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
# Dive in
|
||||
user = self._users.get(str(id), {})
|
||||
user = safe_state._users.get(str(id), {})
|
||||
scopes = user.get(_Key.SCOPES, [])
|
||||
|
||||
# Set the value
|
||||
|
@ -185,7 +181,7 @@ class State():
|
|||
|
||||
# Roll out
|
||||
user[_Key.SCOPES] = scopes
|
||||
self._users[id] = user
|
||||
safe_state._users[str(id)] = user
|
||||
|
||||
def get_user_has_scope(self, id: str, scope: str) -> bool:
|
||||
"""
|
||||
|
@ -196,20 +192,9 @@ class State():
|
|||
scopes = user.get(_Key.SCOPES, [])
|
||||
return AuthScope.OWNER in scopes or scope in scopes
|
||||
|
||||
def set_use_active_in_channel(self, id: str, channel_id: str, active: bool = True):
|
||||
def set_user_active_in_channel(self, id: str, channel_id: str, active: bool = True):
|
||||
"""Set a user as active (or not) on a given channel"""
|
||||
# Dive in
|
||||
user = self._users.get(str(id), {})
|
||||
channels = user.get(_Key.CHANNELS, {})
|
||||
channel = channels.get(str(channel_id), {})
|
||||
|
||||
# Set the value
|
||||
channel[_Key.ACTIVE] = active
|
||||
|
||||
# Unroll
|
||||
channels[str(channel_id)] = channel
|
||||
user[_Key.CHANNELS] = channels
|
||||
self._users[str(id)] = user
|
||||
self._set_user_channel_prop(id, channel_id, _Key.ACTIVE, active)
|
||||
|
||||
def get_user_active_in_channel(self, id: str, channel_id: str) -> bool:
|
||||
"""Get a users active channels"""
|
||||
|
@ -217,11 +202,69 @@ class State():
|
|||
channels = user.get(_Key.CHANNELS, {})
|
||||
return str(channel_id) in [channel for (channel, props) in channels.items() if props.get(_Key.ACTIVE, False)]
|
||||
|
||||
def set_user_paused_in_channel(self, id: str, channel_id: str, days: int):
|
||||
"""Sets a user as paused in a channel"""
|
||||
# Deactivate the user in the channel first
|
||||
self.set_user_active_in_channel(id, channel_id, False)
|
||||
|
||||
# Set the reactivate time the number of days in the future
|
||||
ts = datetime.now() + timedelta(days=days)
|
||||
self._set_user_channel_prop(
|
||||
id, channel_id, _Key.REACTIVATE, datetime_to_ts(ts))
|
||||
|
||||
def reactivate_users(self, channel_id: str):
|
||||
"""Reactivate any users who've passed their reactivation time on this channel"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
for user in safe_state._users.values():
|
||||
channels = user.get(_Key.CHANNELS, {})
|
||||
channel = channels.get(str(channel_id), {})
|
||||
if channel and not channel[_Key.ACTIVE]:
|
||||
reactivate = channel.get(_Key.REACTIVATE, None)
|
||||
# Check if we've gone past the reactivation time and re-activate
|
||||
if reactivate and datetime.now() > ts_to_datetime(reactivate):
|
||||
channel[_Key.ACTIVE] = True
|
||||
|
||||
@property
|
||||
def dict_internal(self) -> dict:
|
||||
def dict_internal_copy(self) -> dict:
|
||||
"""Only to be used to get the internal dict as a copy"""
|
||||
return copy.deepcopy(self._dict)
|
||||
|
||||
@property
|
||||
def _history(self) -> dict[str]:
|
||||
return self._dict[_Key.HISTORY]
|
||||
|
||||
@property
|
||||
def _users(self) -> dict[str]:
|
||||
return self._dict[_Key.USERS]
|
||||
|
||||
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
|
||||
"""Set a user channel property helper"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
# Dive in
|
||||
user = safe_state._users.get(str(id), {})
|
||||
channels = user.get(_Key.CHANNELS, {})
|
||||
channel = channels.get(str(channel_id), {})
|
||||
|
||||
# Set the value
|
||||
channel[key] = value
|
||||
|
||||
# Unroll
|
||||
channels[str(channel_id)] = channel
|
||||
user[_Key.CHANNELS] = channels
|
||||
safe_state._users[str(id)] = user
|
||||
|
||||
@contextmanager
|
||||
def _safe_wrap(self):
|
||||
"""Safely run any function wrapped in a validate"""
|
||||
# Wrap in a temporary state to validate first to prevent corruption
|
||||
tmp_state = State(self._dict)
|
||||
try:
|
||||
yield tmp_state
|
||||
finally:
|
||||
# Validate and then overwrite our dict with the new one
|
||||
tmp_state.validate()
|
||||
self._dict = tmp_state._dict
|
||||
|
||||
|
||||
def _migrate(dict: dict):
|
||||
"""Migrate a dict through versions"""
|
||||
|
@ -254,4 +297,4 @@ def load_from_file(file: str) -> State:
|
|||
|
||||
def save_to_file(state: State, file: str):
|
||||
"""Saves the state out to a file"""
|
||||
files.save(file, state.dict_internal)
|
||||
files.save(file, state.dict_internal_copy)
|
||||
|
|
|
@ -55,11 +55,11 @@ def test_channeljoin():
|
|||
assert not st.get_user_active_in_channel(1, "2")
|
||||
|
||||
st = state.load_from_file(path)
|
||||
st.set_use_active_in_channel(1, "2", True)
|
||||
st.set_user_active_in_channel(1, "2", True)
|
||||
state.save_to_file(st, path)
|
||||
|
||||
st = state.load_from_file(path)
|
||||
assert st.get_user_active_in_channel(1, "2")
|
||||
|
||||
st.set_use_active_in_channel(1, "2", False)
|
||||
st.set_user_active_in_channel(1, "2", False)
|
||||
assert not st.get_user_active_in_channel(1, "2")
|
||||
|
|
Loading…
Add table
Reference in a new issue