More protection - State does it's own saving
This commit is contained in:
parent
cbea7abca2
commit
57f65b265c
5 changed files with 27 additions and 29 deletions
|
@ -9,7 +9,7 @@ from datetime import datetime, timedelta, time
|
|||
|
||||
import cogs.match_button as match_button
|
||||
import matching
|
||||
from state import State, save_to_file, AuthScope
|
||||
from state import State, AuthScope
|
||||
import util
|
||||
|
||||
logger = logging.getLogger("cog")
|
||||
|
@ -38,7 +38,6 @@ class MatchyCog(commands.Cog):
|
|||
|
||||
self.state.set_user_active_in_channel(
|
||||
interaction.user.id, interaction.channel.id)
|
||||
save_to_file(self.state)
|
||||
await interaction.response.send_message(
|
||||
f"Roger roger {interaction.user.mention}!\n"
|
||||
+ f"Added you to {interaction.channel.mention}!",
|
||||
|
@ -52,7 +51,6 @@ class MatchyCog(commands.Cog):
|
|||
|
||||
self.state.set_user_active_in_channel(
|
||||
interaction.user.id, interaction.channel.id, False)
|
||||
save_to_file(self.state)
|
||||
await interaction.response.send_message(
|
||||
f"No worries {interaction.user.mention}. Come back soon :)", ephemeral=True, silent=True)
|
||||
|
||||
|
@ -68,7 +66,6 @@ class MatchyCog(commands.Cog):
|
|||
until = datetime.now() + timedelta(days=days)
|
||||
self.state.set_user_paused_in_channel(
|
||||
interaction.user.id, interaction.channel.id, until)
|
||||
save_to_file(self.state)
|
||||
await interaction.response.send_message(
|
||||
f"Sure thing {interaction.user.mention}!\n"
|
||||
+ f"Paused you until {util.format_day(until)}!",
|
||||
|
@ -127,7 +124,6 @@ class MatchyCog(commands.Cog):
|
|||
# Add the scheduled task and save
|
||||
success = self.state.set_channel_match_task(
|
||||
channel_id, members_min, weekday, hour, not cancel)
|
||||
save_to_file(self.state)
|
||||
|
||||
# Let the user know what happened
|
||||
if not cancel:
|
||||
|
|
|
@ -3,7 +3,7 @@ Owner bot cog
|
|||
"""
|
||||
import logging
|
||||
from discord.ext import commands
|
||||
from state import State, AuthScope, save_to_file
|
||||
from state import State, AuthScope
|
||||
|
||||
logger = logging.getLogger("owner")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
@ -49,7 +49,6 @@ class OwnerCog(commands.Cog):
|
|||
"""
|
||||
if user.isdigit():
|
||||
self._state.set_user_scope(str(user), AuthScope.MATCHER)
|
||||
save_to_file(self._state)
|
||||
logger.info("Granting user %s matcher scope", user)
|
||||
await ctx.reply("Done!", ephemeral=True)
|
||||
else:
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import discord
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Protocol, runtime_checkable
|
||||
from state import State, save_to_file, ts_to_datetime
|
||||
from state import State, ts_to_datetime
|
||||
import util
|
||||
import config
|
||||
|
||||
|
@ -224,7 +224,6 @@ async def match_groups_in_channel(state: State, channel: discord.channel, min: i
|
|||
|
||||
# Save the groups to the history
|
||||
state.log_groups(groups)
|
||||
save_to_file(state)
|
||||
|
||||
logger.info("Done! Matched into %s groups.", len(groups))
|
||||
|
||||
|
|
32
py/state.py
32
py/state.py
|
@ -174,10 +174,11 @@ def datetime_to_ts(ts: datetime) -> str:
|
|||
|
||||
|
||||
class State():
|
||||
def __init__(self, data: dict):
|
||||
def __init__(self, data: dict, file: str | None = None):
|
||||
"""Initialise and validate the state"""
|
||||
self.validate(data)
|
||||
self._dict = copy.deepcopy(data)
|
||||
self._file = file
|
||||
|
||||
def validate(self, dict: dict = None):
|
||||
"""Initialise and validate a state dict"""
|
||||
|
@ -208,7 +209,7 @@ class State():
|
|||
def log_groups(self, groups: list[list[Member]], ts: datetime = None) -> None:
|
||||
"""Log the groups"""
|
||||
ts = datetime_to_ts(ts or datetime.now())
|
||||
with self._safe_wrap() as safe_state:
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
for group in groups:
|
||||
# Update the matchee data with the matches
|
||||
for m in group:
|
||||
|
@ -220,7 +221,7 @@ class State():
|
|||
|
||||
def set_user_scope(self, id: str, scope: str, value: bool = True):
|
||||
"""Add an auth scope to a user"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
# Dive in
|
||||
user = safe_state._users.setdefault(str(id), {})
|
||||
scopes = user.setdefault(_Key.SCOPES, [])
|
||||
|
@ -260,7 +261,7 @@ class State():
|
|||
|
||||
def reactivate_users(self, channel_id: str):
|
||||
"""Reactivate any users who've passed their reactivation time on this channel"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
for user in safe_state._users.values():
|
||||
channels = user.get(_Key.CHANNELS, {})
|
||||
channel = channels.get(str(channel_id), {})
|
||||
|
@ -300,7 +301,7 @@ class State():
|
|||
|
||||
def set_channel_match_task(self, channel_id: str, members_min: int, weekday: int, hour: int, set: bool) -> bool:
|
||||
"""Set up a match task on a channel"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
channel = safe_state._tasks.setdefault(str(channel_id), {})
|
||||
matches = channel.setdefault(_Key.MATCH_TASKS, [])
|
||||
|
||||
|
@ -345,7 +346,7 @@ class State():
|
|||
|
||||
def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value):
|
||||
"""Set a user channel property helper"""
|
||||
with self._safe_wrap() as safe_state:
|
||||
with self._safe_wrap_write() as safe_state:
|
||||
# Dive in
|
||||
user = safe_state._users.setdefault(str(id), {})
|
||||
channels = user.setdefault(_Key.CHANNELS, {})
|
||||
|
@ -355,7 +356,7 @@ class State():
|
|||
channel[key] = value
|
||||
|
||||
@contextmanager
|
||||
def _safe_wrap(self):
|
||||
def _safe_wrap_write(self):
|
||||
"""Safely run any function wrapped in a validate"""
|
||||
# Wrap in a temporary state to validate first to prevent corruption
|
||||
tmp_state = State(self._dict)
|
||||
|
@ -366,6 +367,14 @@ class State():
|
|||
tmp_state.validate()
|
||||
self._dict = tmp_state._dict
|
||||
|
||||
# Write this change out if we have a file
|
||||
if self._file:
|
||||
self._save_to_file()
|
||||
|
||||
def _save_to_file(self):
|
||||
"""Saves the state out to the chosen file"""
|
||||
files.save(self._file, self.dict_internal_copy)
|
||||
|
||||
|
||||
def _migrate(dict: dict):
|
||||
"""Migrate a dict through versions"""
|
||||
|
@ -378,7 +387,7 @@ def _migrate(dict: dict):
|
|||
|
||||
def load_from_file(file: str = _STATE_FILE) -> State:
|
||||
"""
|
||||
Load the state from a file
|
||||
Load the state from a files
|
||||
Apply any required migrations
|
||||
"""
|
||||
loaded = _EMPTY_DICT
|
||||
|
@ -388,14 +397,9 @@ def load_from_file(file: str = _STATE_FILE) -> State:
|
|||
loaded = files.load(file)
|
||||
_migrate(loaded)
|
||||
|
||||
st = State(loaded)
|
||||
st = State(loaded, file)
|
||||
|
||||
# Save out the migrated (or new) file
|
||||
files.save(file, st._dict)
|
||||
|
||||
return st
|
||||
|
||||
|
||||
def save_to_file(state: State, file: str = _STATE_FILE):
|
||||
"""Saves the state out to a file"""
|
||||
files.save(file, state.dict_internal_copy)
|
||||
|
|
|
@ -18,10 +18,10 @@ def test_simple_load_reload():
|
|||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
st._save_to_file()
|
||||
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
st._save_to_file()
|
||||
st = state.load_from_file(path)
|
||||
|
||||
|
||||
|
@ -30,13 +30,13 @@ def test_authscope():
|
|||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
st._save_to_file()
|
||||
|
||||
assert not st.get_user_has_scope(1, state.AuthScope.MATCHER)
|
||||
|
||||
st = state.load_from_file(path)
|
||||
st.set_user_scope(1, state.AuthScope.MATCHER)
|
||||
state.save_to_file(st, path)
|
||||
st._save_to_file()
|
||||
|
||||
st = state.load_from_file(path)
|
||||
assert st.get_user_has_scope(1, state.AuthScope.MATCHER)
|
||||
|
@ -50,13 +50,13 @@ def test_channeljoin():
|
|||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, 'tmp.json')
|
||||
st = state.load_from_file(path)
|
||||
state.save_to_file(st, path)
|
||||
st._save_to_file()
|
||||
|
||||
assert not st.get_user_active_in_channel(1, "2")
|
||||
|
||||
st = state.load_from_file(path)
|
||||
st.set_user_active_in_channel(1, "2", True)
|
||||
state.save_to_file(st, path)
|
||||
st._save_to_file()
|
||||
|
||||
st = state.load_from_file(path)
|
||||
assert st.get_user_active_in_channel(1, "2")
|
||||
|
|
Loading…
Add table
Reference in a new issue