diff --git a/client/client.go b/client/client.go index b71da2a8..10ea757e 100644 --- a/client/client.go +++ b/client/client.go @@ -773,14 +773,23 @@ func (o *ovsdbClient) listDbs(ctx context.Context) ([]string, error) { return dbs, err } +// logFromContext returns a Logger from ctx or return the default logger +func (o *ovsdbClient) logFromContext(ctx context.Context) *logr.Logger { + if logger, err := logr.FromContext(ctx); err == nil { + return &logger + } + return o.logger +} + // Transact performs the provided Operations on the database // RFC 7047 : transact func (o *ovsdbClient) Transact(ctx context.Context, operation ...ovsdb.Operation) ([]ovsdb.OperationResult, error) { + logger := o.logFromContext(ctx) o.rpcMutex.RLock() if o.rpcClient == nil || !o.connected { o.rpcMutex.RUnlock() if o.options.reconnect { - o.logger.V(5).Info("blocking transaction until reconnected", "operations", + logger.V(5).Info("blocking transaction until reconnected", "operations", fmt.Sprintf("%+v", operation)) ticker := time.NewTicker(50 * time.Millisecond) defer ticker.Stop() @@ -806,6 +815,7 @@ func (o *ovsdbClient) Transact(ctx context.Context, operation ...ovsdb.Operation } func (o *ovsdbClient) transact(ctx context.Context, dbName string, skipChWrite bool, operation ...ovsdb.Operation) ([]ovsdb.OperationResult, error) { + logger := o.logFromContext(ctx) var reply []ovsdb.OperationResult db := o.databases[dbName] db.modelMutex.RLock() @@ -822,7 +832,7 @@ func (o *ovsdbClient) transact(ctx context.Context, dbName string, skipChWrite b if o.rpcClient == nil { return nil, ErrNotConnected } - dbgLogger := o.logger.WithValues("database", dbName).V(4) + dbgLogger := logger.WithValues("database", dbName).V(4) if dbgLogger.Enabled() { dbgLogger.Info("transacting operations", "operations", fmt.Sprintf("%+v", operation)) } diff --git a/client/client_test.go b/client/client_test.go index c8a8148e..1f66651a 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,9 +1,11 @@ package client import ( + "bytes" "context" "encoding/json" "fmt" + "log" "math/rand" "os" "reflect" @@ -14,6 +16,8 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/cenkalti/rpc2" + "github.com/go-logr/logr" + "github.com/go-logr/stdr" "github.com/google/uuid" "github.com/ovn-org/libovsdb/cache" db "github.com/ovn-org/libovsdb/database" @@ -22,6 +26,7 @@ import ( "github.com/ovn-org/libovsdb/ovsdb" "github.com/ovn-org/libovsdb/ovsdb/serverdb" "github.com/ovn-org/libovsdb/server" + "github.com/ovn-org/libovsdb/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -800,6 +805,85 @@ func TestOperationWhenNeverConnected(t *testing.T) { } } +func TestTransactionLogger(t *testing.T) { + stdr.SetVerbosity(5) + + var defSchema ovsdb.DatabaseSchema + err := json.Unmarshal([]byte(schema), &defSchema) + require.NoError(t, err) + _, sock := newOVSDBServer(t, defDB, defSchema) + // Create client for this server's Server database + endpoint := fmt.Sprintf("unix:%s", sock) + + var defaultBuf bytes.Buffer + defaultL := stdr.New(log.New(&defaultBuf, "", log.LstdFlags)).WithName("default") + + // Create client to test transaction logger + ovs, err := newOVSDBClient(defDB, + WithEndpoint(endpoint), + WithLogger(&defaultL)) + require.NoError(t, err) + + err = ovs.Connect(context.Background()) + require.NoError(t, err) + t.Cleanup(ovs.Close) + + var s ovsdb.DatabaseSchema + err = json.Unmarshal([]byte(schema), &s) + require.NoError(t, err) + + dbModel, err := test.GetModel() + require.NoError(t, err) + m := mapper.NewMapper(dbModel.Schema) + + bridge1 := test.BridgeType{ + Name: "foo", + ExternalIds: map[string]string{ + "foo": "bar", + "baz": "quux", + "waldo": "fred", + }, + } + bridgeInfo1, err := dbModel.NewModelInfo(&bridge1) + require.NoError(t, err) + bridgeRow1, err := m.NewRow(bridgeInfo1) + require.Nil(t, err) + bridgeUUID1 := uuid.NewString() + operation1 := ovsdb.Operation{ + Op: ovsdb.OperationInsert, + Table: "Bridge", + UUID: bridgeUUID1, + Row: bridgeRow1, + } + _, _ = ovs.Transact(context.TODO(), operation1) + assert.Contains(t, defaultBuf.String(), "default") + + bridge2 := test.BridgeType{ + Name: "bar", + ExternalIds: map[string]string{ + "foo": "bar", + "baz": "quux", + "waldo": "fred", + }, + } + bridgeInfo2, err := dbModel.NewModelInfo(&bridge2) + require.NoError(t, err) + bridgeRow2, err := m.NewRow(bridgeInfo2) + require.Nil(t, err) + bridgeUUID2 := uuid.NewString() + operation2 := ovsdb.Operation{ + Op: ovsdb.OperationInsert, + Table: "Bridge", + UUID: bridgeUUID2, + Row: bridgeRow2, + } + var customBuf bytes.Buffer + customL := stdr.New(log.New(&customBuf, "", log.LstdFlags)).WithName("custom") + ctx := logr.NewContext(context.TODO(), customL) + _, _ = ovs.Transact(ctx, operation2) + assert.Contains(t, customBuf.String(), "custom") +} + func TestOperationWhenNotConnected(t *testing.T) { ovs, err := newOVSDBClient(defDB) require.NoError(t, err)