diff --git a/cache/cache.go b/cache/cache.go index 01a5aabd..80eb589f 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -590,15 +590,15 @@ func (t *TableCache) Populate2(tableUpdates ovsdb.TableUpdates2) { } t.eventProcessor.AddEvent(addEvent, table, nil, m) case row.Modify != nil: - modified := tCache.Row(uuid) - if modified == nil { + existing := tCache.Row(uuid) + if existing == nil { panic(fmt.Errorf("row with uuid %s does not exist", uuid)) } + modified := tCache.Row(uuid) err := t.ApplyModifications(table, modified, *row.Modify) if err != nil { panic(err) } - existing := tCache.Row(uuid) if !reflect.DeepEqual(modified, existing) { if err := tCache.Update(uuid, modified, false); err != nil { panic(err) @@ -720,7 +720,7 @@ func (e *eventProcessor) AddEvent(eventType string, table string, old model.Mode // noop return default: - log.Print("dropping event because event buffer is full") + log.Print("libovsdb: dropping event because event buffer is full") } } @@ -871,6 +871,9 @@ func (t *TableCache) ApplyModifications(tableName string, base model.Model, upda bv.SetMapIndex(mk, mv) } } + if len(bv.MapKeys()) == 0 { + bv = reflect.Zero(nv.Type()) + } err = info.SetField(k, bv.Interface()) if err != nil { return err diff --git a/cache/cache_test.go b/cache/cache_test.go index 44aba58c..14c109a0 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -1200,7 +1200,12 @@ func TestTableCacheApplyModifications(t *testing.T) { &testDBModel{Value: "foo"}, &testDBModel{Value: "bar"}, }, - + { + "noop", + ovsdb.Row{"value": "bar"}, + &testDBModel{Value: "bar"}, + &testDBModel{Value: "bar"}, + }, { "add to set", ovsdb.Row{"set": aFooSet}, @@ -1236,7 +1241,7 @@ func TestTableCacheApplyModifications(t *testing.T) { "delete map key", ovsdb.Row{"map": aFooMap}, &testDBModel{Map: map[string]string{"foo": "bar"}}, - &testDBModel{Map: map[string]string{}}, + &testDBModel{Map: nil}, }, { "multiple map operations", @@ -1276,8 +1281,8 @@ func TestTableCacheApplyModifications(t *testing.T) { err = tc.ApplyModifications("Open_vSwitch", original, tt.update) require.NoError(t, err) require.Equal(t, tt.expected, original) - if reflect.DeepEqual(original, tt.base) { - t.Error("original and base are equal") + if !reflect.DeepEqual(tt.expected, tt.base) { + require.NotEqual(t, tt.base, original) } }) } diff --git a/client/client.go b/client/client.go index 0e175239..066d7f43 100644 --- a/client/client.go +++ b/client/client.go @@ -245,16 +245,10 @@ func (o *ovsdbClient) tryEndpoint(ctx context.Context, u *url.URL) error { o.createRPC2Client(c) - // from now on, if err is nil, always tear down the RPC session - defer func() { - if err != nil { - o.rpcClient.Close() - o.rpcClient = nil - } - }() - serverDBNames, err := o.listDbs(ctx) if err != nil { + o.rpcClient.Close() + o.rpcClient = nil return err } @@ -271,12 +265,16 @@ func (o *ovsdbClient) tryEndpoint(ctx context.Context, u *url.URL) error { } if !found { err = fmt.Errorf("target database %s not found", dbName) + o.rpcClient.Close() + o.rpcClient = nil return err } // load and validate the schema schema, err := o.getSchema(ctx, dbName) if err != nil { + o.rpcClient.Close() + o.rpcClient = nil return err } @@ -288,6 +286,8 @@ func (o *ovsdbClient) tryEndpoint(ctx context.Context, u *url.URL) error { } err = fmt.Errorf("database %s validation error (%d): %s", dbName, len(errors), strings.Join(combined, ". ")) + o.rpcClient.Close() + o.rpcClient = nil return err } @@ -300,6 +300,8 @@ func (o *ovsdbClient) tryEndpoint(ctx context.Context, u *url.URL) error { db.cache, err = cache.NewTableCache(schema, db.model, nil) if err != nil { db.cacheMutex.Unlock() + o.rpcClient.Close() + o.rpcClient = nil return err } db.api = newAPI(db.cache) @@ -314,10 +316,14 @@ func (o *ovsdbClient) tryEndpoint(ctx context.Context, u *url.URL) error { var leader bool leader, err = o.isEndpointLeader(ctx) if err != nil { + o.rpcClient.Close() + o.rpcClient = nil return err } if !leader { err = fmt.Errorf("endpoint is not leader") + o.rpcClient.Close() + o.rpcClient = nil return err } } @@ -610,7 +616,6 @@ func (o *ovsdbClient) transact(ctx context.Context, dbName string, operation ... } args := ovsdb.NewTransactArgs(dbName, operation...) - if o.rpcClient == nil { return nil, ErrNotConnected } @@ -761,7 +766,6 @@ func (o *ovsdbClient) monitor(ctx context.Context, cookie MonitorCookie, reconne } if !reconnecting { - db := o.databases[dbName] db.monitorsMutex.Lock() db.monitors[cookie.ID] = monitor db.monitorsMutex.Unlock() @@ -769,12 +773,15 @@ func (o *ovsdbClient) monitor(ctx context.Context, cookie MonitorCookie, reconne if monitor.Method == ovsdb.MonitorRPC { u := tableUpdates.(ovsdb.TableUpdates) - o.databases[dbName].cache.Populate(u) + db.cacheMutex.Lock() + defer db.cacheMutex.Unlock() + db.cache.Update(nil, u) } else { u := tableUpdates.(ovsdb.TableUpdates2) - o.databases[dbName].cache.Populate2(u) + db.cacheMutex.Lock() + defer db.cacheMutex.Unlock() + db.cache.Update2(nil, u) } - return nil } diff --git a/ovsdb/bindings.go b/ovsdb/bindings.go index f80b751c..4c675858 100644 --- a/ovsdb/bindings.go +++ b/ovsdb/bindings.go @@ -245,7 +245,6 @@ func NativeToOvsAtomic(basicType string, nativeElem interface{}) (interface{}, e // NativeToOvs transforms an native type to a ovs type based on the column type information func NativeToOvs(column *ColumnSchema, rawElem interface{}) (interface{}, error) { naType := NativeType(column) - if t := reflect.TypeOf(rawElem); t != naType { return nil, NewErrWrongType("NativeToOvs", naType.String(), rawElem) } diff --git a/ovsdb/updates2.go b/ovsdb/updates2.go index 7b269ec9..64bd2e80 100644 --- a/ovsdb/updates2.go +++ b/ovsdb/updates2.go @@ -62,7 +62,31 @@ func (r *RowUpdate2) Merge(new *RowUpdate2) { return } if r.Modify != nil && new.Modify != nil { - r.Modify = new.Modify + currentRowData := *r.Modify + newRowData := *new.Modify + for k, v := range newRowData { + if _, ok := currentRowData[k]; !ok { + currentRowData[k] = v + } else { + switch v.(type) { + case OvsSet: + oSet := currentRowData[k].(OvsSet) + newSet := v.(OvsSet) + oSet.GoSet = append(oSet.GoSet, newSet.GoSet...) + case OvsMap: + oMap := currentRowData[k].(OvsMap) + newMap := v.(OvsMap) + for newK, newV := range newMap.GoMap { + if _, ok := oMap.GoMap[newK]; !ok { + oMap.GoMap[newK] = newV + } + } + default: + panic("ARGH!") + } + } + } + r.Modify = ¤tRowData return } if r.Modify != nil && new.Delete != nil { diff --git a/ovsdb/updates2_test.go b/ovsdb/updates2_test.go new file mode 100644 index 00000000..058ddf44 --- /dev/null +++ b/ovsdb/updates2_test.go @@ -0,0 +1,47 @@ +package ovsdb + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddRowUpdate2Merge(t *testing.T) { + tests := []struct { + name string + initial *RowUpdate2 + new *RowUpdate2 + expected *RowUpdate2 + }{ + { + "insert then modify", + &RowUpdate2{Insert: &Row{"foo": "bar"}}, + &RowUpdate2{Modify: &Row{"foo": "baz"}}, + &RowUpdate2{Insert: &Row{"foo": "baz"}}, + }, + { + "insert then delete", + &RowUpdate2{Insert: &Row{"foo": "bar"}}, + &RowUpdate2{Delete: &Row{"foo": "bar"}}, + &RowUpdate2{Delete: &Row{"foo": "bar"}}, + }, + { + "modify then delete", + &RowUpdate2{Modify: &Row{"foo": "baz"}}, + &RowUpdate2{Delete: &Row{"foo": "baz"}}, + &RowUpdate2{Delete: &Row{"foo": "baz"}}, + }, + { + "modify then modify", + &RowUpdate2{Modify: &Row{"foo": "baz"}}, + &RowUpdate2{Modify: &Row{"bar": "quux"}}, + &RowUpdate2{Modify: &Row{"foo": "baz", "bar": "quux"}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.initial.Merge(tt.new) + assert.Equal(t, tt.expected, tt.initial) + }) + } +} diff --git a/server/mutate.go b/server/mutate.go index a41ee10d..d2425874 100644 --- a/server/mutate.go +++ b/server/mutate.go @@ -6,23 +6,23 @@ import ( "github.com/ovn-org/libovsdb/ovsdb" ) -func removeFromSlice(a, b reflect.Value) reflect.Value { +func removeFromSlice(a, b reflect.Value) (reflect.Value, bool) { for i := 0; i < a.Len(); i++ { if a.Index(i).Interface() == b.Interface() { v := reflect.AppendSlice(a.Slice(0, i), a.Slice(i+1, a.Len())) - return v + return v, true } } - return a + return a, false } -func insertToSlice(a, b reflect.Value) reflect.Value { +func insertToSlice(a, b reflect.Value) (reflect.Value, bool) { for i := 0; i < a.Len(); i++ { if a.Index(i).Interface() == b.Interface() { - return a + return a, false } } - return reflect.Append(a, b) + return reflect.Append(a, b), true } func mutate(current interface{}, mutator ovsdb.Mutator, value interface{}) (interface{}, interface{}) { @@ -33,10 +33,9 @@ func mutate(current interface{}, mutator ovsdb.Mutator, value interface{}) (inte switch mutator { case ovsdb.MutateOperationInsert: // for insert, the delta will be the new value added - return mutateInsert(current, value), value + return mutateInsert(current, value) case ovsdb.MutateOperationDelete: - // for delete, the delta will be the value removed - return mutateDelete(current, value), value + return mutateDelete(current, value) case ovsdb.MutateOperationAdd: // for add, the delta is the new value new := mutateAdd(current, value) @@ -58,74 +57,121 @@ func mutate(current interface{}, mutator ovsdb.Mutator, value interface{}) (inte return current, value } -func mutateInsert(current, value interface{}) interface{} { +func mutateInsert(current, value interface{}) (interface{}, interface{}) { switch current.(type) { case int, float64: - return current + return current, current } vc := reflect.ValueOf(current) vv := reflect.ValueOf(value) if vc.Kind() == reflect.Slice && vc.Type() == reflect.SliceOf(vv.Type()) { - v := insertToSlice(vc, vv) - return v.Interface() + v, ok := insertToSlice(vc, vv) + var diff interface{} + if ok { + diff = value + } + return v.Interface(), diff + } + if !vc.IsValid() { + if vv.IsValid() { + return vv.Interface(), vv.Interface() + } + return nil, nil } if vc.Kind() == reflect.Slice && vv.Kind() == reflect.Slice { v := vc + diff := reflect.Indirect(reflect.New(vv.Type())) for i := 0; i < vv.Len(); i++ { - v = insertToSlice(v, vv.Index(i)) + var ok bool + v, ok = insertToSlice(v, vv.Index(i)) + if ok { + diff = reflect.Append(diff, vv.Index(i)) + } } - return v.Interface() + if diff.Len() > 0 { + return v.Interface(), diff.Interface() + } + return v.Interface(), nil } if vc.Kind() == reflect.Map && vv.Kind() == reflect.Map { + diff := reflect.MakeMap(vc.Type()) iter := vv.MapRange() - if vc.IsNil() && vv.Len() > 0 { - return value - } for iter.Next() { k := iter.Key() if !vc.MapIndex(k).IsValid() { vc.SetMapIndex(k, iter.Value()) + diff.SetMapIndex(k, iter.Value()) } } + if diff.Len() > 0 { + return current, diff.Interface() + } + return current, nil } - return current + return current, nil } -func mutateDelete(current, value interface{}) interface{} { +func mutateDelete(current, value interface{}) (interface{}, interface{}) { switch current.(type) { case int, float64: - return current + return current, nil } vc := reflect.ValueOf(current) vv := reflect.ValueOf(value) if vc.Kind() == reflect.Slice && vc.Type() == reflect.SliceOf(vv.Type()) { - v := removeFromSlice(vc, vv) - return v.Interface() + v, ok := removeFromSlice(vc, vv) + diff := value + if !ok { + diff = nil + } + return v.Interface(), diff } if vc.Kind() == reflect.Slice && vv.Kind() == reflect.Slice { v := vc + diff := reflect.Indirect(reflect.New(vv.Type())) for i := 0; i < vv.Len(); i++ { - v = removeFromSlice(v, vv.Index(i)) + var ok bool + v, ok = removeFromSlice(v, vv.Index(i)) + if ok { + diff = reflect.Append(diff, vv.Index(i)) + } } - return v.Interface() + if diff.Len() > 0 { + return v.Interface(), diff.Interface() + } + return v.Interface(), nil } if vc.Kind() == reflect.Map && vv.Type() == reflect.SliceOf(vc.Type().Key()) { + diff := reflect.MakeMap(vc.Type()) for i := 0; i < vv.Len(); i++ { - vc.SetMapIndex(vv.Index(i), reflect.Value{}) + if vc.MapIndex(vv.Index(i)).IsValid() { + diff.SetMapIndex(vv.Index(i), vc.MapIndex(vv.Index(i))) + vc.SetMapIndex(vv.Index(i), reflect.Value{}) + } } + if diff.Len() > 0 { + return current, diff.Interface() + } + return current, nil } if vc.Kind() == reflect.Map && vv.Kind() == reflect.Map { + diff := reflect.MakeMap(vc.Type()) iter := vv.MapRange() for iter.Next() { vvk := iter.Key() vvv := iter.Value() vcv := vc.MapIndex(vvk) if reflect.DeepEqual(vcv.Interface(), vvv.Interface()) { + diff.SetMapIndex(vvk, vcv) vc.SetMapIndex(vvk, reflect.Value{}) } } + if diff.Len() > 0 { + return current, diff.Interface() + } + return current, nil } - return current + return current, nil } func mutateAdd(current, value interface{}) interface{} { diff --git a/server/mutate_test.go b/server/mutate_test.go index 75c36bc6..5d615d0a 100644 --- a/server/mutate_test.go +++ b/server/mutate_test.go @@ -1,10 +1,10 @@ package server import ( - "reflect" "testing" "github.com/ovn-org/libovsdb/ovsdb" + "github.com/stretchr/testify/assert" ) func TestMutateAdd(t *testing.T) { @@ -46,10 +46,9 @@ func TestMutateAdd(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, _ := mutate(tt.current, tt.mutator, tt.value) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("mutate() = %v, want %v", got, tt.want) - } + got, diff := mutate(tt.current, tt.mutator, tt.value) + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.want, diff) }) } } @@ -94,10 +93,9 @@ func TestMutateSubtract(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, _ := mutate(tt.current, tt.mutator, tt.value) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("mutate() = %v, want %v", got, tt.want) - } + got, diff := mutate(tt.current, tt.mutator, tt.value) + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.want, diff) }) } } @@ -142,10 +140,9 @@ func TestMutateMultiply(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, _ := mutate(tt.current, tt.mutator, tt.value) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("mutate() = %v, want %v", got, tt.want) - } + got, diff := mutate(tt.current, tt.mutator, tt.value) + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.want, diff) }) } } @@ -189,10 +186,9 @@ func TestMutateDivide(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, _ := mutate(tt.current, tt.mutator, tt.value) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("mutate() = %v, want %v", got, tt.want) - } + got, diff := mutate(tt.current, tt.mutator, tt.value) + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.want, diff) }) } } @@ -222,10 +218,9 @@ func TestMutateModulo(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, _ := mutate(tt.current, tt.mutator, tt.value) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("mutate() = %v, want %v", got, tt.want) - } + got, diff := mutate(tt.current, tt.mutator, tt.value) + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.want, diff) }) } } @@ -237,6 +232,7 @@ func TestMutateInsert(t *testing.T) { mutator ovsdb.Mutator value interface{} want interface{} + diff interface{} }{ { "insert single string", @@ -244,6 +240,7 @@ func TestMutateInsert(t *testing.T) { ovsdb.MutateOperationInsert, "baz", []string{"foo", "bar", "baz"}, + "baz", }, { "insert existing string", @@ -251,6 +248,7 @@ func TestMutateInsert(t *testing.T) { ovsdb.MutateOperationInsert, "baz", []string{"foo", "bar", "baz"}, + nil, }, { "insert multiple string", @@ -258,6 +256,7 @@ func TestMutateInsert(t *testing.T) { ovsdb.MutateOperationInsert, []string{"baz", "quux", "foo"}, []string{"foo", "bar", "baz", "quux"}, + []string{"baz", "quux"}, }, { "insert key value pairs", @@ -273,14 +272,20 @@ func TestMutateInsert(t *testing.T) { "foo": "bar", "baz": "quux", }, + map[string]string{ + "baz": "quux", + }, }, { "insert key value pairs on nil map", + nil, + ovsdb.MutateOperationInsert, + map[string]string{ + "foo": "bar", + }, map[string]string{ "foo": "bar", }, - ovsdb.MutateOperationInsert, - nil, map[string]string{ "foo": "bar", }, @@ -288,10 +293,9 @@ func TestMutateInsert(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, _ := mutate(tt.current, tt.mutator, tt.value) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("mutate() = %v, want %v", got, tt.want) - } + got, diff := mutate(tt.current, tt.mutator, tt.value) + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.diff, diff) }) } } @@ -303,6 +307,7 @@ func TestMutateDelete(t *testing.T) { mutator ovsdb.Mutator value interface{} want interface{} + diff interface{} }{ { "delete single string", @@ -310,6 +315,7 @@ func TestMutateDelete(t *testing.T) { ovsdb.MutateOperationDelete, "bar", []string{"foo"}, + "bar", }, { "delete multiple string", @@ -317,6 +323,7 @@ func TestMutateDelete(t *testing.T) { ovsdb.MutateOperationDelete, []string{"bar", "baz"}, []string{"foo"}, + []string{"bar", "baz"}, }, { "delete key value pairs", @@ -332,6 +339,9 @@ func TestMutateDelete(t *testing.T) { map[string]string{ "foo": "bar", }, + map[string]string{ + "baz": "quux", + }, }, { "delete keys", @@ -344,14 +354,16 @@ func TestMutateDelete(t *testing.T) { map[string]string{ "baz": "quux", }, + map[string]string{ + "foo": "bar", + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, _ := mutate(tt.current, tt.mutator, tt.value) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("mutate() = %v, want %v", got, tt.want) - } + got, diff := mutate(tt.current, tt.mutator, tt.value) + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.diff, diff) }) } } diff --git a/server/server.go b/server/server.go index 64576b9d..3450a8ff 100644 --- a/server/server.go +++ b/server/server.go @@ -9,6 +9,7 @@ import ( "github.com/cenkalti/rpc2" "github.com/cenkalti/rpc2/jsonrpc" "github.com/google/uuid" + "github.com/ovn-org/libovsdb/cache" "github.com/ovn-org/libovsdb/model" "github.com/ovn-org/libovsdb/ovsdb" ) @@ -135,6 +136,22 @@ func (o *OvsdbServer) GetSchema(client *rpc2.Client, args []interface{}, reply * return nil } +type Transaction struct { + ID uuid.UUID + Cache *cache.TableCache +} + +func NewTransaction(schema *ovsdb.DatabaseSchema, model *model.DBModel) Transaction { + cache, err := cache.NewTableCache(schema, model, nil) + if err != nil { + panic(err) + } + return Transaction{ + ID: uuid.New(), + Cache: cache, + } +} + // Transact issues a new database transaction and returns the results func (o *OvsdbServer) Transact(client *rpc2.Client, args []json.RawMessage, reply *[]ovsdb.OperationResult) error { if len(args) < 2 { diff --git a/server/server_integration_test.go b/server/server_integration_test.go index a794e78c..24a0a934 100644 --- a/server/server_integration_test.go +++ b/server/server_integration_test.go @@ -21,12 +21,14 @@ import ( // bridgeType is the simplified ORM model of the Bridge table type bridgeType struct { - UUID string `ovsdb:"_uuid"` - Name string `ovsdb:"name"` - OtherConfig map[string]string `ovsdb:"other_config"` - ExternalIds map[string]string `ovsdb:"external_ids"` - Ports []string `ovsdb:"ports"` - Status map[string]string `ovsdb:"status"` + UUID string `ovsdb:"_uuid"` + Name string `ovsdb:"name"` + DatapathType string `ovsdb:"datapath_type"` + DatapathID *string `ovsdb:"datapath_id"` + OtherConfig map[string]string `ovsdb:"other_config"` + ExternalIds map[string]string `ovsdb:"external_ids"` + Ports []string `ovsdb:"ports"` + Status map[string]string `ovsdb:"status"` } // ovsType is the simplified ORM model of the Bridge table @@ -127,9 +129,12 @@ func TestClientServerInsert(t *testing.T) { _, err = ovs.MonitorAll(context.Background()) require.NoError(t, err) + wallace := "wallace" bridgeRow := &bridgeType{ - Name: "foo", - ExternalIds: map[string]string{"go": "awesome", "docker": "made-for-each-other"}, + Name: "foo", + DatapathType: "bar", + DatapathID: &wallace, + ExternalIds: map[string]string{"go": "awesome", "docker": "made-for-each-other"}, } ops, err := ovs.Create(bridgeRow) @@ -145,6 +150,15 @@ func TestClientServerInsert(t *testing.T) { err := ovs.Get(br) return err == nil }, 2*time.Second, 500*time.Millisecond) + + br := &bridgeType{UUID: uuid} + err = ovs.Get(br) + require.NoError(t, err) + + assert.Equal(t, bridgeRow.Name, br.Name) + assert.Equal(t, bridgeRow.ExternalIds, br.ExternalIds) + assert.Equal(t, bridgeRow.DatapathType, br.DatapathType) + assert.Equal(t, *bridgeRow.DatapathID, wallace) } func TestClientServerMonitor(t *testing.T) { @@ -498,4 +512,10 @@ func TestClientServerInsertAndUpdate(t *testing.T) { } return reflect.DeepEqual(br.ExternalIds, bridgeRow.ExternalIds) }, 2*time.Second, 500*time.Millisecond) + + br := &bridgeType{UUID: uuid} + err = ovs.Get(br) + assert.NoError(t, err) + + assert.Equal(t, bridgeRow, br) } diff --git a/server/testdata/ovslite.json b/server/testdata/ovslite.json index 2c38f55c..d8c402f1 100644 --- a/server/testdata/ovslite.json +++ b/server/testdata/ovslite.json @@ -24,6 +24,17 @@ "type": "string", "mutable": false }, + "datapath_type": { + "type": "string" + }, + "datapath_id": { + "type": { + "key": "string", + "min": 0, + "max": 1 + }, + "ephemeral": true + }, "ports": { "type": { "key": { diff --git a/server/transact.go b/server/transact.go index a3f9575b..27169026 100644 --- a/server/transact.go +++ b/server/transact.go @@ -239,14 +239,21 @@ func (o *OvsdbServer) Update(database, table string, where []ovsdb.Condition, ro if err != nil { panic(err) } + oldValue, err := ovsdb.NativeToOvs(colSchema, old) if err != nil { - panic(err) + oldValue = nil } + native, err := ovsdb.OvsToNative(colSchema, value) if err != nil { panic(err) } + + if oldValue == native { + continue + } + err = info.SetField(column, native) if err != nil { panic(err) @@ -257,7 +264,10 @@ func (o *OvsdbServer) Update(database, table string, where []ovsdb.Condition, ro if err != nil { panic(err) } - rowDelta[column] = diff(oldValue, newValue) + diff := diff(oldValue, newValue) + if diff != nil { + rowDelta[column] = diff + } } newRow, err := m.NewRow(table, new) @@ -314,11 +324,11 @@ func (o *OvsdbServer) Mutate(database, table string, where []ovsdb.Condition, mu } for _, old := range rows { - info, err := mapper.NewInfo(schema, old) + oldInfo, err := mapper.NewInfo(schema, old) if err != nil { panic(err) } - uuid, _ := info.FieldByColumn("_uuid") + uuid, _ := oldInfo.FieldByColumn("_uuid") oldRow, err := m.NewRow(table, old) if err != nil { panic(err) @@ -331,17 +341,19 @@ func (o *OvsdbServer) Mutate(database, table string, where []ovsdb.Condition, mu if err != nil { panic(err) } - info, err = mapper.NewInfo(schema, new) + newInfo, err := mapper.NewInfo(schema, new) if err != nil { panic(err) } - err = info.SetField("_uuid", uuid) + err = newInfo.SetField("_uuid", uuid) if err != nil { panic(err) } rowDelta := ovsdb.NewRow() + mutateCols := make(map[string]struct{}) for _, mutation := range mutations { + mutateCols[mutation.Column] = struct{}{} column := schema.Column(mutation.Column) var nativeValue interface{} // Usually a mutation value is of the same type of the value being mutated @@ -361,38 +373,69 @@ func (o *OvsdbServer) Mutate(database, table string, where []ovsdb.Condition, mu if err := ovsdb.ValidateMutation(column, mutation.Mutator, nativeValue); err != nil { panic(err) } - current, err := info.FieldByColumn(mutation.Column) + current, err := newInfo.FieldByColumn(mutation.Column) if err != nil { panic(err) } - newValue, delta := mutate(current, mutation.Mutator, nativeValue) - if err := info.SetField(mutation.Column, newValue); err != nil { + newValue, _ := mutate(current, mutation.Mutator, nativeValue) + if err := newInfo.SetField(mutation.Column, newValue); err != nil { + panic(err) + } + } + for changed := range mutateCols { + colSchema := schema.Column(changed) + oldValueNative, err := oldInfo.FieldByColumn(changed) + if err != nil { panic(err) } - rowDelta[mutation.Column] = delta - newRow, err := m.NewRow(table, new) + newValueNative, err := newInfo.FieldByColumn(changed) if err != nil { panic(err) } - // check indexes - if err := o.db.CheckIndexes(database, table, new); err != nil { - if indexExists, ok := err.(*cache.ErrIndexExists); ok { - e := ovsdb.ConstraintViolation{} - return ovsdb.OperationResult{ - Error: e.Error(), - Details: newIndexExistsDetails(*indexExists), - }, nil - } + + oldValue, err := ovsdb.NativeToOvs(colSchema, oldValueNative) + if err != nil { + panic(err) + } + + newValue, err := ovsdb.NativeToOvs(colSchema, newValueNative) + if err != nil { + panic(err) + } + + delta := diff(oldValue, newValue) + if delta != nil { + rowDelta[changed] = delta + } + } + + // check indexes + if err := o.db.CheckIndexes(database, table, new); err != nil { + if indexExists, ok := err.(*cache.ErrIndexExists); ok { + e := ovsdb.ConstraintViolation{} return ovsdb.OperationResult{ - Error: err.Error(), + Error: e.Error(), + Details: newIndexExistsDetails(*indexExists), }, nil } - tableUpdate.AddRowUpdate(uuid.(string), &ovsdb.RowUpdate2{ - Modify: &newRow, - }) + return ovsdb.OperationResult{ + Error: err.Error(), + }, nil + } + + newRow, err := m.NewRow(table, new) + if err != nil { + panic(err) } + + tableUpdate.AddRowUpdate(uuid.(string), &ovsdb.RowUpdate2{ + Modify: &rowDelta, + Old: &oldRow, + New: &newRow, + }) } + return ovsdb.OperationResult{ Count: len(rows), }, ovsdb.TableUpdates2{ @@ -493,8 +536,11 @@ func diff(a interface{}, b interface{}) interface{} { c = append(c, replacementElem) } } - cSet, _ := ovsdb.NewOvsSet(c) - return cSet + if len(c) > 0 { + cSet, _ := ovsdb.NewOvsSet(c) + return cSet + } + return nil case ovsdb.OvsMap: originalMap := a.(ovsdb.OvsMap) replacementMap := b.(ovsdb.OvsMap) @@ -520,8 +566,11 @@ func diff(a interface{}, b interface{}) interface{} { c[k] = v } } - cMap, _ := ovsdb.NewOvsMap(c) - return cMap + if len(c) > 0 { + cMap, _ := ovsdb.NewOvsMap(c) + return cMap + } + return nil default: return b } diff --git a/server/transact_test.go b/server/transact_test.go index 50be96c5..926950d7 100644 --- a/server/transact_test.go +++ b/server/transact_test.go @@ -40,7 +40,7 @@ func TestMutateOp(t *testing.T) { Name: "foo", ExternalIds: map[string]string{ "foo": "bar", - "baz": "qux", + "baz": "quux", "waldo": "fred", }, } @@ -70,6 +70,8 @@ func TestMutateOp(t *testing.T) { }, ) assert.Equal(t, ovsdb.OperationResult{Count: 1}, gotResult) + err = o.db.Commit("Open_vSwitch", uuid.New(), gotUpdate) + require.NoError(t, err) bridgeSet, err := ovsdb.NewOvsSet([]ovsdb.UUID{{GoUUID: bridgeUUID}}) assert.Nil(t, err) @@ -77,6 +79,13 @@ func TestMutateOp(t *testing.T) { "Open_vSwitch": ovsdb.TableUpdate2{ ovsUUID: &ovsdb.RowUpdate2{ Modify: &ovsdb.Row{ + "bridges": bridgeSet, + }, + Old: &ovsdb.Row{ + // TODO: _uuid should be filtered + "_uuid": ovsdb.UUID{GoUUID: ovsUUID}, + }, + New: &ovsdb.Row{ // TODO: _uuid should be filtered "_uuid": ovsdb.UUID{GoUUID: ovsUUID}, "bridges": bridgeSet, @@ -87,7 +96,7 @@ func TestMutateOp(t *testing.T) { keyDelete, err := ovsdb.NewOvsSet([]string{"foo"}) assert.Nil(t, err) - keyValueDelete, err := ovsdb.NewOvsMap(map[string]string{"baz": "qux"}) + keyValueDelete, err := ovsdb.NewOvsMap(map[string]string{"baz": "quux"}) assert.Nil(t, err) gotResult, gotUpdate = o.Mutate( "Open_vSwitch", @@ -102,21 +111,18 @@ func TestMutateOp(t *testing.T) { ) assert.Equal(t, ovsdb.OperationResult{Count: 1}, gotResult) - // oldExternalIds, err := ovsdb.NewOvsMap(bridge.ExternalIds) - assert.Nil(t, err) - newExternalIds, err := ovsdb.NewOvsMap(map[string]string{"waldo": "fred"}) + oldExternalIds, _ := ovsdb.NewOvsMap(bridge.ExternalIds) + newExternalIds, _ := ovsdb.NewOvsMap(map[string]string{"waldo": "fred"}) + diffExternalIds, _ := ovsdb.NewOvsMap(map[string]string{"foo": "bar", "baz": "quux"}) + assert.Nil(t, err) - assert.Equal(t, ovsdb.TableUpdates2{ - "Bridge": ovsdb.TableUpdate2{ - bridgeUUID: &ovsdb.RowUpdate2{ - Modify: &ovsdb.Row{ - "_uuid": ovsdb.UUID{GoUUID: bridgeUUID}, - "name": "foo", - "external_ids": newExternalIds, - }, - }, - }, - }, gotUpdate) + + gotModify := *gotUpdate["Bridge"][bridgeUUID].Modify + gotOld := *gotUpdate["Bridge"][bridgeUUID].Old + gotNew := *gotUpdate["Bridge"][bridgeUUID].New + assert.Equal(t, diffExternalIds, gotModify["external_ids"]) + assert.Equal(t, oldExternalIds, gotOld["external_ids"]) + assert.Equal(t, newExternalIds, gotNew["external_ids"]) } func TestDiff(t *testing.T) { @@ -157,6 +163,12 @@ func TestDiff(t *testing.T) { originSetDel, setDelDiff, }, + { + "noop set", + originSet, + originSet, + nil, + }, { "add to map", originMap, @@ -175,6 +187,12 @@ func TestDiff(t *testing.T) { originMapReplace, originMapReplaceDiff, }, + { + "noop map", + originMap, + originMap, + nil, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -183,3 +201,138 @@ func TestDiff(t *testing.T) { }) } } + +func TestOvsdbServerInsert(t *testing.T) { + t.Skip("need a helper for comparing rows as map elements aren't in same order") + defDB, err := model.NewDBModel("Open_vSwitch", map[string]model.Model{ + "Open_vSwitch": &ovsType{}, + "Bridge": &bridgeType{}}) + if err != nil { + t.Fatal(err) + } + schema, err := getSchema() + if err != nil { + t.Fatal(err) + } + ovsDB := NewInMemoryDatabase(map[string]*model.DBModel{"Open_vSwitch": defDB}) + o, err := NewOvsdbServer(ovsDB, DatabaseModel{ + Model: defDB, Schema: schema}) + require.Nil(t, err) + m := mapper.NewMapper(schema) + + gromit := "gromit" + bridge := bridgeType{ + Name: "foo", + DatapathType: "bar", + DatapathID: &gromit, + ExternalIds: map[string]string{ + "foo": "bar", + "baz": "qux", + "waldo": "fred", + }, + } + bridgeUUID := uuid.NewString() + bridgeRow, err := m.NewRow("Bridge", &bridge) + require.Nil(t, err) + + res, updates := o.Insert("Open_vSwitch", "Bridge", bridgeUUID, bridgeRow) + _, err = ovsdb.CheckOperationResults([]ovsdb.OperationResult{res}, []ovsdb.Operation{{Op: "insert"}}) + require.NoError(t, err) + + err = ovsDB.Commit("Open_vSwitch", uuid.New(), updates) + assert.NoError(t, err) + + bridge.UUID = bridgeUUID + br, err := o.db.Get("Open_vSwitch", "Bridge", bridgeUUID) + assert.NoError(t, err) + assert.Equal(t, &bridge, br) + assert.Equal(t, ovsdb.TableUpdates2{ + "Bridge": { + bridgeUUID: &ovsdb.RowUpdate2{ + Insert: &bridgeRow, + New: &bridgeRow, + }, + }, + }, updates) +} + +func TestOvsdbServerUpdate(t *testing.T) { + defDB, err := model.NewDBModel("Open_vSwitch", map[string]model.Model{ + "Open_vSwitch": &ovsType{}, + "Bridge": &bridgeType{}}) + if err != nil { + t.Fatal(err) + } + schema, err := getSchema() + if err != nil { + t.Fatal(err) + } + ovsDB := NewInMemoryDatabase(map[string]*model.DBModel{"Open_vSwitch": defDB}) + o, err := NewOvsdbServer(ovsDB, DatabaseModel{ + Model: defDB, Schema: schema}) + require.Nil(t, err) + m := mapper.NewMapper(schema) + + bridge := bridgeType{ + Name: "foo", + ExternalIds: map[string]string{ + "foo": "bar", + "baz": "qux", + "waldo": "fred", + }, + } + bridgeUUID := uuid.NewString() + bridgeRow, err := m.NewRow("Bridge", &bridge) + require.Nil(t, err) + + res, updates := o.Insert("Open_vSwitch", "Bridge", bridgeUUID, bridgeRow) + _, err = ovsdb.CheckOperationResults([]ovsdb.OperationResult{res}, []ovsdb.Operation{{Op: "insert"}}) + require.NoError(t, err) + + err = ovsDB.Commit("Open_vSwitch", uuid.New(), updates) + assert.NoError(t, err) + + halloween, _ := ovsdb.NewOvsSet([]string{"halloween"}) + tests := []struct { + name string + row ovsdb.Row + expected *ovsdb.RowUpdate2 + }{ + { + "update single field", + ovsdb.Row{"datapath_type": "waldo"}, + &ovsdb.RowUpdate2{ + Modify: &ovsdb.Row{ + "datapath_type": "waldo", + }, + }, + }, + { + "update single optional field", + ovsdb.Row{"datapath_id": "halloween"}, + &ovsdb.RowUpdate2{ + Modify: &ovsdb.Row{ + "datapath_id": halloween, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + res, updates := o.Update( + "Open_vSwitch", "Bridge", + []ovsdb.Condition{{ + Column: "_uuid", Function: ovsdb.ConditionEqual, Value: ovsdb.UUID{GoUUID: bridgeUUID}, + }}, tt.row) + errs, err := ovsdb.CheckOperationResults([]ovsdb.OperationResult{res}, []ovsdb.Operation{{Op: "update"}}) + require.NoErrorf(t, err, "%+v", errs) + + bridge.UUID = bridgeUUID + row, err := o.db.Get("Open_vSwitch", "Bridge", bridgeUUID) + assert.NoError(t, err) + br := row.(*bridgeType) + assert.NotEqual(t, br, bridgeRow) + assert.Equal(t, tt.expected.Modify, updates["Bridge"][bridgeUUID].Modify) + }) + } +} diff --git a/test/ovs/ovs_integration_test.go b/test/ovs/ovs_integration_test.go index 645a9792..06796de0 100644 --- a/test/ovs/ovs_integration_test.go +++ b/test/ovs/ovs_integration_test.go @@ -133,10 +133,26 @@ type bridgeType struct { IPFIX *string `ovsdb:"ipfix"` } -// ovsType is the simplified ORM model of the Bridge table +// ovsType is the ORM model of the OVS table type ovsType struct { - UUID string `ovsdb:"_uuid"` - Bridges []string `ovsdb:"bridges"` + UUID string `ovsdb:"_uuid"` + Bridges []string `ovsdb:"bridges"` + CurCfg int `ovsdb:"cur_cfg"` + DatapathTypes []string `ovsdb:"datapath_types"` + Datapaths map[string]string `ovsdb:"datapaths"` + DbVersion *string `ovsdb:"db_version"` + DpdkInitialized bool `ovsdb:"dpdk_initialized"` + DpdkVersion *string `ovsdb:"dpdk_version"` + ExternalIDs map[string]string `ovsdb:"external_ids"` + IfaceTypes []string `ovsdb:"iface_types"` + ManagerOptions []string `ovsdb:"manager_options"` + NextCfg int `ovsdb:"next_cfg"` + OtherConfig map[string]string `ovsdb:"other_config"` + OVSVersion *string `ovsdb:"ovs_version"` + SSL *string `ovsdb:"ssl"` + Statistics map[string]string `ovsdb:"statistics"` + SystemType *string `ovsdb:"system_type"` + SystemVersion *string `ovsdb:"system_version"` } // ipfixType is a simplified ORM model for the IPFIX table