From d124963eee940088dcb53f696acf37a6719c77ae Mon Sep 17 00:00:00 2001 From: Josh Deprez Date: Sun, 14 May 2023 17:24:40 +1000 Subject: [PATCH] Make MapVariableStorage concurrent-safe #6 --- README.md | 2 +- async_adapter_test.go | 10 ++--- cmd/yarnrunner.go | 2 +- vars.go | 98 +++++++++++++++++++++++++++++++++++++------ vm_test.go | 2 +- 5 files changed, 93 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 1fff2bd..8376b32 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ commands to the handler. vm := &yarn.VirtualMachine{ Program: program, Handler: myHandler, - Vars: make(yarn.MapVariableStorage), // or your own VariableStorage implementation + Vars: yarn.NewMapVariableStorage(), // or your own VariableStorage implementation FuncMap: yarn.FuncMap{ // this is optional "last_value": func(x ...any) any { return x[len(x)-1] diff --git a/async_adapter_test.go b/async_adapter_test.go index 9e7ac07..e1616ee 100644 --- a/async_adapter_test.go +++ b/async_adapter_test.go @@ -126,7 +126,7 @@ func TestAllTestPlansAsync(t *testing.T) { vm := &VirtualMachine{ Program: prog, Handler: sa.aa, - Vars: make(MapVariableStorage), + Vars: NewMapVariableStorage(), FuncMap: FuncMap{ // Used by various "assert": func(x interface{}) error { @@ -232,7 +232,7 @@ func TestAsyncAdapterWithDecoupledHandler(t *testing.T) { vm := &VirtualMachine{ Program: prog, Handler: aa, - Vars: make(MapVariableStorage), + Vars: NewMapVariableStorage(), } if traceOutput { vm.TraceLogf = t.Logf @@ -292,7 +292,7 @@ func TestAsyncAdapterWithImmediateHandler(t *testing.T) { vm := &VirtualMachine{ Program: prog, Handler: aa, - Vars: make(MapVariableStorage), + Vars: NewMapVariableStorage(), } if traceOutput { vm.TraceLogf = t.Logf @@ -370,7 +370,7 @@ func TestAsyncAdapterWithBadHandler(t *testing.T) { vm := &VirtualMachine{ Program: prog, Handler: aa, - Vars: make(MapVariableStorage), + Vars: NewMapVariableStorage(), } if traceOutput { vm.TraceLogf = t.Logf @@ -421,7 +421,7 @@ func TestAsyncAdapterWithAbortHandler(t *testing.T) { vm := &VirtualMachine{ Program: prog, Handler: aa, - Vars: make(MapVariableStorage), + Vars: NewMapVariableStorage(), } if traceOutput { vm.TraceLogf = t.Logf diff --git a/cmd/yarnrunner.go b/cmd/yarnrunner.go index 1980cf5..f6094e7 100644 --- a/cmd/yarnrunner.go +++ b/cmd/yarnrunner.go @@ -52,7 +52,7 @@ func main() { Handler: &dialogueHandler{ stringTable: stringTable, }, - Vars: make(yarn.MapVariableStorage), + Vars: yarn.NewMapVariableStorage(), } if err := vm.Run(*startNode); err != nil { log.Printf("Yarn VM error: %v", err) diff --git a/vars.go b/vars.go index 4bec47c..07db943 100644 --- a/vars.go +++ b/vars.go @@ -14,30 +14,102 @@ package yarn +import "sync" + // VariableStorage stores values of any kind. type VariableStorage interface { - Clear() - GetValue(name string) (value interface{}, ok bool) - SetValue(name string, value interface{}) + GetValue(name string) (value any, ok bool) + SetValue(name string, value any) } // MapVariableStorage implements VariableStorage, in memory, using a map. -type MapVariableStorage map[string]interface{} +// In addition to the core VariableStorage functionality, there are methods for +// accessing the contents as an ordinary map[string]any. +type MapVariableStorage struct { + mu sync.RWMutex + m map[string]any +} + +// NewMapVariableStorage creates a new empty MapVariableStorage. +func NewMapVariableStorage() *MapVariableStorage { + return &MapVariableStorage{ + m: make(map[string]any), + } +} + +// NewMapVariableStorageFromMap creates a new MapVariableStorage with initial +// contents copied from src. It does not keep a reference to src. +func NewMapVariableStorageFromMap(src map[string]any) *MapVariableStorage { + return &MapVariableStorage{ + m: copyMap(src), + } +} // Clear empties the storage of all values. -func (m MapVariableStorage) Clear() { - for name := range m { - delete(m, name) +func (m *MapVariableStorage) Clear() { + m.mu.Lock() + defer m.mu.Unlock() + for name := range m.m { + delete(m.m, name) } } -// GetValue fetches a value from the map, returning (nil, false) if not present. -func (m MapVariableStorage) GetValue(name string) (value interface{}, found bool) { - value, found = m[name] +// GetValue fetches a value from the storage, returning (nil, false) if not present. +func (m *MapVariableStorage) GetValue(name string) (value any, found bool) { + m.mu.RLock() + defer m.mu.RUnlock() + value, found = m.m[name] return value, found } -// SetValue sets a value in the map. -func (m MapVariableStorage) SetValue(name string, value interface{}) { - m[name] = value +// SetValue sets a value in the storage. +func (m *MapVariableStorage) SetValue(name string, value any) { + m.mu.Lock() + defer m.mu.Unlock() + m.m[name] = value +} + +// Delete deletes values from the storage. +func (m *MapVariableStorage) Delete(names ...string) { + m.mu.Lock() + defer m.mu.Unlock() + for _, name := range names { + delete(m.m, name) + } +} + +// Contents returns a copy of the contents of the storage, as a regular map. +// The returned map is a copy, it is not a reference to the map contained within +// the storage (to avoid accidental data races). +func (m *MapVariableStorage) Contents() map[string]any { + m.mu.RLock() + defer m.mu.RUnlock() + return copyMap(m.m) +} + +// Clone returns a new MapVariableStorage that is a clone of the receiver. +// The new storage is a deep copy, and does not contain a reference to the +// original map inside the receiver (to avoid accidental data races). +func (m *MapVariableStorage) Clone() *MapVariableStorage { + m.mu.RLock() + defer m.mu.RUnlock() + return NewMapVariableStorageFromMap(m.m) +} + +// ReplaceContents replaces the contents of the storage with values from a +// regular map. ReplaceContents copies src, it does not keep a reference to src +// (to avoid accidental data races). +func (m *MapVariableStorage) ReplaceContents(src map[string]any) { + m2 := copyMap(src) + m.mu.Lock() + defer m.mu.Unlock() + m.m = m2 +} + +func copyMap[K comparable, V any](src map[K]V) map[K]V { + m := make(map[K]V, len(src)) + for name, val := range src { + m[name] = val + } + return m } diff --git a/vm_test.go b/vm_test.go index 19bb632..54648d2 100644 --- a/vm_test.go +++ b/vm_test.go @@ -47,7 +47,7 @@ func TestAllTestPlans(t *testing.T) { vm := &VirtualMachine{ Program: prog, Handler: testplan, - Vars: make(MapVariableStorage), + Vars: NewMapVariableStorage(), FuncMap: FuncMap{ // Used by various "assert": func(x interface{}) error {