diff --git a/gnmi_server/client_subscribe.go b/gnmi_server/client_subscribe.go index 5d27177a..090c4405 100644 --- a/gnmi_server/client_subscribe.go +++ b/gnmi_server/client_subscribe.go @@ -5,6 +5,7 @@ import ( "io" "net" "sync" + "strings" "github.com/Workiva/go-datastructures/queue" log "github.com/golang/glog" @@ -207,6 +208,10 @@ func (c *Client) Run(stream gnmipb.GNMI_SubscribeServer) (err error) { c.Close() // Wait until all child go routines exited c.w.Wait() + if strings.Contains(err.Error(), "i/o timeout") { + return grpc.Errorf(codes.Internal, "%s", err) + } + return grpc.Errorf(codes.InvalidArgument, "%s", err) } diff --git a/gnmi_server/server_test.go b/gnmi_server/server_test.go index 76b0ddbe..d04e7991 100644 --- a/gnmi_server/server_test.go +++ b/gnmi_server/server_test.go @@ -3255,6 +3255,72 @@ func TestConnectionsKeepAlive(t *testing.T) { } } +func TestConnectionRedisFailure(t *testing.T) { + s := createServer(t, 8081) + go runServer(t, s) + defer s.s.Stop() + + test := struct { + desc string + q client.Query + want []client.Notification + poll int + }{ + desc: "poll query for COUNTERS/Ethernet*", + poll: 10, + q: client.Query{ + Target: "COUNTERS_DB", + Type: client.Poll, + Queries: []client.Path{{"COUNTERS", "Ethernet*"}}, + TLS: &tls.Config{InsecureSkipVerify: true}, + }, + want: []client.Notification{ + client.Connected{}, + client.Sync{}, + }, + } + namespace := sdcfg.GetDbDefaultNamespace() + rclient := getRedisClientN(t, 6, namespace) + defer rclient.Close() + + prepareStateDb(t, namespace) + t.Run(test.desc, func(t *testing.T) { + q := test.q + q.Addrs = []string{"127.0.0.1:8081"} + c := client.New() + + wg := new(sync.WaitGroup) + wg.Add(1) + + sdc.MockFail = 1 + go func() { + defer wg.Done() + if err := c.Subscribe(context.Background(), q); err != nil { + t.Errorf("c.Subscribe(): got error %v, expected nil", err) + } + }() + + wg.Wait() + sdc.MockFail = 0 + resultMap, err := rclient.HGetAll("TELEMETRY_CONNECTIONS").Result() + + if resultMap == nil { + t.Errorf("result Map is nil, expected non nil, err: %v", err) + } + if len(resultMap) != 1 { + t.Errorf("result for TELEMETRY_CONNECTIONS should be 1") + } + + for key, _ := range resultMap { + if !strings.Contains(key, "COUNTERS_DB|COUNTERS|Ethernet*") { + t.Errorf("key is expected to contain correct query, received: %s", key) + } + } + + c.Close() + }) +} + func TestClient(t *testing.T) { var mutexDeInit sync.RWMutex var mutexHB sync.RWMutex diff --git a/sonic_data_client/db_client.go b/sonic_data_client/db_client.go index 09a52d45..0b70232e 100644 --- a/sonic_data_client/db_client.go +++ b/sonic_data_client/db_client.go @@ -80,6 +80,7 @@ var IntervalTicker = func(interval time.Duration) <-chan time.Time { } var NeedMock bool = false +var MockFail int = 0 var intervalTickerMutex sync.Mutex // Define a new function to set the IntervalTicker variable @@ -744,6 +745,12 @@ func tableData2Msi(tblPath *tablePath, useKey bool, op *string, msi *map[string] return nil } + if MockFail == 1 { + MockFail++ + fmt.Printf("mock sleep for redis timeout\n") + time.Sleep(30 * time.Second) + } + for idx, dbkey := range dbkeys { fv, err = redisDb.HGetAll(dbkey).Result() if err != nil {