diff --git a/cmd/gen.go b/cmd/gen.go index 5849a5ce0..b563d710a 100644 --- a/cmd/gen.go +++ b/cmd/gen.go @@ -86,13 +86,15 @@ var ( return err } } - return types.Run(ctx, flags.ProjectRef, flags.DbConfig, lang.Value, schema, postgrestV9Compat, swiftAccessControl.Value, afero.NewOsFs()) + return types.Run(ctx, flags.ProjectRef, flags.DbConfig, lang.Value, schema, setDefault, postgrestV9Compat, swiftAccessControl.Value, afero.NewOsFs()) }, Example: ` supabase gen types --local supabase gen types --linked --lang=go supabase gen types --project-id abc-def-123 --schema public --schema private supabase gen types --db-url 'postgresql://...' --schema public --schema auth`, } + + setDefault bool ) func init() { @@ -106,6 +108,7 @@ func init() { typeFlags.StringSliceVarP(&schema, "schema", "s", []string{}, "Comma separated list of schema to include.") typeFlags.Var(&swiftAccessControl, "swift-access-control", "Access control for Swift generated types.") typeFlags.BoolVar(&postgrestV9Compat, "postgrest-v9-compat", false, "Generate types compatible with PostgREST v9 and below. Only use together with --db-url.") + typeFlags.BoolVar(&setDefault, "set-default", false, "Set the specified schema as the default for helper types when using a single non-public schema") genCmd.AddCommand(genTypesCmd) keyFlags := genKeysCmd.Flags() keyFlags.StringVar(&flags.ProjectRef, "project-ref", "", "Project ref of the Supabase project.") diff --git a/internal/gen/types/types.go b/internal/gen/types/types.go index 9ffc7dc72..ff7266259 100644 --- a/internal/gen/types/types.go +++ b/internal/gen/types/types.go @@ -27,7 +27,7 @@ const ( SwiftInternalAccessControl = "internal" ) -func Run(ctx context.Context, projectId string, dbConfig pgconn.Config, lang string, schemas []string, postgrestV9Compat bool, swiftAccessControl string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { +func Run(ctx context.Context, projectId string, dbConfig pgconn.Config, lang string, schemas []string, setDefault bool, postgrestV9Compat bool, swiftAccessControl string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { originalURL := utils.ToPostgresURL(dbConfig) // Add default schemas if --schema flag is not specified if len(schemas) == 0 { @@ -35,6 +35,11 @@ func Run(ctx context.Context, projectId string, dbConfig pgconn.Config, lang str } included := strings.Join(schemas, ",") + var defaultSchemaEnv string + if setDefault && len(schemas) == 1 && schemas[0] != "public" { + defaultSchemaEnv = schemas[0] + } + if projectId != "" { if lang != LangTypescript { return errors.Errorf("Unable to generate %s types for selected project. Try using --db-url flag instead.", lang) @@ -84,18 +89,24 @@ func Run(ctx context.Context, projectId string, dbConfig pgconn.Config, lang str escaped += "&sslmode=require" } + envVars := []string{ + "PG_META_DB_URL=" + escaped, + "PG_META_GENERATE_TYPES=" + lang, + "PG_META_GENERATE_TYPES_INCLUDED_SCHEMAS=" + included, + "PG_META_GENERATE_TYPES_SWIFT_ACCESS_CONTROL=" + swiftAccessControl, + fmt.Sprintf("PG_META_GENERATE_TYPES_DETECT_ONE_TO_ONE_RELATIONSHIPS=%v", !postgrestV9Compat), + } + + if defaultSchemaEnv != "" { + envVars = append(envVars, "PG_META_GENERATE_TYPES_DEFAULT_SCHEMA="+defaultSchemaEnv) + } + return utils.DockerRunOnceWithConfig( ctx, container.Config{ Image: utils.Config.Studio.PgmetaImage, - Env: []string{ - "PG_META_DB_URL=" + escaped, - "PG_META_GENERATE_TYPES=" + lang, - "PG_META_GENERATE_TYPES_INCLUDED_SCHEMAS=" + included, - "PG_META_GENERATE_TYPES_SWIFT_ACCESS_CONTROL=" + swiftAccessControl, - fmt.Sprintf("PG_META_GENERATE_TYPES_DETECT_ONE_TO_ONE_RELATIONSHIPS=%v", !postgrestV9Compat), - }, - Cmd: []string{"node", "dist/server/server.js"}, + Env: envVars, + Cmd: []string{"node", "dist/server/server.js"}, }, hostConfig, network.NetworkingConfig{}, diff --git a/internal/gen/types/types_test.go b/internal/gen/types/types_test.go index bf7ee4067..fc65f84be 100644 --- a/internal/gen/types/types_test.go +++ b/internal/gen/types/types_test.go @@ -48,7 +48,7 @@ func TestGenLocalCommand(t *testing.T) { conn := pgtest.NewConn() defer conn.Close(t) // Run test - assert.NoError(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, true, "", fsys, conn.Intercept)) + assert.NoError(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, false, true, "", fsys, conn.Intercept)) // Validate api assert.Empty(t, apitest.ListUnmatchedRequests()) }) @@ -63,7 +63,7 @@ func TestGenLocalCommand(t *testing.T) { Get("/v" + utils.Docker.ClientVersion() + "/containers/" + utils.DbId). Reply(http.StatusServiceUnavailable) // Run test - assert.Error(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, true, "", fsys)) + assert.Error(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, false, true, "", fsys)) // Validate api assert.Empty(t, apitest.ListUnmatchedRequests()) }) @@ -83,7 +83,7 @@ func TestGenLocalCommand(t *testing.T) { Get("/v" + utils.Docker.ClientVersion() + "/images"). Reply(http.StatusServiceUnavailable) // Run test - assert.Error(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, true, "", fsys)) + assert.Error(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, false, true, "", fsys)) // Validate api assert.Empty(t, apitest.ListUnmatchedRequests()) }) @@ -106,7 +106,7 @@ func TestGenLocalCommand(t *testing.T) { conn := pgtest.NewConn() defer conn.Close(t) // Run test - assert.NoError(t, Run(context.Background(), "", dbConfig, LangSwift, []string{}, true, SwiftInternalAccessControl, fsys, conn.Intercept)) + assert.NoError(t, Run(context.Background(), "", dbConfig, LangSwift, []string{}, false, true, SwiftInternalAccessControl, fsys, conn.Intercept)) // Validate api assert.Empty(t, apitest.ListUnmatchedRequests()) }) @@ -129,7 +129,7 @@ func TestGenLinkedCommand(t *testing.T) { Reply(200). JSON(api.TypescriptResponse{Types: ""}) // Run test - assert.NoError(t, Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, true, "", fsys)) + assert.NoError(t, Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, false, true, "", fsys)) // Validate api assert.Empty(t, apitest.ListUnmatchedRequests()) }) @@ -144,7 +144,7 @@ func TestGenLinkedCommand(t *testing.T) { Get("/v1/projects/" + projectId + "/types/typescript"). ReplyError(errNetwork) // Run test - err := Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, true, "", fsys) + err := Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, false, true, "", fsys) // Validate api assert.ErrorIs(t, err, errNetwork) assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -159,7 +159,7 @@ func TestGenLinkedCommand(t *testing.T) { Get("/v1/projects/" + projectId + "/types/typescript"). Reply(http.StatusServiceUnavailable) // Run test - assert.Error(t, Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, true, "", fsys)) + assert.Error(t, Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, false, true, "", fsys)) }) } @@ -184,8 +184,142 @@ func TestGenRemoteCommand(t *testing.T) { conn := pgtest.NewConn() defer conn.Close(t) // Run test - assert.NoError(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{"public"}, true, "", afero.NewMemMapFs(), conn.Intercept)) + assert.NoError(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{"public"}, false, true, "", afero.NewMemMapFs(), conn.Intercept)) // Validate api assert.Empty(t, apitest.ListUnmatchedRequests()) }) } + +func TestGenWithSetDefault(t *testing.T) { + utils.DbId = "test-db" + utils.Config.Hostname = "localhost" + utils.Config.Db.Port = 5432 + + dbConfig := pgconn.Config{ + Host: utils.Config.Hostname, + Port: utils.Config.Db.Port, + User: "admin", + Password: "password", + } + + t.Run("sets default schema env var with single non-public schema", func(t *testing.T) { + const containerId = "test-pgmeta" + imageUrl := utils.GetRegistryImageUrl(utils.Config.Studio.PgmetaImage) + fsys := afero.NewMemMapFs() + + require.NoError(t, apitest.MockDocker(utils.Docker)) + defer gock.OffAll() + + gock.New(utils.Docker.DaemonHost()). + Get("/v" + utils.Docker.ClientVersion() + "/containers/" + utils.DbId). + Reply(http.StatusOK). + JSON(container.InspectResponse{}) + + var capturedEnv []string + apitest.MockDockerStartWithEnvCapture(utils.Docker, imageUrl, containerId, &capturedEnv) + require.NoError(t, apitest.MockDockerLogs(utils.Docker, containerId, "hello world\n")) + + conn := pgtest.NewConn() + defer conn.Close(t) + + err := Run(context.Background(), "", dbConfig, LangTypescript, []string{"private"}, true, true, "", fsys, conn.Intercept) + assert.NoError(t, err) + + found := false + for _, env := range capturedEnv { + if env == "PG_META_GENERATE_TYPES_DEFAULT_SCHEMA=private" { + found = true + break + } + } + assert.True(t, found, "Expected PG_META_GENERATE_TYPES_DEFAULT_SCHEMA=private to be set in environment variables") + assert.Empty(t, apitest.ListUnmatchedRequests()) + }) + + t.Run("does not set default schema env var without flag", func(t *testing.T) { + const containerId = "test-pgmeta" + imageUrl := utils.GetRegistryImageUrl(utils.Config.Studio.PgmetaImage) + fsys := afero.NewMemMapFs() + + require.NoError(t, apitest.MockDocker(utils.Docker)) + defer gock.OffAll() + + gock.New(utils.Docker.DaemonHost()). + Get("/v" + utils.Docker.ClientVersion() + "/containers/" + utils.DbId). + Reply(http.StatusOK). + JSON(container.InspectResponse{}) + + var capturedEnv []string + apitest.MockDockerStartWithEnvCapture(utils.Docker, imageUrl, containerId, &capturedEnv) + require.NoError(t, apitest.MockDockerLogs(utils.Docker, containerId, "hello world\n")) + + conn := pgtest.NewConn() + defer conn.Close(t) + + err := Run(context.Background(), "", dbConfig, LangTypescript, []string{"private"}, false, true, "", fsys, conn.Intercept) + assert.NoError(t, err) + + for _, env := range capturedEnv { + assert.NotContains(t, env, "PG_META_GENERATE_TYPES_DEFAULT_SCHEMA", "Should not set default schema env var when flag is false") + } + assert.Empty(t, apitest.ListUnmatchedRequests()) + }) + + t.Run("does not set default schema env var with multiple schemas", func(t *testing.T) { + const containerId = "test-pgmeta" + imageUrl := utils.GetRegistryImageUrl(utils.Config.Studio.PgmetaImage) + fsys := afero.NewMemMapFs() + + require.NoError(t, apitest.MockDocker(utils.Docker)) + defer gock.OffAll() + + gock.New(utils.Docker.DaemonHost()). + Get("/v" + utils.Docker.ClientVersion() + "/containers/" + utils.DbId). + Reply(http.StatusOK). + JSON(container.InspectResponse{}) + + var capturedEnv []string + apitest.MockDockerStartWithEnvCapture(utils.Docker, imageUrl, containerId, &capturedEnv) + require.NoError(t, apitest.MockDockerLogs(utils.Docker, containerId, "hello world\n")) + + conn := pgtest.NewConn() + defer conn.Close(t) + + err := Run(context.Background(), "", dbConfig, LangTypescript, []string{"public", "private"}, true, true, "", fsys, conn.Intercept) + assert.NoError(t, err) + + for _, env := range capturedEnv { + assert.NotContains(t, env, "PG_META_GENERATE_TYPES_DEFAULT_SCHEMA", "Should not set default schema env var with multiple schemas") + } + assert.Empty(t, apitest.ListUnmatchedRequests()) + }) + + t.Run("does not set default schema env var with public schema", func(t *testing.T) { + const containerId = "test-pgmeta" + imageUrl := utils.GetRegistryImageUrl(utils.Config.Studio.PgmetaImage) + fsys := afero.NewMemMapFs() + + require.NoError(t, apitest.MockDocker(utils.Docker)) + defer gock.OffAll() + + gock.New(utils.Docker.DaemonHost()). + Get("/v" + utils.Docker.ClientVersion() + "/containers/" + utils.DbId). + Reply(http.StatusOK). + JSON(container.InspectResponse{}) + + var capturedEnv []string + apitest.MockDockerStartWithEnvCapture(utils.Docker, imageUrl, containerId, &capturedEnv) + require.NoError(t, apitest.MockDockerLogs(utils.Docker, containerId, "hello world\n")) + + conn := pgtest.NewConn() + defer conn.Close(t) + + err := Run(context.Background(), "", dbConfig, LangTypescript, []string{"public"}, true, true, "", fsys, conn.Intercept) + assert.NoError(t, err) + + for _, env := range capturedEnv { + assert.NotContains(t, env, "PG_META_GENERATE_TYPES_DEFAULT_SCHEMA", "Should not set default schema env var with public schema") + } + assert.Empty(t, apitest.ListUnmatchedRequests()) + }) +} diff --git a/internal/testing/apitest/docker.go b/internal/testing/apitest/docker.go index 17a14355d..cc7f8d6b2 100644 --- a/internal/testing/apitest/docker.go +++ b/internal/testing/apitest/docker.go @@ -2,6 +2,7 @@ package apitest import ( "bytes" + "encoding/json" "fmt" "net/http" @@ -113,6 +114,41 @@ func MockDockerLogsExitCode(docker *client.Client, containerID string, exitCode return setupDockerLogs(docker, containerID, "", exitCode) } +// MockDockerStartWithEnvCapture extends MockDockerStart to capture environment variables +// passed to container creation. This is useful for testing environment variable logic. +func MockDockerStartWithEnvCapture(docker *client.Client, imageID, containerID string, capturedEnv *[]string) { + gock.New(docker.DaemonHost()). + Get("/v" + docker.ClientVersion() + "/images/" + imageID + "/json"). + Reply(http.StatusOK). + JSON(image.InspectResponse{}) + gock.New(docker.DaemonHost()). + Post("/v" + docker.ClientVersion() + "/networks/create"). + Reply(http.StatusCreated). + JSON(network.CreateResponse{}) + gock.New(docker.DaemonHost()). + Post("/v" + docker.ClientVersion() + "/volumes/create"). + Persist(). + Reply(http.StatusCreated). + JSON(volume.Volume{}) + gock.New(docker.DaemonHost()). + Post("/v" + docker.ClientVersion() + "/containers/create"). + AddMatcher(func(req *http.Request, ereq *gock.Request) (bool, error) { + var config struct { + Env []string `json:"Env"` + } + if err := json.NewDecoder(req.Body).Decode(&config); err != nil { + return false, err + } + *capturedEnv = config.Env + return true, nil + }). + Reply(http.StatusOK). + JSON(container.CreateResponse{ID: containerID}) + gock.New(docker.DaemonHost()). + Post("/v" + docker.ClientVersion() + "/containers/" + containerID + "/start"). + Reply(http.StatusAccepted) +} + func ListUnmatchedRequests() []string { result := make([]string, len(gock.GetUnmatchedRequests())) for i, r := range gock.GetUnmatchedRequests() {