Specify the persistence path using the command line

This commit is contained in:
Marc Di Luzio 2020-06-02 16:10:45 +01:00
parent c085e56954
commit 5033ec4e63
5 changed files with 36 additions and 18 deletions

View file

@ -13,6 +13,7 @@ import (
var ver = flag.Bool("version", false, "Display version number") var ver = flag.Bool("version", false, "Display version number")
var port = flag.Int("port", 8080, "The port to host on") var port = flag.Int("port", 8080, "The port to host on")
var data = flag.String("data", "/tmp/", "Directory to store persistant data")
func main() { func main() {
flag.Parse() flag.Parse()
@ -24,7 +25,7 @@ func main() {
s := server.NewServer( s := server.NewServer(
server.OptionPort(*port), server.OptionPort(*port),
server.OptionPersistentData()) server.OptionPersistentData(*data))
fmt.Println("Initialising...") fmt.Println("Initialising...")
if err := s.Initialise(); err != nil { if err := s.Initialise(); err != nil {

View file

@ -5,11 +5,12 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"path"
"github.com/google/uuid" "github.com/google/uuid"
) )
const kDefaultSavePath = "/tmp/accounts.json" const kAccountsFileName = "rove-accounts.json"
// Account represents a registered user // Account represents a registered user
type Account struct { type Account struct {
@ -27,12 +28,15 @@ type accountantData struct {
// Accountant manages a set of accounts // Accountant manages a set of accounts
type Accountant struct { type Accountant struct {
data accountantData data accountantData
dataPath string
} }
// NewAccountant creates a new accountant // NewAccountant creates a new accountant
func NewAccountant() *Accountant { func NewAccountant(dataPath string) *Accountant {
return &Accountant{} return &Accountant{
dataPath: dataPath,
}
} }
// RegisterAccount adds an account to the set of internal accounts // RegisterAccount adds an account to the set of internal accounts
@ -56,16 +60,21 @@ func (a *Accountant) RegisterAccount(acc Account) (Account, error) {
return acc, nil return acc, nil
} }
// path returns the full path to the data file
func (a Accountant) path() string {
return path.Join(a.dataPath, kAccountsFileName)
}
// Load will load the accountant from data // Load will load the accountant from data
func (a *Accountant) Load() error { func (a *Accountant) Load() error {
// Don't load anything if the file doesn't exist // Don't load anything if the file doesn't exist
_, err := os.Stat(kDefaultSavePath) _, err := os.Stat(a.path())
if os.IsNotExist(err) { if os.IsNotExist(err) {
fmt.Printf("File %s didn't exist, loading with fresh accounts data\n", kDefaultSavePath) fmt.Printf("File %s didn't exist, loading with fresh accounts data\n", a.path())
return nil return nil
} }
if b, err := ioutil.ReadFile(kDefaultSavePath); err != nil { if b, err := ioutil.ReadFile(a.path()); err != nil {
return err return err
} else if err := json.Unmarshal(b, &a.data); err != nil { } else if err := json.Unmarshal(b, &a.data); err != nil {
return err return err
@ -78,7 +87,7 @@ func (a *Accountant) Save() error {
if b, err := json.Marshal(a.data); err != nil { if b, err := json.Marshal(a.data); err != nil {
return err return err
} else { } else {
if err := ioutil.WriteFile(kDefaultSavePath, b, os.ModePerm); err != nil { if err := ioutil.WriteFile(a.path(), b, os.ModePerm); err != nil {
return err return err
} }
} }

View file

@ -1,12 +1,13 @@
package accounts package accounts
import ( import (
"os"
"testing" "testing"
) )
func TestNewAccountant(t *testing.T) { func TestNewAccountant(t *testing.T) {
// Very basic verify here for now // Very basic verify here for now
accountant := NewAccountant() accountant := NewAccountant(os.TempDir())
if accountant == nil { if accountant == nil {
t.Error("Failed to create accountant") t.Error("Failed to create accountant")
} }
@ -14,7 +15,7 @@ func TestNewAccountant(t *testing.T) {
func TestAccountant_RegisterAccount(t *testing.T) { func TestAccountant_RegisterAccount(t *testing.T) {
accountant := NewAccountant() accountant := NewAccountant(os.TempDir())
// Start by making two accounts // Start by making two accounts
@ -49,7 +50,7 @@ func TestAccountant_RegisterAccount(t *testing.T) {
} }
func TestAccountant_LoadSave(t *testing.T) { func TestAccountant_LoadSave(t *testing.T) {
accountant := NewAccountant() accountant := NewAccountant(os.TempDir())
if len(accountant.data.Accounts) != 0 { if len(accountant.data.Accounts) != 0 {
t.Error("New accountant created with non-zero account number") t.Error("New accountant created with non-zero account number")
} }
@ -73,7 +74,7 @@ func TestAccountant_LoadSave(t *testing.T) {
} }
// Re-create the accountant // Re-create the accountant
accountant = NewAccountant() accountant = NewAccountant(os.TempDir())
if len(accountant.data.Accounts) != 0 { if len(accountant.data.Accounts) != 0 {
t.Error("New accountant created with non-zero account number") t.Error("New accountant created with non-zero account number")
} }

View file

@ -31,7 +31,8 @@ type Server struct {
server *http.Server server *http.Server
router *mux.Router router *mux.Router
persistence int persistence int
persistenceLocation string
sync sync.WaitGroup sync sync.WaitGroup
} }
@ -47,9 +48,10 @@ func OptionPort(port int) ServerOption {
} }
// OptionPersistentData sets the server data to be persistent // OptionPersistentData sets the server data to be persistent
func OptionPersistentData() ServerOption { func OptionPersistentData(loc string) ServerOption {
return func(s *Server) { return func(s *Server) {
s.persistence = PersistentData s.persistence = PersistentData
s.persistenceLocation = loc
} }
} }
@ -61,7 +63,6 @@ func NewServer(opts ...ServerOption) *Server {
// Set up the default server // Set up the default server
s := &Server{ s := &Server{
port: 8080, port: 8080,
accountant: accounts.NewAccountant(),
world: game.NewWorld(), world: game.NewWorld(),
persistence: EphemeralData, persistence: EphemeralData,
router: router, router: router,
@ -75,6 +76,9 @@ func NewServer(opts ...ServerOption) *Server {
// Set up the server object // Set up the server object
s.server = &http.Server{Addr: fmt.Sprintf(":%d", s.port), Handler: router} s.server = &http.Server{Addr: fmt.Sprintf(":%d", s.port), Handler: router}
// Create the accountant
s.accountant = accounts.NewAccountant(s.persistenceLocation)
return s return s
} }

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"os"
"testing" "testing"
) )
@ -21,11 +22,13 @@ func TestNewServer_OptionPort(t *testing.T) {
} }
func TestNewServer_OptionPersistentData(t *testing.T) { func TestNewServer_OptionPersistentData(t *testing.T) {
server := NewServer(OptionPersistentData()) server := NewServer(OptionPersistentData(os.TempDir()))
if server == nil { if server == nil {
t.Error("Failed to create server") t.Error("Failed to create server")
} else if server.persistence != PersistentData { } else if server.persistence != PersistentData {
t.Error("Failed to set server persistent data") t.Error("Failed to set server persistent data")
} else if server.persistenceLocation != os.TempDir() {
t.Error("Failed to set server persistent data path")
} }
} }
@ -44,7 +47,7 @@ func TestServer_Run(t *testing.T) {
} }
func TestServer_RunPersistentData(t *testing.T) { func TestServer_RunPersistentData(t *testing.T) {
server := NewServer(OptionPersistentData()) server := NewServer(OptionPersistentData(os.TempDir()))
if server == nil { if server == nil {
t.Error("Failed to create server") t.Error("Failed to create server")
} else if err := server.Initialise(); err != nil { } else if err := server.Initialise(); err != nil {