From b04f811bf4ca8de572f34424d9bc17b9125d80f0 Mon Sep 17 00:00:00 2001 From: Kun Chang Date: Tue, 5 Mar 2024 14:24:35 +0800 Subject: [PATCH] Add environment variable option to set postgres ssl mode Signed-off-by: Kun Chang --- pkg/db/v1beta1/common/const.go | 2 + pkg/db/v1beta1/postgres/postgres.go | 6 ++- pkg/db/v1beta1/postgres/postgres_test.go | 60 ++++++++++++++++++++++-- 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/pkg/db/v1beta1/common/const.go b/pkg/db/v1beta1/common/const.go index 9754f9a65dd..dce3b3e5a9c 100644 --- a/pkg/db/v1beta1/common/const.go +++ b/pkg/db/v1beta1/common/const.go @@ -41,11 +41,13 @@ const ( PostgreSQLDBHostEnvName = "KATIB_POSTGRESQL_DB_HOST" PostgreSQLDBPortEnvName = "KATIB_POSTGRESQL_DB_PORT" PostgreSQLDatabase = "KATIB_POSTGRESQL_DB_DATABASE" + PostgreSSLMode = "KATIB_POSTGRESQL_SSL_MODE" DefaultPostgreSQLUser = "katib" DefaultPostgreSQLDatabase = "katib" DefaultPostgreSQLHost = "katib-postgres" DefaultPostgreSQLPort = "5432" + DefaultPostgreSSLMode = "disable" SkipDbInitializationEnvName = "SKIP_DB_INITIALIZATION" ) diff --git a/pkg/db/v1beta1/postgres/postgres.go b/pkg/db/v1beta1/postgres/postgres.go index c89a9598636..af60d0b0efc 100644 --- a/pkg/db/v1beta1/postgres/postgres.go +++ b/pkg/db/v1beta1/postgres/postgres.go @@ -48,10 +48,12 @@ func getDbName() string { common.PostgreSQLDBPortEnvName, common.DefaultPostgreSQLPort) dbName := env.GetEnvOrDefault(common.PostgreSQLDatabase, common.DefaultPostgreSQLDatabase) + sslMode := env.GetEnvOrDefault(common.PostgreSSLMode, + common.DefaultPostgreSSLMode) psqlInfo := fmt.Sprintf("host=%s port=%s user=%s "+ - "password=%s dbname=%s sslmode=disable", - dbHost, dbPort, dbUser, dbPass, dbName) + "password=%s dbname=%s sslmode=%s", + dbHost, dbPort, dbUser, dbPass, dbName, sslMode) return psqlInfo } diff --git a/pkg/db/v1beta1/postgres/postgres_test.go b/pkg/db/v1beta1/postgres/postgres_test.go index e4bba33164b..f478107aab2 100644 --- a/pkg/db/v1beta1/postgres/postgres_test.go +++ b/pkg/db/v1beta1/postgres/postgres_test.go @@ -22,6 +22,7 @@ import ( "testing" sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/google/go-cmp/cmp" _ "github.com/lib/pq" api_pb "github.com/kubeflow/katib/pkg/apis/manager/v1beta1" @@ -129,11 +130,60 @@ func TestDeleteObservationLog(t *testing.T) { } func TestGetDbName(t *testing.T) { - // dbName := "root:@tcp(katib-mysql:3306)/katib?timeout=5s" - dbName := "host=katib-postgres port=5432 user=katib password= dbname=katib sslmode=disable" - - if getDbName() != dbName { - t.Errorf("getDbName returns wrong value %v", getDbName()) + cases := map[string]struct { + updateEnvs map[string]string + wantName string + }{ + "All parameters are default": { + wantName: "host=katib-postgres port=5432 user=katib password= dbname=katib sslmode=disable", + }, + "Set DB_USER": { + updateEnvs: map[string]string{ + common.DBUserEnvName: "testUser", + }, + wantName: "host=katib-postgres port=5432 user=testUser password= dbname=katib sslmode=disable", + }, + "Set KATIB_POSTGRESQL_DB_HOST": { + updateEnvs: map[string]string{ + common.PostgreSQLDBHostEnvName: "testHost", + }, + wantName: "host=testHost port=5432 user=katib password= dbname=katib sslmode=disable", + }, + "Set KATIB_POSTGRESQL_DB_PORT": { + updateEnvs: map[string]string{ + common.PostgreSQLDBPortEnvName: "1234", + }, + wantName: "host=katib-postgres port=1234 user=katib password= dbname=katib sslmode=disable", + }, + "Set KATIB_POSTGRESQL_DB_DATABASE": { + updateEnvs: map[string]string{ + common.PostgreSQLDatabase: "testDB", + }, + wantName: "host=katib-postgres port=5432 user=katib password= dbname=testDB sslmode=disable", + }, + "Set DB_PASSWORD": { + updateEnvs: map[string]string{ + common.DBPasswordEnvName: "testPassword", + }, + wantName: "host=katib-postgres port=5432 user=katib password=testPassword dbname=katib sslmode=disable", + }, + "Set KATIB_POSTGRESQL_SSL_MODE": { + updateEnvs: map[string]string{ + common.PostgreSSLMode: "require", + }, + wantName: "host=katib-postgres port=5432 user=katib password= dbname=katib sslmode=require", + }, } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + for k, v := range tc.updateEnvs { + t.Setenv(k, v) + } + gotName := getDbName() + if diff := cmp.Diff(tc.wantName, gotName); len(diff) != 0 { + t.Errorf("Unexpected DBName (-want,+got):\n%s", diff) + } + }) + } }