From 752f1fceffa2d8fa7138638dd49c2d0aec96bc7b Mon Sep 17 00:00:00 2001 From: ganglv <88995770+ganglyu@users.noreply.github.com> Date: Tue, 19 Dec 2023 09:37:06 +0800 Subject: [PATCH] Use swsscommon API to read database configuration (#176) Why I did it sonic_db_config/db_config.go reads from database configuration json to get database configuration, but swsscommon has provided the same feature. How I did it Use swsscommon API to read database configuration. Pending on below PR to clear database configuration for unit test. sonic-net/sonic-swss-common#843 How to verify it Run unit test. --- dialout/dialout_client/dialout_client.go | 29 +- dialout/dialout_client/dialout_client_test.go | 30 +- gnmi_server/connection_manager.go | 16 +- gnmi_server/server_test.go | 63 ++-- gnmi_server/transl_sub_test.go | 8 +- sonic_data_client/db_client.go | 61 +++- sonic_data_client/events_client.go | 16 +- sonic_data_client/mixed_db_client.go | 7 +- sonic_data_client/virtual_db.go | 18 +- sonic_db_config/db_config.go | 299 +++++++++--------- sonic_db_config/db_config_test.go | 148 +++++++-- 11 files changed, 446 insertions(+), 249 deletions(-) diff --git a/dialout/dialout_client/dialout_client.go b/dialout/dialout_client/dialout_client.go index a5eb625b..8c3678e5 100644 --- a/dialout/dialout_client/dialout_client.go +++ b/dialout/dialout_client/dialout_client.go @@ -460,7 +460,11 @@ func setupDestGroupClients(ctx context.Context, destGroupName string) { // start/stop/update telemetry publist client as requested // TODO: more validation on db data func processTelemetryClientConfig(ctx context.Context, redisDb *redis.Client, key string, op string) error { - separator, _ := sdc.GetTableKeySeparator("CONFIG_DB", sdcfg.GetDbDefaultNamespace()) + ns, _ := sdcfg.GetDbDefaultNamespace() + separator, err := sdc.GetTableKeySeparator("CONFIG_DB", ns) + if err != nil { + return err + } tableKey := "TELEMETRY_CLIENT" + separator + key fv, err := redisDb.HGetAll(tableKey).Result() if err != nil { @@ -640,28 +644,43 @@ func processTelemetryClientConfig(ctx context.Context, redisDb *redis.Client, ke // read configDB data for telemetry client and start publishing service for client subscription func DialOutRun(ctx context.Context, ccfg *ClientConfig) error { clientCfg = ccfg - dbn := sdcfg.GetDbId("CONFIG_DB", sdcfg.GetDbDefaultNamespace()) + ns, _ := sdcfg.GetDbDefaultNamespace() + dbn, err := sdcfg.GetDbId("CONFIG_DB", ns) + if err != nil { + return err + } var redisDb *redis.Client if sdc.UseRedisLocalTcpPort == false { + addr, err := sdcfg.GetDbSock("CONFIG_DB", ns) + if err != nil { + return err + } redisDb = redis.NewClient(&redis.Options{ Network: "unix", - Addr: sdcfg.GetDbSock("CONFIG_DB", sdcfg.GetDbDefaultNamespace()), + Addr: addr, Password: "", // no password set DB: dbn, DialTimeout: 0, }) } else { + addr, err := sdcfg.GetDbTcpAddr("CONFIG_DB", ns) + if err != nil { + return err + } redisDb = redis.NewClient(&redis.Options{ Network: "tcp", - Addr: sdcfg.GetDbTcpAddr("CONFIG_DB", sdcfg.GetDbDefaultNamespace()), + Addr: addr, Password: "", // no password set DB: dbn, DialTimeout: 0, }) } - separator, _ := sdc.GetTableKeySeparator("CONFIG_DB", sdcfg.GetDbDefaultNamespace()) + separator, err := sdc.GetTableKeySeparator("CONFIG_DB", ns) + if err != nil { + return err + } pattern := "__keyspace@" + strconv.Itoa(int(dbn)) + "__:TELEMETRY_CLIENT" + separator prefixLen := len(pattern) pattern += "*" diff --git a/dialout/dialout_client/dialout_client_test.go b/dialout/dialout_client/dialout_client_test.go index 5d0bb3be..db37aaa4 100644 --- a/dialout/dialout_client/dialout_client_test.go +++ b/dialout/dialout_client/dialout_client_test.go @@ -95,14 +95,23 @@ func runServer(t *testing.T, s *sds.Server) { } func getRedisClient(t *testing.T) *redis.Client { + ns, _ := sdcfg.GetDbDefaultNamespace() + addr, err := sdcfg.GetDbTcpAddr("COUNTERS_DB", ns) + if err != nil { + t.Fatal("failed to get addr ", err) + } + db, err := sdcfg.GetDbId("COUNTERS_DB", ns) + if err != nil { + t.Fatal("failed to get db ", err) + } rclient := redis.NewClient(&redis.Options{ Network: "tcp", - Addr: sdcfg.GetDbTcpAddr("COUNTERS_DB", sdcfg.GetDbDefaultNamespace()), + Addr: addr, Password: "", // no password set - DB: sdcfg.GetDbId("COUNTERS_DB", sdcfg.GetDbDefaultNamespace()), + DB: db, DialTimeout: 0, }) - _, err := rclient.Ping().Result() + _, err = rclient.Ping().Result() if err != nil { t.Fatal("failed to connect to redis server ", err) } @@ -124,14 +133,23 @@ func exe_cmd(t *testing.T, cmd string) { } func getConfigDbClient(t *testing.T) *redis.Client { + ns, _ := sdcfg.GetDbDefaultNamespace() + addr, err := sdcfg.GetDbTcpAddr("CONFIG_DB", ns) + if err != nil { + t.Fatal("failed to get addr ", err) + } + db, err := sdcfg.GetDbId("CONFIG_DB", ns) + if err != nil { + t.Fatal("failed to get db ", err) + } rclient := redis.NewClient(&redis.Options{ Network: "tcp", - Addr: sdcfg.GetDbTcpAddr("CONFIG_DB", sdcfg.GetDbDefaultNamespace()), + Addr: addr, Password: "", // no password set - DB: sdcfg.GetDbId("CONFIG_DB", sdcfg.GetDbDefaultNamespace()), + DB: db, DialTimeout: 0, }) - _, err := rclient.Ping().Result() + _, err = rclient.Ping().Result() if err != nil { t.Fatalf("failed to connect to redis server %v", err) } diff --git a/gnmi_server/connection_manager.go b/gnmi_server/connection_manager.go index b939c733..cc0b2cde 100644 --- a/gnmi_server/connection_manager.go +++ b/gnmi_server/connection_manager.go @@ -26,12 +26,22 @@ func (cm *ConnectionManager) GetThreshold() int { } func (cm *ConnectionManager) PrepareRedis() { - ns := sdcfg.GetDbDefaultNamespace() + ns, _ := sdcfg.GetDbDefaultNamespace() + addr, err := sdcfg.GetDbTcpAddr("STATE_DB", ns) + if err != nil { + log.Errorf("Addr err: %v", err) + return + } + db, err := sdcfg.GetDbId("STATE_DB", ns) + if err != nil { + log.Errorf("DB err: %v", err) + return + } rclient = redis.NewClient(&redis.Options{ Network: "tcp", - Addr: sdcfg.GetDbTcpAddr("STATE_DB", ns), + Addr: addr, Password: "", - DB: sdcfg.GetDbId("STATE_DB", ns), + DB: db, DialTimeout: 0, }) diff --git a/gnmi_server/server_test.go b/gnmi_server/server_test.go index 9df8bc72..6af1d58e 100644 --- a/gnmi_server/server_test.go +++ b/gnmi_server/server_test.go @@ -402,14 +402,18 @@ func runServer(t *testing.T, s *Server) { } func getRedisClientN(t *testing.T, n int, namespace string) *redis.Client { + addr, err := sdcfg.GetDbTcpAddr("COUNTERS_DB", namespace) + if err != nil { + t.Fatalf("failed to get addr %v", err) + } rclient := redis.NewClient(&redis.Options{ Network: "tcp", - Addr: sdcfg.GetDbTcpAddr("COUNTERS_DB", namespace), + Addr: addr, Password: "", // no password set DB: n, DialTimeout: 0, }) - _, err := rclient.Ping().Result() + _, err = rclient.Ping().Result() if err != nil { t.Fatalf("failed to connect to redis server %v", err) } @@ -417,15 +421,22 @@ func getRedisClientN(t *testing.T, n int, namespace string) *redis.Client { } func getRedisClient(t *testing.T, namespace string) *redis.Client { - + addr, err := sdcfg.GetDbTcpAddr("COUNTERS_DB", namespace) + if err != nil { + t.Fatalf("failed to get addr %v", err) + } + db, err := sdcfg.GetDbId("COUNTERS_DB", namespace) + if err != nil { + t.Fatalf("failed to get db %v", err) + } rclient := redis.NewClient(&redis.Options{ Network: "tcp", - Addr: sdcfg.GetDbTcpAddr("COUNTERS_DB", namespace), + Addr: addr, Password: "", // no password set - DB: sdcfg.GetDbId("COUNTERS_DB", namespace), + DB: db, DialTimeout: 0, }) - _, err := rclient.Ping().Result() + _, err = rclient.Ping().Result() if err != nil { t.Fatalf("failed to connect to redis server %v", err) } @@ -433,15 +444,22 @@ func getRedisClient(t *testing.T, namespace string) *redis.Client { } func getConfigDbClient(t *testing.T, namespace string) *redis.Client { - + addr, err := sdcfg.GetDbTcpAddr("CONFIG_DB", namespace) + if err != nil { + t.Fatalf("failed to get addr %v", err) + } + db, err := sdcfg.GetDbId("CONFIG_DB", namespace) + if err != nil { + t.Fatalf("failed to get db %v", err) + } rclient := redis.NewClient(&redis.Options{ Network: "tcp", - Addr: sdcfg.GetDbTcpAddr("CONFIG_DB", namespace), + Addr: addr, Password: "", // no password set - DB: sdcfg.GetDbId("CONFIG_DB", namespace), + DB: db, DialTimeout: 0, }) - _, err := rclient.Ping().Result() + _, err = rclient.Ping().Result() if err != nil { t.Fatalf("failed to connect to redis server %v", err) } @@ -678,7 +696,8 @@ func prepareDb(t *testing.T, namespace string) { } func prepareDbTranslib(t *testing.T) { - rclient := getRedisClient(t, sdcfg.GetDbDefaultNamespace()) + ns, _ := sdcfg.GetDbDefaultNamespace() + rclient := getRedisClient(t, ns) rclient.FlushDB() rclient.Close() @@ -698,7 +717,7 @@ func prepareDbTranslib(t *testing.T) { t.Fatalf("read file %v err: %v", fileName, err) } for n, v := range rj { - rclient := getRedisClientN(t, n, sdcfg.GetDbDefaultNamespace()) + rclient := getRedisClientN(t, n, ns) loadDBNotStrict(t, rclient, v) rclient.Close() } @@ -1208,7 +1227,8 @@ func runGnmiTestGet(t *testing.T, namespace string) { stateDBPath := "STATE_DB" - if namespace != sdcfg.GetDbDefaultNamespace() { + ns, _ := sdcfg.GetDbDefaultNamespace() + if namespace != ns { stateDBPath = "STATE_DB" + "/" + namespace } @@ -1444,9 +1464,10 @@ func TestGnmiGet(t *testing.T) { s := createServer(t, 8081) go runServer(t, s) - prepareDb(t, sdcfg.GetDbDefaultNamespace()) + ns, _ := sdcfg.GetDbDefaultNamespace() + prepareDb(t, ns) - runGnmiTestGet(t, sdcfg.GetDbDefaultNamespace()) + runGnmiTestGet(t, ns) s.s.Stop() } @@ -2715,7 +2736,8 @@ func TestGnmiSubscribe(t *testing.T) { s := createServer(t, 8081) go runServer(t, s) - runTestSubscribe(t, sdcfg.GetDbDefaultNamespace()) + ns, _ := sdcfg.GetDbDefaultNamespace() + runTestSubscribe(t, ns) s.s.Stop() } @@ -3125,7 +3147,7 @@ func TestTableKeyOnDeletion(t *testing.T) { var neighStateTableDeletedJson61 interface{} json.Unmarshal(neighStateTableDeletedByte61, &neighStateTableDeletedJson61) - namespace := sdcfg.GetDbDefaultNamespace() + namespace, _ := sdcfg.GetDbDefaultNamespace() rclient := getRedisClientN(t, 6, namespace) defer rclient.Close() prepareStateDb(t, namespace) @@ -3411,7 +3433,7 @@ func TestConnectionDataSet(t *testing.T) { }, }, } - namespace := sdcfg.GetDbDefaultNamespace() + namespace, _ := sdcfg.GetDbDefaultNamespace() rclient := getRedisClientN(t, 6, namespace) defer rclient.Close() @@ -3828,8 +3850,9 @@ print('%s') s := createServer(t, 8080) go runServer(t, s) defer s.s.Stop() - initFullConfigDb(t, sdcfg.GetDbDefaultNamespace()) - initFullCountersDb(t, sdcfg.GetDbDefaultNamespace()) + ns, _ := sdcfg.GetDbDefaultNamespace() + initFullConfigDb(t, ns) + initFullCountersDb(t, ns) path, _ := os.Getwd() path = filepath.Dir(path) diff --git a/gnmi_server/transl_sub_test.go b/gnmi_server/transl_sub_test.go index 3f90a0b1..3f0f59d8 100644 --- a/gnmi_server/transl_sub_test.go +++ b/gnmi_server/transl_sub_test.go @@ -890,9 +890,13 @@ type DbDataMap map[string]map[string]map[string]interface{} func updateDb(t *testing.T, data DbDataMap) { t.Helper() + ns, _ := dbconfig.GetDbDefaultNamespace() for dbName, tableData := range data { - n := dbconfig.GetDbId(dbName, dbconfig.GetDbDefaultNamespace()) - redis := getRedisClientN(t, n, dbconfig.GetDbDefaultNamespace()) + n, err := dbconfig.GetDbId(dbName, ns) + if err != nil { + t.Fatalf("GetDbId failed: %v", err) + } + redis := getRedisClientN(t, n, ns) defer redis.Close() for key, fields := range tableData { if fields == nil { diff --git a/sonic_data_client/db_client.go b/sonic_data_client/db_client.go index 8df9dc0c..ddca58b8 100644 --- a/sonic_data_client/db_client.go +++ b/sonic_data_client/db_client.go @@ -432,22 +432,29 @@ func GetTableKeySeparator(target string, ns string) (string, error) { return "", fmt.Errorf("%v not a valid path target", target) } - var separator string = sdcfg.GetDbSeparator(target, ns) - return separator, nil + separator, err := sdcfg.GetDbSeparator(target, ns) + return separator, err } -func GetRedisClientsForDb(target string) map[string]*redis.Client { - redis_client_map := make(map[string]*redis.Client) - if sdcfg.CheckDbMultiNamespace() { - ns_list := sdcfg.GetDbNonDefaultNamespaces() +func GetRedisClientsForDb(target string) (redis_client_map map[string]*redis.Client, err error) { + redis_client_map = make(map[string]*redis.Client) + ok, err := sdcfg.CheckDbMultiNamespace() + if err != nil { + return redis_client_map, err + } + if ok { + ns_list, err := sdcfg.GetDbNonDefaultNamespaces() + if err != nil { + return redis_client_map, err + } for _, ns := range ns_list { redis_client_map[ns] = Target2RedisDb[ns][target] } } else { - ns := sdcfg.GetDbDefaultNamespace() + ns, _ := sdcfg.GetDbDefaultNamespace() redis_client_map[ns] = Target2RedisDb[ns][target] } - return redis_client_map + return redis_client_map, nil } // This function get target present in GNMI Request and @@ -457,7 +464,7 @@ func IsTargetDb(target string) (string, bool, string, bool) { targetname := strings.Split(target, "/") dbName := targetname[0] dbNameSpaceExist := false - dbNamespace := sdcfg.GetDbDefaultNamespace() + dbNamespace, _ := sdcfg.GetDbDefaultNamespace() if len(targetname) > 2 { log.V(1).Infof("target format is not correct") @@ -478,18 +485,26 @@ func IsTargetDb(target string) (string, bool, string, bool) { } // For testing only -func useRedisTcpClient() { +func useRedisTcpClient() error { if !UseRedisLocalTcpPort { - return + return nil + } + AllNamespaces, err := sdcfg.GetDbAllNamespaces() + if err != nil { + return err } - for _, dbNamespace := range sdcfg.GetDbAllNamespaces() { + for _, dbNamespace := range AllNamespaces { Target2RedisDb[dbNamespace] = make(map[string]*redis.Client) for dbName, dbn := range spb.Target_value { if dbName != "OTHERS" { + addr, err := sdcfg.GetDbTcpAddr(dbName, dbNamespace) + if err != nil { + return err + } // DB connector for direct redis operation redisDb := redis.NewClient(&redis.Options{ Network: "tcp", - Addr: sdcfg.GetDbTcpAddr(dbName, dbNamespace), + Addr: addr, Password: "", // no password set DB: int(dbn), DialTimeout: 0, @@ -498,18 +513,29 @@ func useRedisTcpClient() { } } } + return nil } // Client package prepare redis clients to all DBs automatically func init() { - for _, dbNamespace := range sdcfg.GetDbAllNamespaces() { + AllNamespaces, err := sdcfg.GetDbAllNamespaces() + if err != nil { + log.Errorf("init error: %v", err) + return + } + for _, dbNamespace := range AllNamespaces { Target2RedisDb[dbNamespace] = make(map[string]*redis.Client) for dbName, dbn := range spb.Target_value { if dbName != "OTHERS" { + addr, err := sdcfg.GetDbSock(dbName, dbNamespace) + if err != nil { + log.Errorf("init error: %v", err) + return + } // DB connector for direct redis operation redisDb := redis.NewClient(&redis.Options{ Network: "unix", - Addr: sdcfg.GetDbSock(dbName, dbNamespace), + Addr: addr, Password: "", // no password set DB: int(dbn), DialTimeout: 0, @@ -557,7 +583,10 @@ func populateDbtablePath(prefix, path *gnmipb.Path, pathG2S *map[*gnmipb.Path][] } // Verify Namespace is valid - dbNamespace, ok := sdcfg.GetDbNamespaceFromTarget(targetDbNameSpace) + dbNamespace, ok, err := sdcfg.GetDbNamespaceFromTarget(targetDbNameSpace) + if err != nil { + return fmt.Errorf("Failed to get namespace %v", err) + } if !ok { return fmt.Errorf("Invalid target dbNameSpace %v", targetDbNameSpace) } diff --git a/sonic_data_client/events_client.go b/sonic_data_client/events_client.go index 224ccb4d..e8fa2648 100644 --- a/sonic_data_client/events_client.go +++ b/sonic_data_client/events_client.go @@ -226,13 +226,23 @@ func update_stats(evtc *EventClient) { /* Populate counters from DB for cumulative counters. */ if !evtc.isStopped() { - ns := sdcfg.GetDbDefaultNamespace() + ns, _ := sdcfg.GetDbDefaultNamespace() + addr, err := sdcfg.GetDbTcpAddr("COUNTERS_DB", ns) + if err != nil { + log.Errorf("Address error: %v", err) + return + } + dbId, err := sdcfg.GetDbId("COUNTERS_DB", ns) + if err != nil { + log.Errorf("DB error: %v", err) + return + } rclient = redis.NewClient(&redis.Options{ Network: "tcp", - Addr: sdcfg.GetDbTcpAddr("COUNTERS_DB", ns), + Addr: addr, Password: "", // no password set, - DB: sdcfg.GetDbId("COUNTERS_DB", ns), + DB: dbId, DialTimeout:0, }) diff --git a/sonic_data_client/mixed_db_client.go b/sonic_data_client/mixed_db_client.go index 8878d15f..0247cfd3 100644 --- a/sonic_data_client/mixed_db_client.go +++ b/sonic_data_client/mixed_db_client.go @@ -249,7 +249,7 @@ func NewMixedDbClient(paths []*gnmipb.Path, prefix *gnmipb.Path, origin string, return nil, err } } - _, ok, _, _ := IsTargetDb(client.target); + _, ok, _, _ := IsTargetDb(client.target) if !ok { return nil, status.Errorf(codes.Unimplemented, "Invalid target: %s", client.target) } @@ -311,7 +311,10 @@ func (c *MixedDbClient) populateDbtablePath(path *gnmipb.Path, value *gnmipb.Typ } // Verify Namespace is valid - dbNamespace, ok := sdcfg.GetDbNamespaceFromTarget(targetDbNameSpace) + dbNamespace, ok, err := sdcfg.GetDbNamespaceFromTarget(targetDbNameSpace) + if err != nil { + return err + } if !ok { return fmt.Errorf("Invalid target dbNameSpace %v", targetDbNameSpace) } diff --git a/sonic_data_client/virtual_db.go b/sonic_data_client/virtual_db.go index 6d333678..c3ded154 100644 --- a/sonic_data_client/virtual_db.go +++ b/sonic_data_client/virtual_db.go @@ -123,7 +123,11 @@ func getPfcwdMap() (map[string]map[string]string, error) { var pfcwdName_map = make(map[string]map[string]string) dbName := "CONFIG_DB" - for namespace, redisDb := range GetRedisClientsForDb(dbName) { + redis_client_map, err := GetRedisClientsForDb(dbName) + if err != nil { + return nil, err + } + for namespace, redisDb := range redis_client_map { separator, _ := GetTableKeySeparator(dbName, namespace) _, err := redisDb.Ping().Result() if err != nil { @@ -216,7 +220,11 @@ func getAliasMap() (map[string]string, map[string]string, map[string]string, err var port2namespace_map = make(map[string]string) dbName := "CONFIG_DB" - for namespace, redisDb := range GetRedisClientsForDb(dbName) { + redis_client_map, err := GetRedisClientsForDb(dbName) + if err != nil { + return nil, nil, nil, err + } + for namespace, redisDb := range redis_client_map { separator, _ := GetTableKeySeparator(dbName, namespace) _, err := redisDb.Ping().Result() if err != nil { @@ -260,7 +268,11 @@ func addmap(a map[string]string, b map[string]string) { func getCountersMap(tableName string) (map[string]string, error) { counter_map := make(map[string]string) dbName := "COUNTERS_DB" - for namespace, redisDb := range GetRedisClientsForDb(dbName) { + redis_client_map, err := GetRedisClientsForDb(dbName) + if err != nil { + return nil, err + } + for namespace, redisDb := range redis_client_map { fv, err := redisDb.HGetAll(tableName).Result() if err != nil { log.V(2).Infof("redis HGetAll failed for COUNTERS_DB in namespace %v, tableName: %s", namespace, tableName) diff --git a/sonic_db_config/db_config.go b/sonic_db_config/db_config.go index fc4729b3..ad7c2f3c 100644 --- a/sonic_db_config/db_config.go +++ b/sonic_db_config/db_config.go @@ -3,13 +3,10 @@ package dbconfig import ( - "encoding/json" - "errors" - "fmt" - io "io/ioutil" "os" - "path/filepath" + "fmt" "strconv" + "github.com/sonic-net/sonic-gnmi/swsscommon" ) const ( @@ -18,223 +15,209 @@ const ( SONIC_DEFAULT_NAMESPACE string = "" ) -var sonic_db_config = make(map[string]map[string]interface{}) var sonic_db_init bool -var sonic_db_multi_namespace bool -func GetDbDefaultNamespace() string { - return SONIC_DEFAULT_NAMESPACE +// Convert exception to error +func CatchException(err *error) { + if r := recover(); r != nil { + *err = fmt.Errorf("%v", r) + } } -func CheckDbMultiNamespace() bool { + +func GetDbDefaultNamespace() (ns string, err error) { + return SONIC_DEFAULT_NAMESPACE, nil +} + +func CheckDbMultiNamespace() (ret bool, err error) { if !sonic_db_init { - DbInit() + err = DbInit() + if err != nil { + return false, err + } } - return sonic_db_multi_namespace + defer CatchException(&err) + ns_vec := swsscommon.SonicDBConfigGetNamespaces() + length := int(ns_vec.Size()) + // If there are more than one namespaces, this means that SONiC is using multinamespace + return length > 1, err } -func GetDbNonDefaultNamespaces() []string { + +func GetDbNonDefaultNamespaces() (ns_list []string, err error) { if !sonic_db_init { - DbInit() + err = DbInit() + if err != nil { + return ns_list, err + } } - ns_list := make([]string, 0, len(sonic_db_config)) - for ns := range sonic_db_config { + defer CatchException(&err) + ns_vec := swsscommon.SonicDBConfigGetNamespaces() + // Translate from vector to array + length := int(ns_vec.Size()) + for i := 0; i < length; i += 1 { + ns := ns_vec.Get(i) if ns == SONIC_DEFAULT_NAMESPACE { continue } ns_list = append(ns_list, ns) } - return ns_list + return ns_list, err } -func GetDbAllNamespaces() []string { + +func GetDbAllNamespaces() (ns_list []string, err error) { if !sonic_db_init { - DbInit() + err = DbInit() + if err != nil { + return ns_list, err + } } - ns_list := make([]string, len(sonic_db_config)) - i := 0 - for ns := range sonic_db_config { - ns_list[i] = ns - i++ + defer CatchException(&err) + ns_vec := swsscommon.SonicDBConfigGetNamespaces() + // Translate from vector to array + length := int(ns_vec.Size()) + for i := 0; i < length; i += 1 { + ns := ns_vec.Get(i) + ns_list = append(ns_list, ns) } - return ns_list + return ns_list, err } -func GetDbNamespaceFromTarget(target string) (string, bool) { - if target == GetDbDefaultNamespace() { - return target, true +func GetDbNamespaceFromTarget(target string) (ns string, ret bool, err error) { + ns, _ = GetDbDefaultNamespace() + if target == ns { + return target, true, nil + } + ns_list, err := GetDbNonDefaultNamespaces() + if err != nil { + return "", false, err } - ns_list := GetDbNonDefaultNamespaces() for _, ns := range ns_list { if target == ns { - return target, true + return target, true, nil } } - return "", false -} -func GetDbList(ns string) map[string]interface{} { - if !sonic_db_init { - DbInit() - } - db_list, ok := sonic_db_config[ns]["DATABASES"].(map[string]interface{}) - if !ok { - panic(fmt.Errorf("DATABASES' is not valid key in database_config.json file for namespace `%v` !", ns)) - } - return db_list + return "", false, nil } -func GetDbInst(db_name string, ns string) map[string]interface{} { +func GetDbList(ns string) (db_list []string, err error) { if !sonic_db_init { - DbInit() - } - db, ok := sonic_db_config[ns]["DATABASES"].(map[string]interface{})[db_name] - if !ok { - panic(fmt.Errorf("database name '%v' is not valid in database_config.json file for namespace `%v`!", db_name, ns)) - } - inst_name, ok := db.(map[string]interface{})["instance"] - if !ok { - panic(fmt.Errorf("'instance' is not a valid field in database_config.json file for namespace `%v`!", ns)) + err = DbInit() + if err != nil { + return db_list, err + } } - inst, ok := sonic_db_config[ns]["INSTANCES"].(map[string]interface{})[inst_name.(string)] - if !ok { - panic(fmt.Errorf("instance name '%v' is not valid in database_config.json file for namespace `%v`!", inst_name, ns)) + defer CatchException(&err) + db_vec := swsscommon.SonicDBConfigGetDbList() + // Translate from vector to array + length := int(db_vec.Size()) + for i := 0; i < length; i += 1 { + ns := db_vec.Get(i) + db_list = append(db_list, ns) } - return inst.(map[string]interface{}) + return db_list, err } -func GetDbSeparator(db_name string, ns string) string { +func GetDbSeparator(db_name string, ns string) (separator string, err error) { if !sonic_db_init { - DbInit() - } - db_list := GetDbList(ns) - separator, ok := db_list[db_name].(map[string]interface{})["separator"] - if !ok { - panic(fmt.Errorf("'separator' is not a valid field in database_config.json file!")) + err = DbInit() + if err != nil { + return "", err + } } - return separator.(string) + defer CatchException(&err) + separator = swsscommon.SonicDBConfigGetSeparator(db_name, ns) + return separator, err } -func GetDbId(db_name string, ns string) int { +func GetDbId(db_name string, ns string) (id int, err error) { if !sonic_db_init { - DbInit() - } - db_list := GetDbList(ns) - id, ok := db_list[db_name].(map[string]interface{})["id"] - if !ok { - panic(fmt.Errorf("'id' is not a valid field in database_config.json file!")) + err = DbInit() + if err != nil { + return -1, err + } } - return int(id.(float64)) + defer CatchException(&err) + id = swsscommon.SonicDBConfigGetDbId(db_name, ns) + return id, err } -func GetDbSock(db_name string, ns string) string { +func GetDbSock(db_name string, ns string) (unix_socket_path string, err error) { if !sonic_db_init { - DbInit() - } - inst := GetDbInst(db_name, ns) - unix_socket_path, ok := inst["unix_socket_path"] - if !ok { - panic(fmt.Errorf("'unix_socket_path' is not a valid field in database_config.json file!")) + err = DbInit() + if err != nil { + return "", err + } } - return unix_socket_path.(string) + defer CatchException(&err) + unix_socket_path = swsscommon.SonicDBConfigGetDbSock(db_name, ns) + return unix_socket_path, err } -func GetDbHostName(db_name string, ns string) string { +func GetDbHostName(db_name string, ns string) (hostname string, err error) { if !sonic_db_init { - DbInit() - } - inst := GetDbInst(db_name, ns) - hostname, ok := inst["hostname"] - if !ok { - panic(fmt.Errorf("'hostname' is not a valid field in database_config.json file!")) + err = DbInit() + if err != nil { + return "", err + } } - return hostname.(string) + defer CatchException(&err) + hostname = swsscommon.SonicDBConfigGetDbHostname(db_name, ns) + return hostname, err } -func GetDbPort(db_name string, ns string) int { +func GetDbPort(db_name string, ns string) (port int, err error) { if !sonic_db_init { - DbInit() - } - inst := GetDbInst(db_name, ns) - port, ok := inst["port"] - if !ok { - panic(fmt.Errorf("'port' is not a valid field in database_config.json file!")) + err = DbInit() + if err != nil { + return -1, err + } } - return int(port.(float64)) + defer CatchException(&err) + port = swsscommon.SonicDBConfigGetDbPort(db_name, ns) + return port, err } -func GetDbTcpAddr(db_name string, ns string) string { +func GetDbTcpAddr(db_name string, ns string) (addr string, err error) { if !sonic_db_init { - DbInit() - } - hostname := GetDbHostName(db_name, ns) - port := GetDbPort(db_name, ns) - return hostname + ":" + strconv.Itoa(port) -} - -func DbGetNamespaceAndConfigFile(ns_to_cfgfile_map map[string]string) { - data, err := io.ReadFile(SONIC_DB_GLOBAL_CONFIG_FILE) - if err == nil { - //Ref:https://stackoverflow.com/questions/18537257/how-to-get-the-directory-of-the-currently-running-file - dir, err := filepath.Abs(filepath.Dir(SONIC_DB_GLOBAL_CONFIG_FILE)) - if err != nil { - panic(err) - } - sonic_db_global_config := make(map[string]interface{}) - err = json.Unmarshal([]byte(data), &sonic_db_global_config) + err = DbInit() if err != nil { - panic(err) - } - for _, entry := range sonic_db_global_config["INCLUDES"].([]interface{}) { - ns, ok := entry.(map[string]interface{})["namespace"] - if !ok { - ns = SONIC_DEFAULT_NAMESPACE - } - _, ok = ns_to_cfgfile_map[ns.(string)] - if ok { - panic(fmt.Errorf("Global Database config file is not valid(multiple include for same namespace!")) - } - //Ref:https://www.geeksforgeeks.org/filepath-join-function-in-golang-with-examples/ - db_include_file := filepath.Join(dir, entry.(map[string]interface{})["include"].(string)) - ns_to_cfgfile_map[ns.(string)] = db_include_file - } - if len(ns_to_cfgfile_map) > 1 { - sonic_db_multi_namespace = true - } else { - sonic_db_multi_namespace = false - } - - } else if errors.Is(err, os.ErrNotExist) { - // Ref: https://stackoverflow.com/questions/23452157/how-do-i-check-for-specific-types-of-error-among-those-returned-by-ioutil-readfi - ns_to_cfgfile_map[SONIC_DEFAULT_NAMESPACE] = SONIC_DB_CONFIG_FILE - // Tests can override the file path via an env variable - if f, ok := os.LookupEnv("DB_CONFIG_PATH"); ok { - ns_to_cfgfile_map[SONIC_DEFAULT_NAMESPACE] = f + return "", err } - sonic_db_multi_namespace = false - } else { - panic(err) } + hostname, err := GetDbHostName(db_name, ns) + if err != nil { + return "", err + } + port, err := GetDbPort(db_name, ns) + if err != nil { + return "", err + } + return hostname + ":" + strconv.Itoa(port), err } -func DbInit() { +func DbInit() (err error) { if sonic_db_init { - return + return nil } - ns_to_cfgfile_map := make(map[string]string) - // Ref: https://stackoverflow.com/questions/14928826/passing-pointers-to-maps-in-golang - DbGetNamespaceAndConfigFile(ns_to_cfgfile_map) - for ns, db_cfg_file := range ns_to_cfgfile_map { - data, err := io.ReadFile(db_cfg_file) - if err != nil { - panic(err) + defer CatchException(&err) + if _, ierr := os.Stat(SONIC_DB_GLOBAL_CONFIG_FILE); ierr == nil || os.IsExist(ierr) { + // If there's global config file, invoke SonicDBConfigInitializeGlobalConfig + if !swsscommon.SonicDBConfigIsGlobalInit() { + swsscommon.SonicDBConfigInitializeGlobalConfig() } - db_config := make(map[string]interface{}) - err = json.Unmarshal([]byte(data), &db_config) - if err != nil { - panic(err) + } else { + // If there's no global config file, invoke SonicDBConfigInitialize + if !swsscommon.SonicDBConfigIsInit() { + swsscommon.SonicDBConfigInitialize() } - sonic_db_config[ns] = db_config } sonic_db_init = true + return err } -func Init() { +func Init() (err error) { sonic_db_init = false + defer CatchException(&err) + // Clear database configuration + swsscommon.SonicDBConfigReset() + return err } diff --git a/sonic_db_config/db_config_test.go b/sonic_db_config/db_config_test.go index 841d18d7..da85b866 100644 --- a/sonic_db_config/db_config_test.go +++ b/sonic_db_config/db_config_test.go @@ -1,41 +1,43 @@ package dbconfig import ( - "os" + "fmt" "testing" - "github.com/sonic-net/sonic-gnmi/test_utils" + "github.com/agiledragon/gomonkey/v2" ) func TestGetDb(t *testing.T) { + ns, _ := GetDbDefaultNamespace() t.Run("Id", func(t *testing.T) { - db_id := GetDbId("CONFIG_DB", GetDbDefaultNamespace()) + db_id, _ := GetDbId("CONFIG_DB", ns) if db_id != 4 { t.Fatalf(`Id("") = %d, want 4, error`, db_id) } }) t.Run("Sock", func(t *testing.T) { - sock_path := GetDbSock("CONFIG_DB", GetDbDefaultNamespace()) + sock_path, _ := GetDbSock("CONFIG_DB", ns) if sock_path != "/var/run/redis/redis.sock" { t.Fatalf(`Sock("") = %q, want "/var/run/redis/redis.sock", error`, sock_path) } }) t.Run("AllNamespaces", func(t *testing.T) { - ns_list := GetDbAllNamespaces() + ns_list, _ := GetDbAllNamespaces() if len(ns_list) != 1 { t.Fatalf(`AllNamespaces("") = %q, want "1", error`, len(ns_list)) } - if ns_list[0] != GetDbDefaultNamespace() { + if ns_list[0] != ns { t.Fatalf(`AllNamespaces("") = %q, want default, error`, ns_list[0]) } }) t.Run("TcpAddr", func(t *testing.T) { - tcp_addr := GetDbTcpAddr("CONFIG_DB", GetDbDefaultNamespace()) + tcp_addr, _ := GetDbTcpAddr("CONFIG_DB", ns) if tcp_addr != "127.0.0.1:6379" { t.Fatalf(`TcpAddr("") = %q, want 127.0.0.1:6379, error`, tcp_addr) } }) } + func TestGetDbMultiNs(t *testing.T) { Init() err := test_utils.SetupMultiNamespace() @@ -50,52 +52,136 @@ func TestGetDbMultiNs(t *testing.T) { } }) + ns, _ := GetDbDefaultNamespace() t.Run("Id", func(t *testing.T) { - db_id := GetDbId("CONFIG_DB", "asic0") + db_id, _ := GetDbId("CONFIG_DB", "asic0") if db_id != 4 { t.Fatalf(`Id("") = %d, want 4, error`, db_id) } }) t.Run("Sock", func(t *testing.T) { - sock_path := GetDbSock("CONFIG_DB", "asic0") + sock_path, _ := GetDbSock("CONFIG_DB", "asic0") if sock_path != "/var/run/redis0/redis.sock" { t.Fatalf(`Sock("") = %q, want "/var/run/redis0/redis.sock", error`, sock_path) } }) t.Run("AllNamespaces", func(t *testing.T) { - ns_list := GetDbAllNamespaces() + ns_list, _ := GetDbAllNamespaces() if len(ns_list) != 2 { t.Fatalf(`AllNamespaces("") = %q, want "2", error`, len(ns_list)) } - if !((ns_list[0] == GetDbDefaultNamespace() && ns_list[1] == "asic0") || (ns_list[0] == "asic0" && ns_list[1] == GetDbDefaultNamespace())) { + if !((ns_list[0] == ns && ns_list[1] == "asic0") || (ns_list[0] == "asic0" && ns_list[1] == ns)) { t.Fatalf(`AllNamespaces("") = %q %q, want default and asic0, error`, ns_list[0], ns_list[1]) } }) t.Run("TcpAddr", func(t *testing.T) { - tcp_addr := GetDbTcpAddr("CONFIG_DB", "asic0") + tcp_addr, _ := GetDbTcpAddr("CONFIG_DB", "asic0") if tcp_addr != "127.0.0.1:6379" { t.Fatalf(`TcpAddr("") = %q, want 127.0.0.1:6379, error`, tcp_addr) } }) -} - -func TestOverrideDbConfigFile(t *testing.T) { - Init() - // Override database_config.json path to a garbage value by setting - // env DB_CONFIG_PATH and verify that GetDbId() panics - if err := os.Setenv("DB_CONFIG_PATH", "/tmp/.unknown_database_config_file.json"); err != nil { - t.Fatalf("os.Setenv failed: %v", err) - } - t.Cleanup(func() { - os.Unsetenv("DB_CONFIG_PATH") + t.Run("AllAPI", func(t *testing.T) { + Init() + _, err = CheckDbMultiNamespace() + if err != nil { + t.Fatalf(`err %v`, err) + } Init() + _, err = GetDbNonDefaultNamespaces() + if err != nil { + t.Fatalf(`err %v`, err) + } + Init() + _, err = GetDbList("asic0") + if err != nil { + t.Fatalf(`err %v`, err) + } + Init() + _, err = GetDbSeparator("CONFIG_DB", "asic0") + if err != nil { + t.Fatalf(`err %v`, err) + } + Init() + _, err = GetDbSock("CONFIG_DB", "asic0") + if err != nil { + t.Fatalf(`err %v`, err) + } + Init() + _, err = GetDbHostName("CONFIG_DB", "asic0") + if err != nil { + t.Fatalf(`err %v`, err) + } + Init() + _, err = GetDbPort("CONFIG_DB", "asic0") + if err != nil { + t.Fatalf(`err %v`, err) + } + Init() + _, err = GetDbTcpAddr("CONFIG_DB", "asic0") + if err != nil { + t.Fatalf(`err %v`, err) + } + err = DbInit() + if err != nil { + t.Fatalf(`err %v`, err) + } + }) + t.Run("AllAPIError", func(t *testing.T) { + mock1 := gomonkey.ApplyFunc(DbInit, func() (err error) { + return fmt.Errorf("Test api error") + }) + defer mock1.Reset() + var err error + Init() + _, err = CheckDbMultiNamespace() + if err == nil || err.Error() != "Test api error" { + t.Fatalf(`No expected error`) + } + Init() + _, err = GetDbNonDefaultNamespaces() + if err == nil || err.Error() != "Test api error" { + t.Fatalf(`No expected error`) + } + Init() + _, err = GetDbAllNamespaces() + if err == nil || err.Error() != "Test api error" { + t.Fatalf(`No expected error`) + } + Init() + _, err = GetDbList("asic0") + if err == nil || err.Error() != "Test api error" { + t.Fatalf(`No expected error`) + } + Init() + _, err = GetDbSeparator("CONFIG_DB", "asic0") + if err == nil || err.Error() != "Test api error" { + t.Fatalf(`No expected error`) + } + Init() + _, err = GetDbId("CONFIG_DB", "asic0") + if err == nil || err.Error() != "Test api error" { + t.Fatalf(`No expected error`) + } + Init() + _, err = GetDbSock("CONFIG_DB", "asic0") + if err == nil || err.Error() != "Test api error" { + t.Fatalf(`No expected error`) + } + Init() + _, err = GetDbHostName("CONFIG_DB", "asic0") + if err == nil || err.Error() != "Test api error" { + t.Fatalf(`No expected error`) + } + Init() + _, err = GetDbPort("CONFIG_DB", "asic0") + if err == nil || err.Error() != "Test api error" { + t.Fatalf(`No expected error`) + } + Init() + _, err = GetDbTcpAddr("CONFIG_DB", "asic0") + if err == nil || err.Error() != "Test api error" { + t.Fatalf(`No expected error`) + } }) - defer func() { - r := recover() - if err, _ := r.(error); !os.IsNotExist(err) { - t.Fatalf("Unexpected panic: %v", r) - } - }() - _ = GetDbId("CONFIG_DB", GetDbDefaultNamespace()) - t.Fatal("GetDbId() should have paniced") } +