diff --git a/main.go b/main.go index ccc2f93..d8c0762 100644 --- a/main.go +++ b/main.go @@ -7,13 +7,14 @@ import ( "os/signal" "syscall" + "github.com/mdiluz/rove/pkg/persistence" "github.com/mdiluz/rove/pkg/server" "github.com/mdiluz/rove/pkg/version" ) var ver = flag.Bool("version", false, "Display version number") var port = flag.Int("port", 8080, "The port to host on") -var data = flag.String("data", "/tmp/", "Directory to store persistant data") +var data = flag.String("data", os.TempDir(), "Directory to store persistant data") func main() { flag.Parse() @@ -23,9 +24,12 @@ func main() { os.Exit(0) } + // Set the persistence path + persistence.SetPath(*data) + s := server.NewServer( server.OptionPort(*port), - server.OptionPersistentData(*data)) + server.OptionPersistentData()) fmt.Println("Initialising...") if err := s.Initialise(); err != nil { diff --git a/pkg/accounts/accounts.go b/pkg/accounts/accounts.go index e975a08..37effec 100644 --- a/pkg/accounts/accounts.go +++ b/pkg/accounts/accounts.go @@ -1,11 +1,7 @@ package accounts import ( - "encoding/json" "fmt" - "io/ioutil" - "os" - "path" "github.com/google/uuid" ) @@ -31,13 +27,11 @@ type accountantData struct { // Accountant manages a set of accounts type Accountant struct { Accounts map[uuid.UUID]Account `json:"accounts"` - dataPath string } // NewAccountant creates a new accountant -func NewAccountant(dataPath string) *Accountant { +func NewAccountant() *Accountant { return &Accountant{ - dataPath: dataPath, Accounts: make(map[uuid.UUID]Account), } } @@ -64,40 +58,6 @@ func (a *Accountant) RegisterAccount(name string) (acc Account, err error) { return } -// path returns the full path to the data file -func (a Accountant) path() string { - return path.Join(a.dataPath, kAccountsFileName) -} - -// Load will load the accountant from data -func (a *Accountant) Load() error { - // Don't load anything if the file doesn't exist - _, err := os.Stat(a.path()) - if os.IsNotExist(err) { - fmt.Printf("File %s didn't exist, loading with fresh accounts data\n", a.path()) - return nil - } - - if b, err := ioutil.ReadFile(a.path()); err != nil { - return err - } else if err := json.Unmarshal(b, &a); err != nil { - return err - } - return nil -} - -// Save will save the accountant data out -func (a *Accountant) Save() error { - if b, err := json.MarshalIndent(a, "", "\t"); err != nil { - return err - } else { - if err := ioutil.WriteFile(a.path(), b, os.ModePerm); err != nil { - return err - } - } - return nil -} - // AssignPrimary assigns primary ownership of an instance to an account func (a *Accountant) AssignPrimary(account uuid.UUID, instance uuid.UUID) error { diff --git a/pkg/accounts/accounts_test.go b/pkg/accounts/accounts_test.go index bef693f..bb6c9a7 100644 --- a/pkg/accounts/accounts_test.go +++ b/pkg/accounts/accounts_test.go @@ -1,7 +1,6 @@ package accounts import ( - "os" "testing" "github.com/google/uuid" @@ -9,7 +8,7 @@ import ( func TestNewAccountant(t *testing.T) { // Very basic verify here for now - accountant := NewAccountant(os.TempDir()) + accountant := NewAccountant() if accountant == nil { t.Error("Failed to create accountant") } @@ -17,7 +16,7 @@ func TestNewAccountant(t *testing.T) { func TestAccountant_RegisterAccount(t *testing.T) { - accountant := NewAccountant(os.TempDir()) + accountant := NewAccountant() // Start by making two accounts @@ -49,50 +48,8 @@ func TestAccountant_RegisterAccount(t *testing.T) { } } -func TestAccountant_LoadSave(t *testing.T) { - accountant := NewAccountant(os.TempDir()) - if len(accountant.Accounts) != 0 { - t.Error("New accountant created with non-zero account number") - } - - name := uuid.New().String() - a, err := accountant.RegisterAccount(name) - if err != nil { - t.Error(err) - } - - if len(accountant.Accounts) != 1 { - t.Error("No new account made") - } else if accountant.Accounts[a.Id].Name != name { - t.Error("New account created with wrong name") - } - - // Save out the accountant - if err := accountant.Save(); err != nil { - t.Error(err) - } - - // Re-create the accountant - accountant = NewAccountant(os.TempDir()) - if len(accountant.Accounts) != 0 { - t.Error("New accountant created with non-zero account number") - } - - // Load the old accountant data - if err := accountant.Load(); err != nil { - t.Error(err) - } - - // Verify we have the same account again - if len(accountant.Accounts) != 1 { - t.Error("No account after load") - } else if accountant.Accounts[a.Id].Name != name { - t.Error("New account created with wrong name") - } -} - func TestAccountant_AssignPrimary(t *testing.T) { - accountant := NewAccountant(os.TempDir()) + accountant := NewAccountant() if len(accountant.Accounts) != 0 { t.Error("New accountant created with non-zero account number") } diff --git a/pkg/game/world.go b/pkg/game/world.go index 3138953..14311b0 100644 --- a/pkg/game/world.go +++ b/pkg/game/world.go @@ -1,11 +1,7 @@ package game import ( - "encoding/json" "fmt" - "io/ioutil" - "os" - "path" "github.com/google/uuid" ) @@ -30,47 +26,12 @@ type Instance struct { const kWorldFileName = "rove-world.json" // NewWorld creates a new world object -func NewWorld(data string) *World { +func NewWorld() *World { return &World{ Instances: make(map[uuid.UUID]Instance), - dataPath: data, } } -// path returns the full path to the data file -func (w *World) path() string { - return path.Join(w.dataPath, kWorldFileName) -} - -// Load will load the accountant from data -func (w *World) Load() error { - // Don't load anything if the file doesn't exist - _, err := os.Stat(w.path()) - if os.IsNotExist(err) { - fmt.Printf("File %s didn't exist, loading with fresh world data\n", w.path()) - return nil - } - - if b, err := ioutil.ReadFile(w.path()); err != nil { - return err - } else if err := json.Unmarshal(b, &w); err != nil { - return err - } - return nil -} - -// Save will save the accountant data out -func (w *World) Save() error { - if b, err := json.MarshalIndent(w, "", "\t"); err != nil { - return err - } else { - if err := ioutil.WriteFile(w.path(), b, os.ModePerm); err != nil { - return err - } - } - return nil -} - // Adds an instance to the game func (w *World) CreateInstance() uuid.UUID { id := uuid.New() diff --git a/pkg/game/world_test.go b/pkg/game/world_test.go index 0bfc070..e5fd031 100644 --- a/pkg/game/world_test.go +++ b/pkg/game/world_test.go @@ -1,20 +1,19 @@ package game import ( - "os" "testing" ) func TestNewWorld(t *testing.T) { // Very basic for now, nothing to verify - world := NewWorld(os.TempDir()) + world := NewWorld() if world == nil { t.Error("Failed to create world") } } func TestWorld_CreateInstance(t *testing.T) { - world := NewWorld(os.TempDir()) + world := NewWorld() a := world.CreateInstance() b := world.CreateInstance() diff --git a/pkg/persistence/persistence.go b/pkg/persistence/persistence.go new file mode 100644 index 0000000..6705767 --- /dev/null +++ b/pkg/persistence/persistence.go @@ -0,0 +1,90 @@ +package persistence + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path" +) + +// dataPath global path for persistence +var dataPath = os.TempDir() + +// SetPath sets the persistent path for the data storage +func SetPath(path string) error { + if info, err := os.Stat(path); err != nil { + return err + } else if !info.IsDir() { + return fmt.Errorf("path for persistence is not directory") + } + dataPath = path + return nil +} + +// Converts name to a full path +func jsonPath(name string) string { + return path.Join(dataPath, fmt.Sprintf("rove-%s.json", name)) +} + +// Save will serialise the interface into a json file +func Save(name string, data interface{}) error { + if b, err := json.MarshalIndent(data, "", "\t"); err != nil { + return err + } else { + if err := ioutil.WriteFile(jsonPath(name), b, os.ModePerm); err != nil { + return err + } + } + return nil +} + +// Load will load the interface from the json file +func Load(name string, data interface{}) error { + path := jsonPath(name) + // Don't load anything if the file doesn't exist + _, err := os.Stat(path) + if os.IsNotExist(err) { + fmt.Printf("File %s didn't exist, loading with fresh data\n", path) + return nil + } + + // Read and unmarshal the json + if b, err := ioutil.ReadFile(path); err != nil { + return err + } else if err := json.Unmarshal(b, data); err != nil { + return err + } + return nil +} + +// saveLoadFunc defines a type of function to save or load an interface +type saveLoadFunc func(string, interface{}) error + +func doAll(f saveLoadFunc, args ...interface{}) error { + var name string + for i, a := range args { + if i%2 == 0 { + var ok bool + name, ok = a.(string) + if !ok { + return fmt.Errorf("Incorrect args") + } + } else { + if err := f(name, a); err != nil { + return err + } + } + } + return nil +} + +// SaveAll allows for saving multiple structures in a single call +func SaveAll(args ...interface{}) error { + return doAll(Save, args...) +} + +// LoadAll allows for loading multiple structures in a single call +func LoadAll(args ...interface{}) error { + return doAll(Load, args...) +} diff --git a/pkg/persistence/persistence_test.go b/pkg/persistence/persistence_test.go new file mode 100644 index 0000000..e4117a1 --- /dev/null +++ b/pkg/persistence/persistence_test.go @@ -0,0 +1,52 @@ +package persistence + +import ( + "io/ioutil" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +type Dummy struct { + Success bool + Value int +} + +func TestPersistence_LoadSave(t *testing.T) { + tmp, err := ioutil.TempDir(os.TempDir(), "rove_persistence_test") + assert.NoError(t, err, "Failed to get tempdir path") + + assert.NoError(t, SetPath(tmp), "Failed to get set tempdir to persistence path") + + // Try and save out the dummy + var dummy Dummy + dummy.Success = true + assert.NoError(t, Save("test", dummy), "Failed to save out dummy file") + + // Load back the dummy + dummy = Dummy{} + assert.NoError(t, Load("test", &dummy), "Failed to load in dummy file") + assert.Equal(t, true, dummy.Success, "Did not successfully load true value from file") +} + +func TestPersistence_LoadSaveAll(t *testing.T) { + tmp, err := ioutil.TempDir(os.TempDir(), "rove_persistence_test") + assert.NoError(t, err, "Failed to get tempdir path") + + assert.NoError(t, SetPath(tmp), "Failed to get set tempdir to persistence path") + + // Try and save out the dummy + var dummyA Dummy + var dummyB Dummy + dummyA.Value = 1 + dummyB.Value = 2 + assert.NoError(t, SaveAll("a", dummyA, "b", dummyB), "Failed to save out dummy file") + + // Load back the dummy + dummyA = Dummy{} + dummyB = Dummy{} + assert.NoError(t, LoadAll("a", &dummyA, "b", &dummyB), "Failed to load in dummy file") + assert.Equal(t, 1, dummyA.Value, "Did not successfully load int value from file") + assert.Equal(t, 2, dummyB.Value, "Did not successfully load int value from file") +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 3ffaee4..37ff0ae 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -11,6 +11,7 @@ import ( "github.com/gorilla/mux" "github.com/mdiluz/rove/pkg/accounts" "github.com/mdiluz/rove/pkg/game" + "github.com/mdiluz/rove/pkg/persistence" ) const ( @@ -48,10 +49,9 @@ func OptionPort(port int) ServerOption { } // OptionPersistentData sets the server data to be persistent -func OptionPersistentData(loc string) ServerOption { +func OptionPersistentData() ServerOption { return func(s *Server) { s.persistence = PersistentData - s.persistenceLocation = loc } } @@ -76,8 +76,8 @@ func NewServer(opts ...ServerOption) *Server { s.server = &http.Server{Addr: fmt.Sprintf(":%d", s.port), Handler: router} // Create the accountant - s.accountant = accounts.NewAccountant(s.persistenceLocation) - s.world = game.NewWorld(s.persistenceLocation) + s.accountant = accounts.NewAccountant() + s.world = game.NewWorld() return s } @@ -87,10 +87,7 @@ func (s *Server) Initialise() error { // Load the accounts if requested if s.persistence == PersistentData { - if err := s.accountant.Load(); err != nil { - return err - } - if err := s.world.Load(); err != nil { + if err := persistence.LoadAll("accounts", &s.accountant, "world", &s.world); err != nil { return err } } @@ -130,10 +127,7 @@ func (s *Server) Close() error { // Save the accounts if requested if s.persistence == PersistentData { - if err := s.accountant.Save(); err != nil { - return err - } - if err := s.world.Save(); err != nil { + if err := persistence.SaveAll("accounts", s.accountant, "world", s.world); err != nil { return err } } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index f4598d7..619ebbe 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -1,7 +1,6 @@ package server import ( - "os" "testing" ) @@ -22,13 +21,11 @@ func TestNewServer_OptionPort(t *testing.T) { } func TestNewServer_OptionPersistentData(t *testing.T) { - server := NewServer(OptionPersistentData(os.TempDir())) + server := NewServer(OptionPersistentData()) if server == nil { t.Error("Failed to create server") } else if server.persistence != PersistentData { t.Error("Failed to set server persistent data") - } else if server.persistenceLocation != os.TempDir() { - t.Error("Failed to set server persistent data path") } } @@ -47,7 +44,7 @@ func TestServer_Run(t *testing.T) { } func TestServer_RunPersistentData(t *testing.T) { - server := NewServer(OptionPersistentData(os.TempDir())) + server := NewServer(OptionPersistentData()) if server == nil { t.Error("Failed to create server") } else if err := server.Initialise(); err != nil {