diff --git a/matchy/files/state.py b/matchy/files/state.py index 3b2201b..438cf59 100644 --- a/matchy/files/state.py +++ b/matchy/files/state.py @@ -171,16 +171,17 @@ def datetime_to_ts(ts: datetime) -> str: class State(): def __init__(self, data: dict, file: str | None = None): - """Initialise and validate the state""" - self.validate(data) + """Copy the data, migrate if needed, and validate""" self._dict = copy.deepcopy(data) self._file = file - def validate(self, dict: dict = None): - """Initialise and validate a state dict""" - if not dict: - dict = self._dict - _SCHEMA.validate(dict) + version = self._dict.get("version", 0) + for i in range(version, _VERSION): + logger.info("Migrating from v%s to v%s", version, version+1) + _MIGRATIONS[i](self._dict) + self._dict[_Key.VERSION] = _VERSION + + _SCHEMA.validate(self._dict) @staticmethod def safe_write(func): @@ -193,9 +194,9 @@ class State(): def inner(self: State, *args, **kwargs): tmp = State(self._dict, self._file) func(tmp, *args, **kwargs) - tmp.validate() + _SCHEMA.validate(tmp._dict) if tmp._file: - tmp._save_to_file() + ops.save(tmp._file, tmp._dict) self._dict = tmp._dict return inner @@ -345,11 +346,6 @@ class State(): # We did not manage to remove the schedule (or add it? though that should be impossible) return False - @property - def dict_internal_copy(self) -> dict: - """Only to be used to get the internal dict as a copy""" - return copy.deepcopy(self._dict) - @property def _users(self) -> dict[str]: return self._dict[_Key.USERS] @@ -361,43 +357,17 @@ class State(): @safe_write def _set_user_channel_prop(self, id: str, channel_id: str, key: str, value): """Set a user channel property helper""" - # Dive in user = self._users.setdefault(str(id), {}) channels = user.setdefault(_Key.CHANNELS, {}) channel = channels.setdefault(str(channel_id), {}) - - # Set the value channel[key] = value - def _save_to_file(self): - """Saves the state out to the chosen file""" - ops.save(self._file, self.dict_internal_copy) - - -def _migrate(dict: dict): - """Migrate a dict through versions""" - version = dict.get("version", 0) - for i in range(version, _VERSION): - logger.info("Migrating from v%s to v%s", version, version+1) - _MIGRATIONS[i](dict) - dict[_Key.VERSION] = _VERSION - def load_from_file(file: str) -> State: """ Load the state from a files - Apply any required migrations """ - loaded = _EMPTY_DICT - - # If there's a file load it and try to migrate - if os.path.isfile(file): - loaded = ops.load(file) - _migrate(loaded) - + loaded = ops.load(file) if os.path.isfile(file) else _EMPTY_DICT st = State(loaded, file) - - # Save out the migrated (or new) file ops.save(file, st._dict) - return st