Skip to content

Commit

Permalink
Merge pull request #276 from jcaamano/deepcopy
Browse files Browse the repository at this point in the history
modelgen: extend to include copy and equal methods
  • Loading branch information
dcbw authored Jan 14, 2022
2 parents cbffe8e + dcf4a9c commit 2971089
Show file tree
Hide file tree
Showing 12 changed files with 951 additions and 46 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ coverage: test integration-test
@cat unit.cov integration.cov > profile.cov

.PHONY: bench
bench: install-deps
bench: install-deps prebuild
@echo "+ $@"
@go test -run=XXX -count=3 -bench=. ./... | tee bench.out
@benchstat bench.out
Expand Down
4 changes: 2 additions & 2 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ func (t *TableCache) Populate(tableUpdates ovsdb.TableUpdates) error {
return err
}
if existing := tCache.Row(uuid); existing != nil {
if !reflect.DeepEqual(newModel, existing) {
if !model.Equal(newModel, existing) {
logger.V(5).Info("updating row", "old:", fmt.Sprintf("%+v", existing), "new", fmt.Sprintf("%+v", newModel))
if err := tCache.Update(uuid, newModel, false); err != nil {
return err
Expand Down Expand Up @@ -660,7 +660,7 @@ func (t *TableCache) Populate2(tableUpdates ovsdb.TableUpdates2) error {
if err != nil {
return fmt.Errorf("unable to apply row modifications: %v", err)
}
if !reflect.DeepEqual(modified, existing) {
if !model.Equal(modified, existing) {
logger.V(5).Info("updating row", "old", fmt.Sprintf("%+v", existing), "new", fmt.Sprintf("%+v", modified))
if err := tCache.Update(uuid, modified, false); err != nil {
return err
Expand Down
4 changes: 1 addition & 3 deletions client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package client

import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
Expand Down Expand Up @@ -223,8 +222,7 @@ func (a api) Get(ctx context.Context, m model.Model) error {
return ErrNotFound
}

foundBytes, _ := json.Marshal(found)
_ = json.Unmarshal(foundBytes, m)
model.CloneInto(found, m)

return nil
}
Expand Down
2 changes: 2 additions & 0 deletions cmd/modelgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var (
outDirP = flag.String("o", ".", "Directory where the generated files shall be stored")
pkgNameP = flag.String("p", "ovsmodel", "Package name")
dryRun = flag.Bool("d", false, "Dry run")
extended = flag.Bool("extended", false, "Generates additional code like deep-copy methods, etc.")
)

func main() {
Expand Down Expand Up @@ -76,6 +77,7 @@ func main() {
for name, table := range dbSchema.Tables {
tmpl := modelgen.NewTableTemplate()
args := modelgen.GetTableTemplateData(pkgName, name, &table)
args.WithExtendedGen(*extended)
if err := gen.Generate(filepath.Join(outDir, modelgen.FileName(name)), tmpl, args); err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion example/vswitchd/gen.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
package vswitchd

//go:generate ../../bin/modelgen -p vswitchd -o . ovs.ovsschema
//go:generate ../../bin/modelgen --extended -p vswitchd -o . ovs.ovsschema
32 changes: 32 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,47 @@ import (
//}
type Model interface{}

type CloneableModel interface {
CloneModel() Model
CloneModelInto(Model)
}

type ComparableModel interface {
EqualsModel(Model) bool
}

// Clone creates a deep copy of a model
func Clone(a Model) Model {
if cloner, ok := a.(CloneableModel); ok {
return cloner.CloneModel()
}

val := reflect.Indirect(reflect.ValueOf(a))
b := reflect.New(val.Type()).Interface()
aBytes, _ := json.Marshal(a)
_ = json.Unmarshal(aBytes, b)
return b
}

// CloneInto deep copies a model into another one
func CloneInto(src, dst Model) {
if cloner, ok := src.(CloneableModel); ok {
cloner.CloneModelInto(dst)
return
}

aBytes, _ := json.Marshal(src)
_ = json.Unmarshal(aBytes, dst)
}

func Equal(l, r Model) bool {
if comparator, ok := l.(ComparableModel); ok {
return comparator.EqualsModel(r)
}

return reflect.DeepEqual(l, r)
}

func modelSetUUID(model Model, uuid string) error {
modelVal := reflect.ValueOf(model).Elem()
for i := 0; i < modelVal.NumField(); i++ {
Expand Down
94 changes: 93 additions & 1 deletion model/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package model
import (
"encoding/json"
"fmt"
"reflect"
"testing"

"github.com/ovn-org/libovsdb/ovsdb"
Expand Down Expand Up @@ -335,7 +336,28 @@ func TestValidate(t *testing.T) {

}

func TestClone(t *testing.T) {
type modelC struct {
modelB
NoClone string
}

func (a *modelC) CloneModel() Model {
return &modelC{
modelB: a.modelB,
}
}

func (a *modelC) CloneModelInto(b Model) {
c := b.(*modelC)
c.modelB = a.modelB
}

func (a *modelC) EqualsModel(b Model) bool {
c := b.(*modelC)
return reflect.DeepEqual(a.modelB, c.modelB)
}

func TestCloneViaMarshalling(t *testing.T) {
a := &modelB{UID: "foo", Foo: "bar", Bar: "baz"}
b := Clone(a).(*modelB)
assert.Equal(t, a, b)
Expand All @@ -344,3 +366,73 @@ func TestClone(t *testing.T) {
b.UID = "quux"
assert.NotEqual(t, a, b)
}

func TestCloneIntoViaMarshalling(t *testing.T) {
a := &modelB{UID: "foo", Foo: "bar", Bar: "baz"}
b := &modelB{}
CloneInto(a, b)
assert.Equal(t, a, b)
a.UID = "baz"
assert.NotEqual(t, a, b)
b.UID = "quux"
assert.NotEqual(t, a, b)
}

func TestCloneViaCloneable(t *testing.T) {
a := &modelC{modelB: modelB{UID: "foo", Foo: "bar", Bar: "baz"}, NoClone: "noClone"}
func(a interface{}) {
_, ok := a.(CloneableModel)
assert.True(t, ok, "is not cloneable")
}(a)
// test that Clone() uses the cloneable interface, in which
// case modelC.NoClone won't be copied
b := Clone(a).(*modelC)
assert.NotEqual(t, a, b)
b.NoClone = a.NoClone
assert.Equal(t, a, b)
a.UID = "baz"
assert.NotEqual(t, a, b)
b.UID = "quux"
assert.NotEqual(t, a, b)
}

func TestCloneIntoViaCloneable(t *testing.T) {
a := &modelC{modelB: modelB{UID: "foo", Foo: "bar", Bar: "baz"}, NoClone: "noClone"}
func(a interface{}) {
_, ok := a.(CloneableModel)
assert.True(t, ok, "is not cloneable")
}(a)
// test that CloneInto() uses the cloneable interface, in which
// case modelC.NoClone won't be copied
b := &modelC{}
CloneInto(a, b)
assert.NotEqual(t, a, b)
b.NoClone = a.NoClone
assert.Equal(t, a, b)
a.UID = "baz"
assert.NotEqual(t, a, b)
b.UID = "quux"
assert.NotEqual(t, a, b)
}

func TestEqualViaDeepEqual(t *testing.T) {
a := &modelB{UID: "foo", Foo: "bar", Bar: "baz"}
b := &modelB{UID: "foo", Foo: "bar", Bar: "baz"}
assert.True(t, Equal(a, b))
a.UID = "baz"
assert.False(t, Equal(a, b))
}

func TestEqualViaComparable(t *testing.T) {
a := &modelC{modelB: modelB{UID: "foo", Foo: "bar", Bar: "baz"}, NoClone: "noClone"}
func(a interface{}) {
_, ok := a.(ComparableModel)
assert.True(t, ok, "is not comparable")
}(a)
b := a.CloneModel().(*modelC)
// test that Equal() uses the comparable interface, in which
// case the difference on modelC.NoClone won't be noticed
assert.True(t, Equal(a, b))
a.UID = "baz"
assert.False(t, Equal(a, b))
}
27 changes: 16 additions & 11 deletions modelgen/dbmodel.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 2971089

Please sign in to comment.