From c5ebbc3c400a47ed5863f23186a2887eeff85f47 Mon Sep 17 00:00:00 2001
From: Marc Di Luzio <marc.diluzio@gmail.com>
Date: Tue, 2 Jun 2020 19:16:02 +0100
Subject: [PATCH] Extract persistence code into own class

---
 main.go                             |  8 ++-
 pkg/accounts/accounts.go            | 42 +-------------
 pkg/accounts/accounts_test.go       | 49 +---------------
 pkg/game/world.go                   | 41 +------------
 pkg/game/world_test.go              |  5 +-
 pkg/persistence/persistence.go      | 90 +++++++++++++++++++++++++++++
 pkg/persistence/persistence_test.go | 52 +++++++++++++++++
 pkg/server/server.go                | 18 ++----
 pkg/server/server_test.go           |  7 +--
 9 files changed, 163 insertions(+), 149 deletions(-)
 create mode 100644 pkg/persistence/persistence.go
 create mode 100644 pkg/persistence/persistence_test.go

diff --git a/main.go b/main.go
index ccc2f93..d8c0762 100644
--- a/main.go
+++ b/main.go
@@ -7,13 +7,14 @@ import (
 	"os/signal"
 	"syscall"
 
+	"github.com/mdiluz/rove/pkg/persistence"
 	"github.com/mdiluz/rove/pkg/server"
 	"github.com/mdiluz/rove/pkg/version"
 )
 
 var ver = flag.Bool("version", false, "Display version number")
 var port = flag.Int("port", 8080, "The port to host on")
