diff --git a/client/client.go b/client/client.go index 8c9a1051..53f368fc 100644 --- a/client/client.go +++ b/client/client.go @@ -87,6 +87,7 @@ type ovsdbClient struct { metrics metrics connected bool rpcClient *rpc2.Client + conn net.Conn rpcMutex sync.RWMutex // endpoints contains all possible endpoints; the first element is // the active endpoint if connected=true @@ -427,6 +428,7 @@ func (o *ovsdbClient) createRPC2Client(conn net.Conn) { if o.options.inactivityTimeout > 0 { o.trafficSeen = make(chan struct{}) } + o.conn = conn o.rpcClient = rpc2.NewClientWithCodec(jsonrpc.NewJSONCodec(conn)) o.rpcClient.SetBlocking(true) o.rpcClient.Handle("echo", func(_ *rpc2.Client, args []interface{}, reply *[]interface{}) error { @@ -748,7 +750,7 @@ func (o *ovsdbClient) update3(params []json.RawMessage, reply *[]interface{}) er func (o *ovsdbClient) getSchema(ctx context.Context, dbName string) (ovsdb.DatabaseSchema, error) { args := ovsdb.NewGetSchemaArgs(dbName) var reply ovsdb.DatabaseSchema - err := o.rpcClient.CallWithContext(ctx, "get_schema", args, &reply) + err := o.CallWithContext(ctx, "get_schema", args, &reply) if err != nil { if err == rpc2.ErrShutdown { return ovsdb.DatabaseSchema{}, ErrNotConnected @@ -763,7 +765,7 @@ func (o *ovsdbClient) getSchema(ctx context.Context, dbName string) (ovsdb.Datab // Should only be called when mutex is held func (o *ovsdbClient) listDbs(ctx context.Context) ([]string, error) { var dbs []string - err := o.rpcClient.CallWithContext(ctx, "list_dbs", nil, &dbs) + err := o.CallWithContext(ctx, "list_dbs", nil, &dbs) if err != nil { if err == rpc2.ErrShutdown { return nil, ErrNotConnected @@ -836,7 +838,7 @@ func (o *ovsdbClient) transact(ctx context.Context, dbName string, skipChWrite b if dbgLogger.Enabled() { dbgLogger.Info("transacting operations", "operations", fmt.Sprintf("%+v", operation)) } - err := o.rpcClient.CallWithContext(ctx, "transact", args, &reply) + err := o.CallWithContext(ctx, "transact", args, &reply) if err != nil { if err == rpc2.ErrShutdown { return nil, ErrNotConnected @@ -869,7 +871,7 @@ func (o *ovsdbClient) MonitorCancel(ctx context.Context, cookie MonitorCookie) e if o.rpcClient == nil { return ErrNotConnected } - err := o.rpcClient.CallWithContext(ctx, "monitor_cancel", args, &reply) + err := o.CallWithContext(ctx, "monitor_cancel", args, &reply) if err != nil { if err == rpc2.ErrShutdown { return ErrNotConnected @@ -981,15 +983,15 @@ func (o *ovsdbClient) monitor(ctx context.Context, cookie MonitorCookie, reconne switch monitor.Method { case ovsdb.MonitorRPC: var reply ovsdb.TableUpdates - err = o.rpcClient.CallWithContext(ctx, monitor.Method, args, &reply) + err = o.CallWithContext(ctx, monitor.Method, args, &reply) tableUpdates = reply case ovsdb.ConditionalMonitorRPC: var reply ovsdb.TableUpdates2 - err = o.rpcClient.CallWithContext(ctx, monitor.Method, args, &reply) + err = o.CallWithContext(ctx, monitor.Method, args, &reply) tableUpdates = reply case ovsdb.ConditionalMonitorSinceRPC: var reply ovsdb.MonitorCondSinceReply - err = o.rpcClient.CallWithContext(ctx, monitor.Method, args, &reply) + err = o.CallWithContext(ctx, monitor.Method, args, &reply) if err == nil && reply.Found { monitor.LastTransactionID = reply.LastTransactionID lastTransactionFound = true @@ -1080,7 +1082,7 @@ func (o *ovsdbClient) Echo(ctx context.Context) error { if o.rpcClient == nil { return ErrNotConnected } - err := o.rpcClient.CallWithContext(ctx, "echo", args, &reply) + err := o.CallWithContext(ctx, "echo", args, &reply) if err != nil { if err == rpc2.ErrShutdown { return ErrNotConnected @@ -1439,3 +1441,19 @@ func (o *ovsdbClient) WhereAll(m model.Model, conditions ...model.Condition) Con func (o *ovsdbClient) WhereCache(predicate interface{}) ConditionalAPI { return o.primaryDB().api.WhereCache(predicate) } + +// CallWithContext invokes the named function, waits for it to complete, and +// returns its error status, or an error from Context timeout. +func (o *ovsdbClient) CallWithContext(ctx context.Context, method string, args interface{}, reply interface{}) error { + // Set up read/write deadline for tcp connection before making + // a rpc request to the server. + if tcpConn, ok := o.conn.(*net.TCPConn); ok { + if o.options.timeout > 0 { + err := tcpConn.SetDeadline(time.Now().Add(o.options.timeout * 3)) + if err != nil { + return err + } + } + } + return o.rpcClient.CallWithContext(ctx, method, args, reply) +}