diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel index 294f10161f69..6cc760393059 100644 --- a/pkg/server/BUILD.bazel +++ b/pkg/server/BUILD.bazel @@ -8,6 +8,7 @@ go_library( "api_v2.go", "api_v2_constants.go", "api_v2_databases_metadata.go", + "api_v2_grants.go", "api_v2_ranges.go", "api_v2_sql.go", "api_v2_sql_schema.go", @@ -425,6 +426,7 @@ go_test( size = "enormous", srcs = [ "api_v2_databases_metadata_test.go", + "api_v2_grants_test.go", "api_v2_ranges_test.go", "api_v2_sql_schema_test.go", "api_v2_sql_test.go", diff --git a/pkg/server/api_v2.go b/pkg/server/api_v2.go index 6a0cc326a424..ee5ee07b5134 100644 --- a/pkg/server/api_v2.go +++ b/pkg/server/api_v2.go @@ -51,6 +51,12 @@ import ( "github.com/gorilla/mux" ) +// Path variables. +const ( + dbIdPathVar = "database_id" + tableIdPathVar = "table_id" +) + type ApiV2System interface { health(w http.ResponseWriter, r *http.Request) listNodes(w http.ResponseWriter, r *http.Request) @@ -192,6 +198,8 @@ func registerRoutes( {"table_metadata/", a.GetTableMetadata, true, authserver.RegularRole, true}, {"table_metadata/{table_id:[0-9]+}/", a.GetTableMetadataWithDetails, true, authserver.RegularRole, true}, {"table_metadata/updatejob/", a.TableMetadataJob, true, authserver.RegularRole, true}, + {fmt.Sprintf("grants/databases/{%s:[0-9]+}/", dbIdPathVar), a.getDatabaseGrants, true, authserver.RegularRole, true}, + {fmt.Sprintf("grants/tables/{%s:[0-9]+}/", tableIdPathVar), a.getTableGrants, true, authserver.RegularRole, true}, } // For all routes requiring authentication, have the outer mux (a.mux) diff --git a/pkg/server/api_v2_grants.go b/pkg/server/api_v2_grants.go new file mode 100644 index 000000000000..1cc42decd638 --- /dev/null +++ b/pkg/server/api_v2_grants.go @@ -0,0 +1,476 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "context" + "fmt" + "net/http" + "strconv" + + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server/apiutil" + "github.com/cockroachdb/cockroach/pkg/server/authserver" + "github.com/cockroachdb/cockroach/pkg/server/srverrors" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/util/safesql" + "github.com/cockroachdb/errors" + "github.com/gorilla/mux" +) + +const ( + granteeSortKey string = "grantee" + privSortKey string = "privilege" +) + +type grantRecord struct { + Grantee string `json:"grantee"` + Privilege string `json:"privilege"` +} + +type databaseGrantsResponseWithPagination struct { + PaginatedResponse[[]grantRecord] + Name string `json:"name"` +} + +// getDatabaseGrants returns a paginated response of grants on the database with the provided id. +// +// --- +// parameters: +// +// - name: database_id +// type: integer +// description: The ID of the database to get grants for. +// in: path +// +// - name: pageNum +// type: integer +// description: The page number to retrieve. +// in: query +// required: false +// +// - name: pageSize +// type: integer +// description: The number of results to return per page. +// in: query +// required: false +// +// - name: sortBy +// type: string +// description: The column to sort by. +// in: query +// required: false +// +// - name: sortOrder +// type: string +// description: The order to sort by. Must be either "asc" or "desc", case insensitive. +// in: query +// required: false +// +// produces: +// - application/json: databaseGrantsResponseWithPagination +// +// responses: +// +// "200": +// description: A paginated response of grants on the provided database. +// "404": +// description: The db does not exist. +// "400": +// description: The request is malformed. +// "500": +// description: An internal server error occurred. +func (a *apiV2Server) getDatabaseGrants(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx = a.sqlServer.AnnotateCtx(ctx) + sqlUser := authserver.UserFromHTTPAuthInfoContext(ctx) + if r.Method != http.MethodGet { + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + + // Validate request parameters. + pathVars := mux.Vars(r) + dbId, err := strconv.Atoi(pathVars[dbIdPathVar]) + if err != nil { + http.Error(w, "invalid database id", http.StatusBadRequest) + return + } + queryValues := r.URL.Query() + pageSize, err := apiutil.GetIntQueryStringVal(queryValues, pageSizeKey) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + pageNum, err := apiutil.GetIntQueryStringVal(queryValues, pageNumKey) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + offset := 0 + if pageNum > 0 { + offset = (pageNum - 1) * pageSize + } + sortBy := getValidGrantsSortParam(queryValues.Get(sortByKey)) + var sortOrder string + if sortBy != "" { + var ok bool + sortOrder, ok = validateSortOrderValue(queryValues.Get(sortOrderKey)) + if !ok { + http.Error(w, "invalid sort order", http.StatusBadRequest) + return + } + } + + // Get the database grants. + resp, err := getDatabaseGrantsResponseWithPagination( + ctx, + a.sqlServer.internalExecutor, + dbId, sqlUser, + pageSize, + offset, + sortBy, + sortOrder) + if err != nil { + srverrors.APIV2InternalError(ctx, err, w) + return + } + + if resp.Name == "" { + http.Error(w, "database does not exist", http.StatusNotFound) + return + } + + resp.PaginationInfo.PageNum = pageNum + resp.PaginationInfo.PageSize = pageSize + + apiutil.WriteJSONResponse(ctx, w, http.StatusOK, resp) +} + +func getDatabaseGrantsResponseWithPagination( + ctx context.Context, + ie *sql.InternalExecutor, + dbId int, + sqlUser username.SQLUsername, + limit int, + offset int, + sortBy string, + sortOrder string, +) (resp databaseGrantsResponseWithPagination, retErr error) { + row, err := ie.QueryRowEx( + ctx, "get-database-name-by-id", nil, /* txn */ + sessiondata.NodeUserSessionDataOverride, ` +SELECT name FROM system.namespace +WHERE "parentID" = 0 AND "parentSchemaID" = 0 AND id = $1`, dbId) + if err != nil { + return resp, err + } + + if row == nil || row[0] == tree.DNull { + // Database id does not exist. + return resp, nil + } + + dbName := string(*row[0].(*tree.DString)) + resp.Name = dbName + + escDbName := tree.NameStringP(&dbName) + + query := safesql.NewQuery() + query.Append(fmt.Sprintf(` +SELECT grantee, privilege_type, count(*) OVER() as total_row_count +FROM %s.crdb_internal.cluster_database_privileges`, escDbName)) + + if sortBy != "" { + query.Append(fmt.Sprintf(" ORDER BY %s %s", sortBy, sortOrder)) + } + + // Pagination arguments. + if limit > 0 { + query.Append(" LIMIT $", limit) + } + if offset > 0 { + query.Append(" OFFSET $", offset) + } + + resp.Results = make([]grantRecord, 0) + it, err := ie.QueryIteratorEx(ctx, "get-database-grants", nil, /* txn */ + sessiondata.InternalExecutorOverride{User: sqlUser}, + query.String(), query.QueryArguments()..., + ) + if err != nil { + return resp, err + } + + defer func(it isql.Rows) { + retErr = errors.CombineErrors(retErr, it.Close()) + }(it) + + ok, err := it.Next(ctx) + + if err != nil || !ok { + // If ok is false, the query returned 0 rows. + return resp, err + } + + scanner := makeResultScanner(it.Types()) + for ; ok; ok, err = it.Next(ctx) { + if err != nil { + return resp, err + } + grant := grantRecord{} + row := it.Cur() + if resp.PaginationInfo.TotalResults == 0 { + resp.PaginationInfo.TotalResults = int64(tree.MustBeDInt(row[2])) + } + if err := scanner.Scan(row, "grantee", &grant.Grantee); err != nil { + return resp, err + } + if err := scanner.Scan(row, "privilege_type", &grant.Privilege); err != nil { + return resp, err + } + resp.Results = append(resp.Results, grant) + } + return resp, nil +} + +// getTableGrants returns a paginated response of grants for the provided table. +// +// --- +// parameters: +// +// - name: table_id +// type: integer +// description: The ID of the table to get grants for. +// in: path +// +// - name: pageNum +// type: integer +// description: The page number to retrieve. +// in: query +// required: false +// +// - name: pageSize +// type: integer +// description: The number of results to return per page. +// in: query +// required: false +// +// - name: sortBy +// type: string +// description: The column to sort by. +// in: query +// required: false +// +// - name: sortOrder +// type: string +// description: The order to sort by. Must be either "asc" or "desc", case insensitive. +// in: query +// required: false +// +// produces: +// - application/json: tableGrantsResponseWithPagination +// +// responses: +// +// "200": +// description: A paginated response of grants on the provided table. +// "404": +// description: The table does not exist. +// "400": +// description: The request is malformed. +// "500": +// description: An internal server error occurred. +func (a *apiV2Server) getTableGrants(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx = a.sqlServer.AnnotateCtx(ctx) + sqlUser := authserver.UserFromHTTPAuthInfoContext(ctx) + if r.Method != http.MethodGet { + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + + // Validate request parameters. + pathVars := mux.Vars(r) + tableId, err := strconv.Atoi(pathVars[tableIdPathVar]) + if err != nil { + http.Error(w, "invalid database id", http.StatusBadRequest) + return + } + queryValues := r.URL.Query() + pageSize, err := apiutil.GetIntQueryStringVal(queryValues, pageSizeKey) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + pageNum, err := apiutil.GetIntQueryStringVal(queryValues, pageNumKey) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + offset := 0 + if pageNum > 0 { + offset = (pageNum - 1) * pageSize + } + sortBy := getValidGrantsSortParam(queryValues.Get(sortByKey)) + var sortOrder string + if sortBy != "" { + var ok bool + sortOrder, ok = validateSortOrderValue(queryValues.Get(sortOrderKey)) + if !ok { + http.Error(w, "invalid sort order", http.StatusBadRequest) + return + } + } + + // Get the table grants. + resp, err := getTableGrantsResponseWithPagination( + ctx, + a.sqlServer.internalExecutor, + tableId, + sqlUser, + pageSize, + offset, + sortBy, + sortOrder) + if err != nil { + srverrors.APIV2InternalError(ctx, err, w) + return + } + if resp.TableName == "" { + http.Error(w, "table does not exist", http.StatusNotFound) + return + } + + resp.PaginationInfo.PageNum = pageNum + resp.PaginationInfo.PageSize = pageSize + + apiutil.WriteJSONResponse(ctx, w, http.StatusOK, resp) +} + +type tableGrantsResponseWithPagination struct { + PaginatedResponse[[]grantRecord] + DatabaseName string `json:"database_name"` + SchemaName string `json:"schema_name"` + TableName string `json:"table_name"` +} + +func getTableGrantsResponseWithPagination( + ctx context.Context, + ie *sql.InternalExecutor, + tableId int, + sqlUser username.SQLUsername, + limit int, + offset int, + sortBy string, + sortOrder string, +) (resp tableGrantsResponseWithPagination, retErr error) { + row, err := ie.QueryRowEx( + ctx, "get-table-name-by-id", nil, /* txn */ + sessiondata.NodeUserSessionDataOverride, ` +SELECT + t.name AS table_name, + sc.name AS schema_name, + db.name AS db_name +FROM system.namespace t + JOIN system.namespace sc ON t."parentSchemaID" = sc.id + JOIN system.namespace db on t."parentID" = db.id +WHERE t.id = $1 +`, tableId) + if err != nil { + return resp, err + } + + if row == nil || row[0] == tree.DNull { + // Table id does not exist. + return resp, nil + } + + resp.TableName = string(*row[0].(*tree.DString)) + resp.SchemaName = string(*row[1].(*tree.DString)) + resp.DatabaseName = string(*row[2].(*tree.DString)) + escDbName := tree.NameStringP(&resp.DatabaseName) + + query := safesql.NewQuery() + query.Append(fmt.Sprintf(` +SELECT grantee, privilege_type, count(*) OVER() as total_row_count +FROM %s.information_schema.table_privileges +WHERE table_name = $ AND table_schema = $`, escDbName), + resp.TableName, resp.SchemaName) + + if sortBy != "" { + query.Append(fmt.Sprintf(" ORDER BY %s %s", sortBy, sortOrder)) + } + + // Pagination arguments. + if limit > 0 { + query.Append(" LIMIT $", limit) + } + if offset > 0 { + query.Append(" OFFSET $", offset) + } + + resp.Results = make([]grantRecord, 0) + it, err := ie.QueryIteratorEx(ctx, "get-table-grants", nil, /* txn */ + sessiondata.InternalExecutorOverride{User: sqlUser}, + query.String(), query.QueryArguments()..., + ) + if err != nil { + return resp, err + } + + defer func(it isql.Rows) { + retErr = errors.CombineErrors(retErr, it.Close()) + }(it) + + ok, err := it.Next(ctx) + if err != nil || !ok { + // If ok is false, the query returned 0 rows. + return resp, err + } + + scanner := makeResultScanner(it.Types()) + for ; ok; ok, err = it.Next(ctx) { + if err != nil { + return resp, err + } + grant := grantRecord{} + row := it.Cur() + if resp.PaginationInfo.TotalResults == 0 { + resp.PaginationInfo.TotalResults = int64(tree.MustBeDInt(row[2])) + } + if err := scanner.Scan(row, "grantee", &grant.Grantee); err != nil { + return resp, err + } + if err := scanner.Scan(row, "privilege_type", &grant.Privilege); err != nil { + return resp, err + } + resp.Results = append(resp.Results, grant) + } + return resp, nil +} + +// getValidGrantsSortParam returns a valid sort parameter for the grants +// query based on the provided sortBy value. If the sortBy value is not +// valid, an empty string is returned. +func getValidGrantsSortParam(sortBy string) string { + switch sortBy { + case granteeSortKey: + return "grantee" + case privSortKey: + return "privilege_type" + default: + return "" + } +} diff --git a/pkg/server/api_v2_grants_test.go b/pkg/server/api_v2_grants_test.go new file mode 100644 index 000000000000..af1661b29920 --- /dev/null +++ b/pkg/server/api_v2_grants_test.go @@ -0,0 +1,586 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/stretchr/testify/require" +) + +func TestGetDatabaseGrants(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Setup server. + ctx := context.Background() + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + + conn := sqlutils.MakeSQLRunner(db) + // Create test databases. + conn.Exec(t, `CREATE DATABASE test_db`) + conn.Exec(t, `CREATE DATABASE no_access_db`) + + // Get the database IDs. + var testDBID, noAccessDBID, systemDbId int + conn.QueryRow(t, `SELECT id FROM system.namespace WHERE name = 'test_db'`).Scan(&testDBID) + conn.QueryRow(t, `SELECT id FROM system.namespace WHERE name = 'no_access_db'`).Scan(&noAccessDBID) + conn.QueryRow(t, `SELECT id FROM system.namespace WHERE "parentID" = 0 AND "parentSchemaID" = 0 AND name = 'system'`).Scan(&systemDbId) + + // Create test users. + users := []string{"user1", "user2", "user3"} + for _, user := range users { + conn.Exec(t, fmt.Sprintf("CREATE USER %s", user)) + } + + // Setup clients. + adminClient, err := s.GetAdminHTTPClient() + require.NoError(t, err) + + authenticatedClient, _, err := s.GetAuthenticatedHTTPClientAndCookie(username.MakeSQLUsernameFromPreNormalizedString("no_access_user"), false, 1) + require.NoError(t, err) + + // Grant different privileges to users. + conn.Exec(t, `GRANT CREATE, CONNECT ON DATABASE test_db TO user1`) + conn.Exec(t, `GRANT ZONECONFIG, BACKUP ON DATABASE test_db TO user2`) + conn.Exec(t, `GRANT ALL ON DATABASE test_db TO user3`) + conn.Exec(t, `REVOKE CONNECT ON DATABASE no_access_db FROM public`) + + // Define test cases. + testCases := []struct { + name string + dbID int + pageSize int + pageNum int + sortBy string + sortOrder string + expectedStatus int + expectedTotal int + expected []grantRecord + // If specified, only this client will be used. + // Otherwise, both the authenticated and admin client will be tested. + client *http.Client + }{ + { + name: "Grants on test_db", + dbID: testDBID, + pageSize: 0, + pageNum: 0, + expectedStatus: http.StatusOK, + expectedTotal: 8, + expected: []grantRecord{ + {Grantee: "admin", Privilege: "ALL"}, + {Grantee: "public", Privilege: "CONNECT"}, + {Grantee: "root", Privilege: "ALL"}, + {Grantee: "user1", Privilege: "CONNECT"}, + {Grantee: "user1", Privilege: "CREATE"}, + {Grantee: "user2", Privilege: "BACKUP"}, + {Grantee: "user2", Privilege: "ZONECONFIG"}, + {Grantee: "user3", Privilege: "ALL"}, + }, + }, + { + name: "Grants on test_db desc order by grantee", + dbID: testDBID, + pageSize: 0, + pageNum: 0, + sortBy: "grantee", + sortOrder: "desc", + expectedStatus: http.StatusOK, + expectedTotal: 8, + expected: []grantRecord{ + {Grantee: "user3", Privilege: "ALL"}, + {Grantee: "user2", Privilege: "BACKUP"}, + {Grantee: "user2", Privilege: "ZONECONFIG"}, + {Grantee: "user1", Privilege: "CONNECT"}, + {Grantee: "user1", Privilege: "CREATE"}, + {Grantee: "root", Privilege: "ALL"}, + {Grantee: "public", Privilege: "CONNECT"}, + {Grantee: "admin", Privilege: "ALL"}, + }, + }, + { + name: "Grants on test_db asc order by privilege", + dbID: testDBID, + pageSize: 0, + pageNum: 0, + sortBy: "privilege", + expectedStatus: http.StatusOK, + expectedTotal: 8, + expected: []grantRecord{ + {Grantee: "admin", Privilege: "ALL"}, + {Grantee: "root", Privilege: "ALL"}, + {Grantee: "user3", Privilege: "ALL"}, + {Grantee: "user2", Privilege: "BACKUP"}, + {Grantee: "public", Privilege: "CONNECT"}, + {Grantee: "user1", Privilege: "CONNECT"}, + {Grantee: "user1", Privilege: "CREATE"}, + {Grantee: "user2", Privilege: "ZONECONFIG"}, + }, + }, + { + name: "Grants on test_db limit 5", + dbID: testDBID, + pageSize: 5, + pageNum: 0, + expectedStatus: http.StatusOK, + expectedTotal: 8, + expected: []grantRecord{ + {Grantee: "admin", Privilege: "ALL"}, + {Grantee: "public", Privilege: "CONNECT"}, + {Grantee: "root", Privilege: "ALL"}, + {Grantee: "user1", Privilege: "CONNECT"}, + {Grantee: "user1", Privilege: "CREATE"}, + }, + }, + { + name: "No access to no_access_db (no_access_user)", + dbID: noAccessDBID, + pageSize: 10, + pageNum: 0, + expectedStatus: http.StatusOK, + expectedTotal: 0, + client: &authenticatedClient, + expected: []grantRecord{}, + }, + { + name: "Admin access to no_access_db", + dbID: noAccessDBID, + pageSize: 10, + pageNum: 0, + expectedStatus: http.StatusOK, + client: &adminClient, + expectedTotal: 2, + expected: []grantRecord{ + {Grantee: "admin", Privilege: "ALL"}, + {Grantee: "root", Privilege: "ALL"}, + }, + }, + { + name: "Page size and number combined on test_db", + dbID: testDBID, + pageSize: 2, + pageNum: 3, + expectedStatus: http.StatusOK, + expectedTotal: 8, + expected: []grantRecord{ + {Grantee: "user1", Privilege: "CREATE"}, + {Grantee: "user2", Privilege: "BACKUP"}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var queryParams string + if tc.pageSize > 0 { + queryParams += fmt.Sprintf("%s=%d&", pageSizeKey, tc.pageSize) + } + if tc.pageNum > 0 { + queryParams += fmt.Sprintf("%s=%d", pageNumKey, tc.pageNum) + } + if tc.sortBy != "" { + queryParams += fmt.Sprintf("%s=%s&", sortByKey, tc.sortBy) + } + if tc.sortOrder != "" { + queryParams += fmt.Sprintf("%s=%s", sortOrderKey, tc.sortOrder) + } + url := fmt.Sprintf("%s/api/v2/grants/databases/%d/?%s", s.AdminURL(), tc.dbID, queryParams) + req, err := http.NewRequest("GET", url, nil) + require.NoError(t, err) + + clients := []*http.Client{&adminClient, &authenticatedClient} + if tc.client != nil { + clients = []*http.Client{tc.client} + } + + for _, client := range clients { + resp, err := client.Do(req) + require.NoError(t, err) + + require.Equal(t, tc.expectedStatus, resp.StatusCode) + + if tc.expectedStatus == http.StatusOK { + var apiResp databaseGrantsResponseWithPagination + err = json.NewDecoder(resp.Body).Decode(&apiResp) + require.NoError(t, err) + + require.Equal(t, tc.pageSize, apiResp.PaginationInfo.PageSize) + require.Equal(t, tc.pageNum, apiResp.PaginationInfo.PageNum) + require.Equal(t, tc.expectedTotal, int(apiResp.PaginationInfo.TotalResults)) + require.Equal(t, tc.pageSize, apiResp.PaginationInfo.PageSize) + require.Equal(t, tc.pageNum, apiResp.PaginationInfo.PageNum) + require.Len(t, apiResp.Results, len(tc.expected)) + + for i, grant := range apiResp.Results { + require.Equal(t, tc.expected[i].Grantee, grant.Grantee) + require.Equal(t, tc.expected[i].Privilege, grant.Privilege) + } + } + require.NoError(t, resp.Body.Close()) + } + }) + } + + // Test for non-existent database. + t.Run("Non-existent database", func(t *testing.T) { + urlBase := fmt.Sprintf("%s/api/v2/grants/databases/", s.AdminURL()) + urls := []string{ + "not-an-int/", + "99999/?pageSize=10&pageNum=0", + "-1", + } + for _, url := range urls { + req, err := http.NewRequest("GET", urlBase+url, nil) + require.NoError(t, err) + resp, err := adminClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusNotFound, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + } + }) + + t.Run("Retrieving grants for databases with special names", func(t *testing.T) { + databases := []string{ + "mixedCaseDatabase", "database with spaces", "database-with-dashes", "databases.with spe. cial characters", + } + for _, db := range databases { + conn.Exec(t, fmt.Sprintf(`CREATE DATABASE "%s"`, db)) + dbID := 0 + conn.QueryRow(t, `SELECT id FROM system.namespace WHERE name = $1`, db).Scan(&dbID) + + url := fmt.Sprintf("%s/api/v2/grants/databases/%d/?limit=10&offset=0", s.AdminURL(), dbID) + req, err := http.NewRequest("GET", url, nil) + require.NoError(t, err) + resp, err := adminClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + } + }) + + t.Run("400 bad request", func(t *testing.T) { + urlBase := fmt.Sprintf("%s/api/v2/grants/databases/", s.AdminURL()) + urls := []string{ + "23/?pageSize=fe", + "23/?pageNum=fe", + "23/?sortBy=grantee&sortOrder=ascending", + } + for _, url := range urls { + req, err := http.NewRequest("GET", urlBase+url, nil) + require.NoError(t, err) + resp, err := adminClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + } + }) +} + +func TestGetTableGrants(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Setup server. + ctx := context.Background() + s := serverutils.StartServerOnly(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + + conn := sqlutils.MakeSQLRunner(s.ApplicationLayer().SQLConn(t)) + // Create test tables. + conn.Exec(t, `CREATE DATABASE test_db`) + conn.Exec(t, `CREATE DATABASE no_access_db`) + conn.Exec(t, `CREATE TABLE test_db.test_table (id INT PRIMARY KEY, name STRING)`) + conn.Exec(t, `CREATE TABLE test_db.no_access_table (id INT PRIMARY KEY, name STRING)`) + conn.Exec(t, `CREATE TABLE no_access_db.no_access_table (id INT PRIMARY KEY, name STRING)`) + + // Get the table IDs. + var testTableID, testDbNoAccessTableID, noAccessDbTestTableID int + conn.QueryRow(t, `SELECT id FROM system.namespace WHERE name = 'test_table' AND "parentID" = (SELECT id FROM system.namespace WHERE name = 'test_db')`).Scan(&testTableID) + conn.QueryRow(t, `SELECT id FROM system.namespace WHERE name = 'no_access_table' AND "parentID" = (SELECT id FROM system.namespace WHERE name = 'test_db')`).Scan(&testDbNoAccessTableID) + conn.QueryRow(t, `SELECT id FROM system.namespace WHERE name = 'no_access_table' AND "parentID" = (SELECT id FROM system.namespace WHERE name = 'no_access_db')`).Scan(&noAccessDbTestTableID) + + // Create test users. + users := []string{"user1", "user2", "user3"} + for _, user := range users { + conn.Exec(t, fmt.Sprintf("CREATE USER %s", user)) + } + + // Setup clients. + adminClient, err := s.GetAdminHTTPClient() + require.NoError(t, err) + + authenticatedClient, _, err := s.GetAuthenticatedHTTPClientAndCookie(username.MakeSQLUsernameFromPreNormalizedString("no_access_user"), false, 1) + require.NoError(t, err) + + // Grant different privileges to users on the test_table. + conn.Exec(t, `GRANT SELECT, INSERT ON TABLE test_db.test_table TO user1`) + conn.Exec(t, `GRANT UPDATE, DELETE ON TABLE test_db.test_table TO user2`) + conn.Exec(t, `GRANT ALL ON TABLE test_db.test_table TO user3`) + conn.Exec(t, `REVOKE CONNECT ON DATABASE no_access_db FROM public`) + + // Define test cases. + testCases := []struct { + name string + tableID int + pageSize int + pageNum int + sortBy string + sortOrder string + expectedStatus int + expectedTotal int + expected []grantRecord + client *http.Client + }{ + { + name: "Grants on test_table", + tableID: testTableID, + pageSize: 0, + pageNum: 0, + expectedStatus: http.StatusOK, + expectedTotal: 7, + expected: []grantRecord{ + {Grantee: "admin", Privilege: "ALL"}, + {Grantee: "root", Privilege: "ALL"}, + {Grantee: "user1", Privilege: "INSERT"}, + {Grantee: "user1", Privilege: "SELECT"}, + {Grantee: "user2", Privilege: "DELETE"}, + {Grantee: "user2", Privilege: "UPDATE"}, + {Grantee: "user3", Privilege: "ALL"}, + }, + }, + { + name: "Grants on test_table desc order by grantee", + tableID: testTableID, + pageSize: 0, + pageNum: 0, + sortBy: "grantee", + sortOrder: "desc", + expectedStatus: http.StatusOK, + expectedTotal: 7, + expected: []grantRecord{ + {Grantee: "user3", Privilege: "ALL"}, + {Grantee: "user2", Privilege: "DELETE"}, + {Grantee: "user2", Privilege: "UPDATE"}, + {Grantee: "user1", Privilege: "INSERT"}, + {Grantee: "user1", Privilege: "SELECT"}, + {Grantee: "root", Privilege: "ALL"}, + {Grantee: "admin", Privilege: "ALL"}, + }, + }, + { + name: "Grants on test_table asc order by privilege", + tableID: testTableID, + pageSize: 0, + pageNum: 0, + sortBy: "privilege", + expectedStatus: http.StatusOK, + expectedTotal: 7, + expected: []grantRecord{ + {Grantee: "admin", Privilege: "ALL"}, + {Grantee: "root", Privilege: "ALL"}, + {Grantee: "user3", Privilege: "ALL"}, + {Grantee: "user2", Privilege: "DELETE"}, + {Grantee: "user1", Privilege: "INSERT"}, + {Grantee: "user1", Privilege: "SELECT"}, + {Grantee: "user2", Privilege: "UPDATE"}, + }, + }, + { + name: "Grants on test_table with limit 5", + tableID: testTableID, + pageSize: 5, + pageNum: 0, + expectedStatus: http.StatusOK, + expectedTotal: 7, + expected: []grantRecord{ + {Grantee: "admin", Privilege: "ALL"}, + {Grantee: "root", Privilege: "ALL"}, + {Grantee: "user1", Privilege: "INSERT"}, + {Grantee: "user1", Privilege: "SELECT"}, + {Grantee: "user2", Privilege: "DELETE"}, + }, + }, + { + name: "Grants on test_db.no_access_table should still be visible", + tableID: testDbNoAccessTableID, + pageSize: 10, + pageNum: 0, + expectedStatus: http.StatusOK, + expectedTotal: 2, + expected: []grantRecord{ + {Grantee: "admin", Privilege: "ALL"}, + {Grantee: "root", Privilege: "ALL"}, + }, + }, + { + name: "Grants on no_access_db.no_access_table should not be visible", + tableID: noAccessDbTestTableID, + pageSize: 10, + pageNum: 0, + expectedStatus: http.StatusOK, + expectedTotal: 0, + client: &authenticatedClient, + expected: []grantRecord{}, + }, + { + name: "Admin access to no_access_db", + tableID: noAccessDbTestTableID, + pageSize: 10, + pageNum: 0, + expectedStatus: http.StatusOK, + client: &adminClient, + expectedTotal: 2, + expected: []grantRecord{ + {Grantee: "admin", Privilege: "ALL"}, + {Grantee: "root", Privilege: "ALL"}, + }, + }, + { + name: "Page size and number combined on test_table", + tableID: testTableID, + pageSize: 2, + pageNum: 2, + expectedTotal: 7, + expectedStatus: http.StatusOK, + expected: []grantRecord{ + {Grantee: "user1", Privilege: "INSERT"}, + {Grantee: "user1", Privilege: "SELECT"}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var queryParams string + if tc.pageSize > 0 { + queryParams += fmt.Sprintf("%s=%d&", pageSizeKey, tc.pageSize) + } + if tc.pageNum > 0 { + queryParams += fmt.Sprintf("%s=%d&", pageNumKey, tc.pageNum) + } + if tc.sortBy != "" { + queryParams += fmt.Sprintf("%s=%s&", sortByKey, tc.sortBy) + } + if tc.sortOrder != "" { + queryParams += fmt.Sprintf("%s=%s", sortOrderKey, tc.sortOrder) + } + url := fmt.Sprintf("%s/api/v2/grants/tables/%d/?%s", s.AdminURL(), tc.tableID, queryParams) + req, err := http.NewRequest("GET", url, nil) + require.NoError(t, err) + + clients := []*http.Client{&adminClient, &authenticatedClient} + if tc.client != nil { + clients = []*http.Client{tc.client} + } + + for _, client := range clients { + resp, err := client.Do(req) + require.NoError(t, err) + + require.Equal(t, tc.expectedStatus, resp.StatusCode) + var apiResp tableGrantsResponseWithPagination + err = json.NewDecoder(resp.Body).Decode(&apiResp) + require.NoError(t, err) + + if tc.expectedStatus != http.StatusOK { + require.Empty(t, apiResp.TableName) + continue + } + + require.Equal(t, tc.pageSize, apiResp.PaginationInfo.PageSize) + require.Equal(t, tc.pageNum, apiResp.PaginationInfo.PageNum) + require.Equal(t, tc.expectedTotal, int(apiResp.PaginationInfo.TotalResults)) + require.Equal(t, tc.pageSize, apiResp.PaginationInfo.PageSize) + require.Equal(t, tc.pageNum, apiResp.PaginationInfo.PageNum) + require.Len(t, apiResp.Results, len(tc.expected)) + + for i, grant := range apiResp.Results { + require.Equal(t, tc.expected[i].Grantee, grant.Grantee) + require.Equal(t, tc.expected[i].Privilege, grant.Privilege) + } + require.NoError(t, resp.Body.Close()) + } + }) + } + + // Test for non-existent table. + t.Run("Non-existent table", func(t *testing.T) { + urlBase := fmt.Sprintf("%s/api/v2/grants/tables/", s.AdminURL()) + urls := []string{ + "not-an-int/", + "99999/?pageSize=10&pageNum=0", + "-1", + } + for _, url := range urls { + req, err := http.NewRequest("GET", urlBase+url, nil) + require.NoError(t, err) + resp, err := adminClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusNotFound, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + } + }) + + t.Run("Retrieving grants for tables with special names", func(t *testing.T) { + tables := []string{ + "mixedCaseTable", "tables with spaces", "table-with-dashes", "tables.with spe. cial characters", + } + // Create a special schema too. + conn.Exec(t, `CREATE database "my.db name"`) + conn.Exec(t, `CREATE SCHEMA "my.db name"."my special. schema"`) + for _, table := range tables { + conn.Exec(t, fmt.Sprintf(`CREATE TABLE "my.db name"."my special. schema"."%s" (id INT PRIMARY KEY)`, table)) + tableID := 0 + conn.QueryRow(t, `SELECT id FROM system.namespace WHERE name = $1 AND "parentID" = (SELECT id FROM system.namespace WHERE name = 'my.db name')`, table).Scan(&tableID) + + url := fmt.Sprintf("%s/api/v2/grants/tables/%d/?pageSize=10&pageNum=0", s.AdminURL(), tableID) + req, err := http.NewRequest("GET", url, nil) + require.NoError(t, err) + resp, err := adminClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + apiResp := tableGrantsResponseWithPagination{} + err = json.NewDecoder(resp.Body).Decode(&apiResp) + require.NoError(t, err) + require.Equal(t, table, apiResp.TableName) + require.NoError(t, resp.Body.Close()) + } + }) + + t.Run("400 bad request", func(t *testing.T) { + urlBase := fmt.Sprintf("%s/api/v2/grants/tables/", s.AdminURL()) + urls := []string{ + "23/?pageSize=fe", + "23/?pageNum=fe", + "23/?sortBy=grantee&sortOrder=ascending", + } + for _, url := range urls { + req, err := http.NewRequest("GET", urlBase+url, nil) + require.NoError(t, err) + resp, err := adminClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + } + }) +}