Extract persistence code into own class

This commit is contained in:
Marc Di Luzio 2020-06-02 19:16:02 +01:00
parent 4c76530832
commit c5ebbc3c40
9 changed files with 163 additions and 149 deletions

View file

@ -7,13 +7,14 @@ import (
"os/signal"
"syscall"
"github.com/mdiluz/rove/pkg/persistence"
"github.com/mdiluz/rove/pkg/server"
"github.com/mdiluz/rove/pkg/version"
)
var ver = flag.Bool("version", false, "Display version number")
var port = flag.Int("port", 8080, "The port to host on")
var data = flag.String("data", "/tmp/", "Directory to store persistant data")
var data = flag.String("data", os.TempDir(), "Directory to store persistant data")
func main() {
flag.Parse()
@ -23,9 +24,12 @@ func main() {
os.Exit(0)
}
// Set the persistence path
persistence.SetPath(*data)
s := server.NewServer(
server.OptionPort(*port),
server.OptionPersistentData(*data))
server.OptionPersistentData())
fmt.Println("Initialising...")
if err := s.Initialise(); err != nil {

View file

@ -1,11 +1,7 @@
package accounts
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path"
"github.com/google/uuid"
)
@ -31,13 +27,11 @@ type accountantData struct {
// Accountant manages a set of accounts
type Accountant struct {
Accounts map[uuid.UUID]Account `json:"accounts"`
dataPath string
}
// NewAccountant creates a new accountant
func NewAccountant(dataPath string) *Accountant {
func NewAccountant() *Accountant {
return &Accountant{
dataPath: dataPath,
Accounts: make(map[uuid.UUID]Account),
}
}
@ -64,40 +58,6 @@ func (a *Accountant) RegisterAccount(name string) (acc Account, err error) {
return
}
// path returns the full path to the data file
func (a Accountant) path() string {
return path.Join(a.dataPath, kAccountsFileName)
}
// Load will load the accountant from data
func (a *Accountant) Load() error {
// Don't load anything if the file doesn't exist
_, err := os.Stat(a.path())
if os.IsNotExist(err) {
fmt.Printf("File %s didn't exist, loading with fresh accounts data\n", a.path())
return nil
}
if b, err := ioutil.ReadFile(a.path()); err != nil {
return err
} else if err := json.Unmarshal(b, &a); err != nil {
return err
}
return nil
}
// Save will save the accountant data out
func (a *Accountant) Save() error {
if b, err := json.MarshalIndent(a, "", "\t"); err != nil {
return err
} else {
if err := ioutil.WriteFile(a.path(), b, os.ModePerm); err != nil {
return err
}
}
return nil
}
// AssignPrimary assigns primary ownership of an instance to an account
func (a *Accountant) AssignPrimary(account uuid.UUID, instance uuid.UUID) error {

View file

@ -1,7 +1,6 @@
package accounts
import (
"os"
"testing"
"github.com/google/uuid"
@ -9,7 +8,7 @@ import (
func TestNewAccountant(t *testing.T) {
// Very basic verify here for now
accountant := NewAccountant(os.TempDir())
accountant := NewAccountant()
if accountant == nil {
t.Error("Failed to create accountant")
}
@ -17,7 +16,7 @@ func TestNewAccountant(t *testing.T) {
func TestAccountant_RegisterAccount(t *testing.T) {
accountant := NewAccountant(os.TempDir())
accountant := NewAccountant()
// Start by making two accounts
@ -49,50 +48,8 @@ func TestAccountant_RegisterAccount(t *testing.T) {
}
}
func TestAccountant_LoadSave(t *testing.T) {
accountant := NewAccountant(os.TempDir())
if len(accountant.Accounts) != 0 {
t.Error("New accountant created with non-zero account number")
}
name := uuid.New().String()
a, err := accountant.RegisterAccount(name)
if err != nil {
t.Error(err)
}
if len(accountant.Accounts) != 1 {
t.Error("No new account made")
} else if accountant.Accounts[a.Id].Name != name {
t.Error("New account created with wrong name")
}
// Save out the accountant
if err := accountant.Save(); err != nil {
t.Error(err)
}
// Re-create the accountant
accountant = NewAccountant(os.TempDir())
if len(accountant.Accounts) != 0 {
t.Error("New accountant created with non-zero account number")
}
// Load the old accountant data
if err := accountant.Load(); err != nil {
t.Error(err)
}
// Verify we have the same account again
if len(accountant.Accounts) != 1 {
t.Error("No account after load")
} else if accountant.Accounts[a.Id].Name != name {
t.Error("New account created with wrong name")
}
}
func TestAccountant_AssignPrimary(t *testing.T) {
accountant := NewAccountant(os.TempDir())
accountant := NewAccountant()
if len(accountant.Accounts) != 0 {
t.Error("New accountant created with non-zero account number")
}

View file

@ -1,11 +1,7 @@
package game
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path"
"github.com/google/uuid"
)
@ -30,47 +26,12 @@ type Instance struct {
const kWorldFileName = "rove-world.json"
// NewWorld creates a new world object
func NewWorld(data string) *World {
func NewWorld() *World {
return &World{
Instances: make(map[uuid.UUID]Instance),
dataPath: data,
}
}
// path returns the full path to the data file
func (w *World) path() string {
return path.Join(w.dataPath, kWorldFileName)
}
// Load will load the accountant from data
func (w *World) Load() error {
// Don't load anything if the file doesn't exist
_, err := os.Stat(w.path())
if os.IsNotExist(err) {
fmt.Printf("File %s didn't exist, loading with fresh world data\n", w.path())
return nil
}
if b, err := ioutil.ReadFile(w.path()); err != nil {
return err
} else if err := json.Unmarshal(b, &w); err != nil {
return err
}
return nil
}
// Save will save the accountant data out
func (w *World) Save() error {
if b, err := json.MarshalIndent(w, "", "\t"); err != nil {
return err
} else {
if err := ioutil.WriteFile(w.path(), b, os.ModePerm); err != nil {
return err
}
}
return nil
}
// Adds an instance to the game
func (w *World) CreateInstance() uuid.UUID {
id := uuid.New()

View file

@ -1,20 +1,19 @@
package game
import (
"os"
"testing"
)
func TestNewWorld(t *testing.T) {
// Very basic for now, nothing to verify
world := NewWorld(os.TempDir())
world := NewWorld()
if world == nil {
t.Error("Failed to create world")
}
}
func TestWorld_CreateInstance(t *testing.T) {
world := NewWorld(os.TempDir())
world := NewWorld()
a := world.CreateInstance()
b := world.CreateInstance()

View file

@ -0,0 +1,90 @@
package persistence
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path"
)
// dataPath global path for persistence
var dataPath = os.TempDir()
// SetPath sets the persistent path for the data storage
func SetPath(path string) error {
if info, err := os.Stat(path); err != nil {
return err
} else if !info.IsDir() {
return fmt.Errorf("path for persistence is not directory")
}
dataPath = path
return nil
}
// Converts name to a full path
func jsonPath(name string) string {
return path.Join(dataPath, fmt.Sprintf("rove-%s.json", name))
}
// Save will serialise the interface into a json file
func Save(name string, data interface{}) error {
if b, err := json.MarshalIndent(data, "", "\t"); err != nil {
return err
} else {
if err := ioutil.WriteFile(jsonPath(name), b, os.ModePerm); err != nil {
return err
}
}
return nil
}
// Load will load the interface from the json file
func Load(name string, data interface{}) error {
path := jsonPath(name)
// Don't load anything if the file doesn't exist
_, err := os.Stat(path)
if os.IsNotExist(err) {
fmt.Printf("File %s didn't exist, loading with fresh data\n", path)
return nil
}
// Read and unmarshal the json
if b, err := ioutil.ReadFile(path); err != nil {
return err
} else if err := json.Unmarshal(b, data); err != nil {
return err
}
return nil
}
// saveLoadFunc defines a type of function to save or load an interface
type saveLoadFunc func(string, interface{}) error
func doAll(f saveLoadFunc, args ...interface{}) error {
var name string
for i, a := range args {
if i%2 == 0 {
var ok bool
name, ok = a.(string)
if !ok {
return fmt.Errorf("Incorrect args")
}
} else {
if err := f(name, a); err != nil {
return err
}
}
}
return nil
}
// SaveAll allows for saving multiple structures in a single call
func SaveAll(args ...interface{}) error {
return doAll(Save, args...)
}
// LoadAll allows for loading multiple structures in a single call
func LoadAll(args ...interface{}) error {
return doAll(Load, args...)
}

View file

@ -0,0 +1,52 @@
package persistence
import (
"io/ioutil"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
type Dummy struct {
Success bool
Value int
}
func TestPersistence_LoadSave(t *testing.T) {
tmp, err := ioutil.TempDir(os.TempDir(), "rove_persistence_test")
assert.NoError(t, err, "Failed to get tempdir path")
assert.NoError(t, SetPath(tmp), "Failed to get set tempdir to persistence path")
// Try and save out the dummy
var dummy Dummy
dummy.Success = true
assert.NoError(t, Save("test", dummy), "Failed to save out dummy file")
// Load back the dummy
dummy = Dummy{}
assert.NoError(t, Load("test", &dummy), "Failed to load in dummy file")
assert.Equal(t, true, dummy.Success, "Did not successfully load true value from file")
}
func TestPersistence_LoadSaveAll(t *testing.T) {
tmp, err := ioutil.TempDir(os.TempDir(), "rove_persistence_test")
assert.NoError(t, err, "Failed to get tempdir path")
assert.NoError(t, SetPath(tmp), "Failed to get set tempdir to persistence path")
// Try and save out the dummy
var dummyA Dummy
var dummyB Dummy
dummyA.Value = 1
dummyB.Value = 2
assert.NoError(t, SaveAll("a", dummyA, "b", dummyB), "Failed to save out dummy file")
// Load back the dummy
dummyA = Dummy{}
dummyB = Dummy{}
assert.NoError(t, LoadAll("a", &dummyA, "b", &dummyB), "Failed to load in dummy file")
assert.Equal(t, 1, dummyA.Value, "Did not successfully load int value from file")
assert.Equal(t, 2, dummyB.Value, "Did not successfully load int value from file")
}

View file

@ -11,6 +11,7 @@ import (
"github.com/gorilla/mux"
"github.com/mdiluz/rove/pkg/accounts"
"github.com/mdiluz/rove/pkg/game"
"github.com/mdiluz/rove/pkg/persistence"
)
const (
@ -48,10 +49,9 @@ func OptionPort(port int) ServerOption {
}
// OptionPersistentData sets the server data to be persistent
func OptionPersistentData(loc string) ServerOption {
func OptionPersistentData() ServerOption {
return func(s *Server) {
s.persistence = PersistentData
s.persistenceLocation = loc
}
}
@ -76,8 +76,8 @@ func NewServer(opts ...ServerOption) *Server {
s.server = &http.Server{Addr: fmt.Sprintf(":%d", s.port), Handler: router}
// Create the accountant
s.accountant = accounts.NewAccountant(s.persistenceLocation)
s.world = game.NewWorld(s.persistenceLocation)
s.accountant = accounts.NewAccountant()
s.world = game.NewWorld()
return s
}
@ -87,10 +87,7 @@ func (s *Server) Initialise() error {
// Load the accounts if requested
if s.persistence == PersistentData {
if err := s.accountant.Load(); err != nil {
return err
}
if err := s.world.Load(); err != nil {
if err := persistence.LoadAll("accounts", &s.accountant, "world", &s.world); err != nil {
return err
}
}
@ -130,10 +127,7 @@ func (s *Server) Close() error {
// Save the accounts if requested
if s.persistence == PersistentData {
if err := s.accountant.Save(); err != nil {
return err
}
if err := s.world.Save(); err != nil {
if err := persistence.SaveAll("accounts", s.accountant, "world", s.world); err != nil {
return err
}
}

View file

@ -1,7 +1,6 @@
package server
import (
"os"
"testing"
)
@ -22,13 +21,11 @@ func TestNewServer_OptionPort(t *testing.T) {
}
func TestNewServer_OptionPersistentData(t *testing.T) {
server := NewServer(OptionPersistentData(os.TempDir()))
server := NewServer(OptionPersistentData())
if server == nil {
t.Error("Failed to create server")
} else if server.persistence != PersistentData {
t.Error("Failed to set server persistent data")
} else if server.persistenceLocation != os.TempDir() {
t.Error("Failed to set server persistent data path")
}
}
@ -47,7 +44,7 @@ func TestServer_Run(t *testing.T) {
}
func TestServer_RunPersistentData(t *testing.T) {
server := NewServer(OptionPersistentData(os.TempDir()))
server := NewServer(OptionPersistentData())
if server == nil {
t.Error("Failed to create server")
} else if err := server.Initialise(); err != nil {