-var data = flag.String("data", "/tmp/", "Directory to store persistant data")
+var data = flag.String("data", os.TempDir(), "Directory to store persistant data")
 
 func main() {
 	flag.Parse()
@@ -23,9 +24,12 @@ func main() {
 		os.Exit(0)
 	}
 
+	// Set the persistence path
+	persistence.SetPath(*data)
+
 	s := server.NewServer(
 		server.OptionPort(*port),
-		server.OptionPersistentData(*data))
+		server.OptionPersistentData())
 
 	fmt.Println("Initialising...")
 	if err := s.Initialise(); err != nil {
diff --git a/pkg/accounts/accounts.go b/pkg/accounts/accounts.go
index e975a08..37effec 100644
--- a/pkg/accounts/accounts.go
+++ b/pkg/accounts/accounts.go
@@ -1,11 +1,7 @@
 package accounts
 
 import (
-	"encoding/json"
 	"fmt"
-	"io/ioutil"
-	"os"
-	"path"
 
 	"github.com/google/uuid"
 )
@@ -31,13 +27,11 @@ type accountantData struct {
 // Accountant manages a set of accounts
 type Accountant struct {
 	Accounts map[uuid.UUID]Account `json:"accounts"`
-	dataPath string
 }
 
 // NewAccountant creates a new accountant
-func NewAccountant(dataPath string) *Accountant {
+func NewAccountant() *Accountant {
 	return &Accountant{
-		dataPath: dataPath,
 		Accounts: make(map[uuid.UUID]Account),
 	}
 }
@@ -64,40 +58,6 @@ func (a *Accountant) RegisterAccount(name string) (acc Account, err error) {
 	return
 }
 
-// path returns the full path to the data file
-func (a Accountant) path() string {
-	return path.Join(a.dataPath, kAccountsFileName)
-}
-
-// Load will load the accountant from data
-func (a *Accountant) Load() error {
-	// Don't load anything if the file doesn't exist
-	_, err := os.Stat(a.path())
-	if os.IsNotExist(err) {
-		fmt.Printf("File %s didn't exist, loading with fresh accounts data\n", a.path())
-		return nil
-	}
-
-	if b, err := ioutil.ReadFile(a.path()); err != nil {
-		return err
-	} else if err := json.Unmarshal(b, &a); err != nil {
-		return err
-	}
-	return nil
-}
-
-// Save will save the accountant data out
-func (a *Accountant) Save() error {
-	if b, err := json.MarshalIndent(a, "", "\t"); err != nil {
-		return err
-	} else {
-		if err := ioutil.WriteFile(a.path(), b, os.ModePerm); err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
 // AssignPrimary assigns primary ownership of an instance to an account
 func (a *Accountant) AssignPrimary(account uuid.UUID, instance uuid.UUID) error {
 
diff --git a/pkg/accounts/accounts_test.go b/pkg/accounts/accounts_test.go
index bef693f..bb6c9a7 100644
--- a/pkg/accounts/accounts_test.go
+++ b/pkg/accounts/accounts_test.go
@@ -1,7 +1,6 @@
 package accounts
 
 import (
-	"os"
 	"testing"
 
 	"github.com/google/uuid"
@@ -9,7 +8,7 @@ import (
 
 func TestNewAccountant(t *testing.T) {
 	// Very basic verify here for now
-	accountant := NewAccountant(os.TempDir())
+	accountant := NewAccountant()
 	if accountant == nil {
 		t.Error("Failed to create accountant")
 	}
@@ -17,7 +16,7 @@ func TestNewAccountant(t *testing.T) {
 
 func TestAccountant_RegisterAccount(t *testing.T) {
 
-	accountant := NewAccountant(os.TempDir())
+	accountant := NewAccountant()
 
 	// Start by making two accounts
 
@@ -49,50 +48,8 @@ func TestAccountant_RegisterAccount(t *testing.T) {
 	}
 }
 
-func TestAccountant_LoadSave(t *testing.T) {
-	accountant := NewAccountant(os.TempDir())
-	if len(accountant.Accounts) != 0 {
-		t.Error("New accountant created with non-zero account number")
-	}
-
-	name := uuid.New().String()
-	a, err := accountant.RegisterAccount(name)
-	if err != nil {
-		t.Error(err)
-	}
-
-	if len(accountant.Accounts) != 1 {
-		t.Error("No new account made")
-	} else if accountant.Accounts[a.Id].Name != name {
-		t.Error("New account created with wrong name")
-	}
-
-	// Save out the accountant
-	if err := accountant.Save(); err != nil {
-		t.Error(err)
-	}
-
-	// Re-create the accountant
-	accountant = NewAccountant(os.TempDir())
-	if len(accountant.Accounts) != 0 {
-		t.Error("New accountant created with non-zero account number")
-	}
-
-	// Load the old accountant data
-	if err := accountant.Load(); err != nil {
-		t.Error(err)
-	}
-
-	// Verify we have the same account again
-	if len(accountant.Accounts) != 1 {
-		t.Error("No account after load")
-	} else if accountant.Accounts[a.Id].Name != name {
-		t.Error("New account created with wrong name")
-	}
-}
-
 func TestAccountant_AssignPrimary(t *testing.T) {
-	accountant := NewAccountant(os.TempDir())
+	accountant := NewAccountant()
 	if len(accountant.Accounts) != 0 {
 		t.Error("New accountant created with non-zero account number")
 	}
diff --git a/pkg/game/world.go b/pkg/game/world.go
index 3138953..14311b0 100644
--- a/pkg/game/world.go
+++ b/pkg/game/world.go
@@ -1,11 +1,7 @@
 package game
 
 import (
-	"encoding/json"
 	"fmt"
-	"io/ioutil"
-	"os"
-	"path"
 
 	"github.com/google/uuid"
 )
@@ -30,47 +26,12 @@ type Instance struct {
 const kWorldFileName = "rove-world.json"
 
 // NewWorld creates a new world object
-func NewWorld(data string) *World {
+func NewWorld() *World {
 	return &World{
 		Instances: make(map[uuid.UUID]Instance),
-		dataPath:  data,
 	}
 }
 
-// path returns the full path to the data file
-func (w *World) path() string {
-	return path.Join(w.dataPath, kWorldFileName)
-}
-
-// Load will load the accountant from data
-func (w *World) Load() error {
-	// Don't load anything if the file doesn't exist
-	_, err := os.Stat(w.path())
-	if os.IsNotExist(err) {
-		fmt.Printf("File %s didn't exist, loading with fresh world data\n", w.path())
-		return nil
-	}
-
-	if b, err := ioutil.ReadFile(w.path()); err != nil {
-		return err
-	} else if err := json.Unmarshal(b, &w); err != nil {
-		return err
-	}
-	return nil
-}
-
-// Save will save the accountant data out
-func (w *World) Save() error {
-	if b, err := json.MarshalIndent(w, "", "\t"); err != nil {
-		return err
-	} else {
-		if err := ioutil.WriteFile(w.path(), b, os.ModePerm); err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
 // Adds an instance to the game
 func (w *World) CreateInstance() uuid.UUID {
 	id := uuid.New()
diff --git a/pkg/game/world_test.go b/pkg/game/world_test.go
index 0bfc070..e5fd031 100644
--- a/pkg/game/world_test.go
+++ b/pkg/game/world_test.go
@@ -1,20 +1,19 @@
 package game
 
 import (
-	"os"
 	"testing"
 )
 
 func TestNewWorld(t *testing.T) {
 	// Very basic for now, nothing to verify
-	world := NewWorld(os.TempDir())
+	world := NewWorld()
 	if world == nil {
 		t.Error("Failed to create world")
 	}
 }
 
 func TestWorld_CreateInstance(t *testing.T) {
-	world := NewWorld(os.TempDir())
+	world := NewWorld()
 	a := world.CreateInstance()
 	b := world.CreateInstance()
 
diff --git a/pkg/persistence/persistence.go b/pkg/persistence/persistence.go
new file mode 100644
index 0000000..6705767
--- /dev/null
+++ b/pkg/persistence/persistence.go
@@ -0,0 +1,90 @@
+package persistence
+
+import (
+	"encoding/json"
+	"fmt"
+	"io/ioutil"
+	"os"
+	"path"
+)
+
+// dataPath global path for persistence
+var dataPath = os.TempDir()
+
+// SetPath sets the persistent path for the data storage
+func SetPath(path string) error {
+	if info, err := os.Stat(path); err != nil {
+		return err
+	} else if !info.IsDir() {
+		return fmt.Errorf("path for persistence is not directory")
+	}
+	dataPath = path
+	return nil
+}
+
+// Converts name to a full path
+func jsonPath(name string) string {
+	return path.Join(dataPath, fmt.Sprintf("rove-%s.json", name))
+}
+
+// Save will serialise the interface into a json file
+func Save(name string, data interface{}) error {
+	if b, err := json.MarshalIndent(data, "", "\t"); err != nil {
+		return err
+	} else {
+		if err := ioutil.WriteFile(jsonPath(name), b, os.ModePerm); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+// Load will load the interface from the json file
+func Load(name string, data interface{}) error {
+	path := jsonPath(name)
+	// Don't load anything if the file doesn't exist
+	_, err := os.Stat(path)
+	if os.IsNotExist(err) {
+		fmt.Printf("File %s didn't exist, loading with fresh data\n", path)
+		return nil
+	}
+
+	// Read and unmarshal the json
+	if b, err := ioutil.ReadFile(path); err != nil {
+		return err
+	} else if err := json.Unmarshal(b, data); err != nil {
+		return err
+	}
+	return nil
+}
+
+// saveLoadFunc defines a type of function to save or load an interface
+type saveLoadFunc func(string, interface{}) error
+
+func doAll(f saveLoadFunc, args ...interface{}) error {
+	var name string
+	for i, a := range args {
+		if i%2 == 0 {
+			var ok bool
+			name, ok = a.(string)
+			if !ok {
+				return fmt.Errorf("Incorrect args")
+			}
+		} else {
+			if err := f(name, a); err != nil {
+				return err
+			}
+		}
+	}
+	return nil
+}
+
+// SaveAll allows for saving multiple structures in a single call
+func SaveAll(args ...interface{}) error {
+	return doAll(Save, args...)
+}
+
+// LoadAll allows for loading multiple structures in a single call
+func LoadAll(args ...interface{}) error {
+	return doAll(Load, args...)
+}
diff --git a/pkg/persistence/persistence_test.go b/pkg/persistence/persistence_test.go
new file mode 100644
index 0000000..e4117a1
--- /dev/null
+++ b/pkg/persistence/persistence_test.go
@@ -0,0 +1,52 @@
+package persistence
+
+import (
+	"io/ioutil"
+	"os"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+type Dummy struct {
+	Success bool
+	Value   int
+}
+
+func TestPersistence_LoadSave(t *testing.T) {
+	tmp, err := ioutil.TempDir(os.TempDir(), "rove_persistence_test")
+	assert.NoError(t, err, "Failed to get tempdir path")
+
+	assert.NoError(t, SetPath(tmp), "Failed to get set tempdir to persistence path")
+
+	// Try and save out the dummy
+	var dummy Dummy
+	dummy.Success = true
+	assert.NoError(t, Save("test", dummy), "Failed to save out dummy file")
+
+	// Load back the dummy
+	dummy = Dummy{}
+	assert.NoError(t, Load("test", &dummy), "Failed to load in dummy file")
+	assert.Equal(t, true, dummy.Success, "Did not successfully load true value from file")
+}
+
+func TestPersistence_LoadSaveAll(t *testing.T) {
+	tmp, err := ioutil.TempDir(os.TempDir(), "rove_persistence_test")
+	assert.NoError(t, err, "Failed to get tempdir path")
+
+	assert.NoError(t, SetPath(tmp), "Failed to get set tempdir to persistence path")
+
+	// Try and save out the dummy
+	var dummyA Dummy
+	var dummyB Dummy
+	dummyA.Value = 1
+	dummyB.Value = 2
+	assert.NoError(t, SaveAll("a", dummyA, "b", dummyB), "Failed to save out dummy file")
+
+	// Load back the dummy
+	dummyA = Dummy{}
+	dummyB = Dummy{}
+	assert.NoError(t, LoadAll("a", &dummyA, "b", &dummyB), "Failed to load in dummy file")
+	assert.Equal(t, 1, dummyA.Value, "Did not successfully load int value from file")
+	assert.Equal(t, 2, dummyB.Value, "Did not successfully load int value from file")
+}
diff --git a/pkg/server/server.go b/pkg/server/server.go
index 3ffaee4..37ff0ae 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -11,6 +11,7 @@ import (
 	"github.com/gorilla/mux"
 	"github.com/mdiluz/rove/pkg/accounts"
 	"github.com/mdiluz/rove/pkg/game"
+	"github.com/mdiluz/rove/pkg/persistence"
 )
 
 const (
@@ -48,10 +49,9 @@ func OptionPort(port int) ServerOption {
 }
 
 // OptionPersistentData sets the server data to be persistent
-func OptionPersistentData(loc string) ServerOption {
+func OptionPersistentData() ServerOption {
 	return func(s *Server) {
 		s.persistence = PersistentData
-		s.persistenceLocation = loc
 	}
 }
 
@@ -76,8 +76,8 @@ func NewServer(opts ...ServerOption) *Server {
 	s.server = &http.Server{Addr: fmt.Sprintf(":%d", s.port), Handler: router}
 
 	// Create the accountant
-	s.accountant = accounts.NewAccountant(s.persistenceLocation)
-	s.world = game.NewWorld(s.persistenceLocation)
+	s.accountant = accounts.NewAccountant()
+	s.world = game.NewWorld()
 
 	return s
 }
@@ -87,10 +87,7 @@ func (s *Server) Initialise() error {
 
 	// Load the accounts if requested
 	if s.persistence == PersistentData {
-		if err := s.accountant.Load(); err != nil {
-			return err
-		}
-		if err := s.world.Load(); err != nil {
+		if err := persistence.LoadAll("accounts", &s.accountant, "world", &s.world); err != nil {
 			return err
 		}
 	}
@@ -130,10 +127,7 @@ func (s *Server) Close() error {
 
 	// Save the accounts if requested
 	if s.persistence == PersistentData {
-		if err := s.accountant.Save(); err != nil {
-			return err
-		}
-		if err := s.world.Save(); err != nil {
+		if err := persistence.SaveAll("accounts", s.accountant, "world", s.world); err != nil {
 			return err
 		}
 	}
diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go
index f4598d7..619ebbe 100644
--- a/pkg/server/server_test.go
+++ b/pkg/server/server_test.go
@@ -1,7 +1,6 @@
 package server
 
 import (
-	"os"
 	"testing"
 )
 
@@ -22,13 +21,11 @@ func TestNewServer_OptionPort(t *testing.T) {
 }
 
 func TestNewServer_OptionPersistentData(t *testing.T) {
-	server := NewServer(OptionPersistentData(os.TempDir()))
+	server := NewServer(OptionPersistentData())
 	if server == nil {
 		t.Error("Failed to create server")
 	} else if server.persistence != PersistentData {
 		t.Error("Failed to set server persistent data")
-	} else if server.persistenceLocation != os.TempDir() {
-		t.Error("Failed to set server persistent data path")
 	}
 }
 
@@ -47,7 +44,7 @@ func TestServer_Run(t *testing.T) {
 }
 
 func TestServer_RunPersistentData(t *testing.T) {
-	server := NewServer(OptionPersistentData(os.TempDir()))
+	server := NewServer(OptionPersistentData())
 	if server == nil {
 		t.Error("Failed to create server")
 	} else if err := server.Initialise(); err != nil {