From 475753f3266cf0056d4fb23c8ce2d885cd511942 Mon Sep 17 00:00:00 2001 From: Marco Dinis Date: Wed, 18 Dec 2024 09:16:59 +0000 Subject: [PATCH] AWS OIDC: List Deployed Database Services - implementation (#49331) * AWS OIDC: List Deployed Database Services - implementation This PR implements the List Deployed Database Services. This will be used to let the user know which deployed database services were deployed during the Enroll New Resource / RDS flows. * validate region for dashboard url --- lib/auth/integration/integrationv1/awsoidc.go | 52 +++ .../integration/integrationv1/awsoidc_test.go | 10 + lib/integrations/awsoidc/deployservice.go | 15 +- .../awsoidc/listdeployeddatabaseservice.go | 194 ++++++++++ .../listdeployeddatabaseservice_test.go | 360 ++++++++++++++++++ 5 files changed, 628 insertions(+), 3 deletions(-) create mode 100644 lib/integrations/awsoidc/listdeployeddatabaseservice.go create mode 100644 lib/integrations/awsoidc/listdeployeddatabaseservice_test.go diff --git a/lib/auth/integration/integrationv1/awsoidc.go b/lib/auth/integration/integrationv1/awsoidc.go index dfb1b154f5934..bcdff34276968 100644 --- a/lib/auth/integration/integrationv1/awsoidc.go +++ b/lib/auth/integration/integrationv1/awsoidc.go @@ -495,6 +495,58 @@ func (s *AWSOIDCService) DeployDatabaseService(ctx context.Context, req *integra }, nil } +// ListDeployedDatabaseServices lists Database Services deployed into Amazon ECS. +func (s *AWSOIDCService) ListDeployedDatabaseServices(ctx context.Context, req *integrationpb.ListDeployedDatabaseServicesRequest) (*integrationpb.ListDeployedDatabaseServicesResponse, error) { + authCtx, err := s.authorizer.Authorize(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + if err := authCtx.CheckAccessToKind(types.KindIntegration, types.VerbUse); err != nil { + return nil, trace.Wrap(err) + } + + clusterName, err := s.cache.GetClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + + awsClientReq, err := s.awsClientReq(ctx, req.Integration, req.Region) + if err != nil { + return nil, trace.Wrap(err) + } + + listDatabaseServicesClient, err := awsoidc.NewListDeployedDatabaseServicesClient(ctx, awsClientReq) + if err != nil { + return nil, trace.Wrap(err) + } + + listDatabaseServicesResponse, err := awsoidc.ListDeployedDatabaseServices(ctx, listDatabaseServicesClient, awsoidc.ListDeployedDatabaseServicesRequest{ + Integration: req.Integration, + TeleportClusterName: clusterName.GetClusterName(), + Region: req.Region, + NextToken: req.NextToken, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + deployedDatabaseServices := make([]*integrationpb.DeployedDatabaseService, 0, len(listDatabaseServicesResponse.DeployedDatabaseServices)) + for _, deployedService := range listDatabaseServicesResponse.DeployedDatabaseServices { + deployedDatabaseServices = append(deployedDatabaseServices, &integrationpb.DeployedDatabaseService{ + Name: deployedService.Name, + ServiceDashboardUrl: deployedService.ServiceDashboardURL, + ContainerEntryPoint: deployedService.ContainerEntryPoint, + ContainerCommand: deployedService.ContainerCommand, + }) + } + + return &integrationpb.ListDeployedDatabaseServicesResponse{ + DeployedDatabaseServices: deployedDatabaseServices, + NextToken: listDatabaseServicesResponse.NextToken, + }, nil +} + // EnrollEKSClusters enrolls EKS clusters into Teleport by installing teleport-kube-agent chart on the clusters. func (s *AWSOIDCService) EnrollEKSClusters(ctx context.Context, req *integrationpb.EnrollEKSClustersRequest) (*integrationpb.EnrollEKSClustersResponse, error) { authCtx, err := s.authorizer.Authorize(ctx) diff --git a/lib/auth/integration/integrationv1/awsoidc_test.go b/lib/auth/integration/integrationv1/awsoidc_test.go index f6cd0e925a48f..6a2497229ab38 100644 --- a/lib/auth/integration/integrationv1/awsoidc_test.go +++ b/lib/auth/integration/integrationv1/awsoidc_test.go @@ -423,6 +423,16 @@ func TestRBAC(t *testing.T) { return err }, }, + { + name: "ListDeployedDatabaseServices", + fn: func() error { + _, err := awsoidService.ListDeployedDatabaseServices(userCtx, &integrationv1.ListDeployedDatabaseServicesRequest{ + Integration: integrationName, + Region: "my-region", + }) + return err + }, + }, } { t.Run(tt.name, func(t *testing.T) { err := tt.fn() diff --git a/lib/integrations/awsoidc/deployservice.go b/lib/integrations/awsoidc/deployservice.go index b9fbc4b99c458..17bfe3a470954 100644 --- a/lib/integrations/awsoidc/deployservice.go +++ b/lib/integrations/awsoidc/deployservice.go @@ -34,6 +34,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" + apiaws "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/integrations/awsoidc/tags" "github.com/gravitational/teleport/lib/utils/teleportassets" @@ -445,16 +446,24 @@ func DeployService(ctx context.Context, clt DeployServiceClient, req DeployServi return nil, trace.Wrap(err) } - serviceDashboardURL := fmt.Sprintf("https://%s.console.aws.amazon.com/ecs/v2/clusters/%s/services/%s", req.Region, aws.ToString(req.ClusterName), aws.ToString(req.ServiceName)) - return &DeployServiceResponse{ ClusterARN: aws.ToString(cluster.ClusterArn), ServiceARN: aws.ToString(service.ServiceArn), TaskDefinitionARN: taskDefinitionARN, - ServiceDashboardURL: serviceDashboardURL, + ServiceDashboardURL: serviceDashboardURL(req.Region, aws.ToString(req.ClusterName), aws.ToString(service.ServiceName)), }, nil } +// serviceDashboardURL builds the ECS Service dashboard URL using the AWS Region, the ECS Cluster and Service Names. +// It returns an empty string if region is not valid. +func serviceDashboardURL(region, clusterName, serviceName string) string { + if err := apiaws.IsValidRegion(region); err != nil { + return "" + } + + return fmt.Sprintf("https://%s.console.aws.amazon.com/ecs/v2/clusters/%s/services/%s", region, clusterName, serviceName) +} + type upsertTaskRequest struct { TaskName string TaskRoleARN string diff --git a/lib/integrations/awsoidc/listdeployeddatabaseservice.go b/lib/integrations/awsoidc/listdeployeddatabaseservice.go new file mode 100644 index 0000000000000..c2894902f78fe --- /dev/null +++ b/lib/integrations/awsoidc/listdeployeddatabaseservice.go @@ -0,0 +1,194 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package awsoidc + +import ( + "context" + "log/slog" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/integrations/awsoidc/tags" +) + +// ListDeployedDatabaseServicesRequest contains the required fields to list the deployed database services in Amazon ECS. +type ListDeployedDatabaseServicesRequest struct { + // Region is the AWS Region. + Region string + // Integration is the AWS OIDC Integration name + Integration string + // TeleportClusterName is the name of the Teleport Cluster. + // Used to uniquely identify the ECS Cluster in Amazon. + TeleportClusterName string + // NextToken is the token to be used to fetch the next page. + // If empty, the first page is fetched. + NextToken string +} + +func (req *ListDeployedDatabaseServicesRequest) checkAndSetDefaults() error { + if req.Region == "" { + return trace.BadParameter("region is required") + } + + if req.Integration == "" { + return trace.BadParameter("integration is required") + } + + if req.TeleportClusterName == "" { + return trace.BadParameter("teleport cluster name is required") + } + + return nil +} + +// ListDeployedDatabaseServicesResponse contains a page of Deployed Database Services. +type ListDeployedDatabaseServicesResponse struct { + // DeployedDatabaseServices contains the page of Deployed Database Services. + DeployedDatabaseServices []DeployedDatabaseService `json:"deployedDatabaseServices"` + + // NextToken is used for pagination. + // If non-empty, it can be used to request the next page. + NextToken string `json:"nextToken"` +} + +// DeployedDatabaseService contains a database service that was deployed to Amazon ECS. +type DeployedDatabaseService struct { + // Name is the ECS Service name. + Name string + // ServiceDashboardURL is the Amazon Web Console URL for this ECS Service. + ServiceDashboardURL string + // ContainerEntryPoint is the entry point for the container 0 that is running in the ECS Task. + ContainerEntryPoint []string + // ContainerCommand is the list of arguments that are passed into the ContainerEntryPoint. + ContainerCommand []string +} + +// ListDeployedDatabaseServicesClient describes the required methods to list AWS VPCs. +type ListDeployedDatabaseServicesClient interface { + // ListServices returns a list of services. + ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) + // DescribeServices returns ECS Services details. + DescribeServices(ctx context.Context, params *ecs.DescribeServicesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) + // DescribeTaskDefinition returns an ECS Task Definition. + DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) +} + +type defaultListDeployedDatabaseServicesClient struct { + *ecs.Client +} + +// NewListDeployedDatabaseServicesClient creates a new ListDeployedDatabaseServicesClient using an AWSClientRequest. +func NewListDeployedDatabaseServicesClient(ctx context.Context, req *AWSClientRequest) (ListDeployedDatabaseServicesClient, error) { + ecsClient, err := newECSClient(ctx, req) + if err != nil { + return nil, trace.Wrap(err) + } + + return &defaultListDeployedDatabaseServicesClient{ + Client: ecsClient, + }, nil +} + +// ListDeployedDatabaseServices calls the following AWS API: +// https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_ListServices.html +// https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeServices.html +// https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeTaskDefinition.html +// It returns a list of ECS Services running Teleport Database Service and an optional NextToken that can be used to fetch the next page. +func ListDeployedDatabaseServices(ctx context.Context, clt ListDeployedDatabaseServicesClient, req ListDeployedDatabaseServicesRequest) (*ListDeployedDatabaseServicesResponse, error) { + if err := req.checkAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + clusterName := normalizeECSClusterName(req.TeleportClusterName) + + log := slog.With( + "integration", req.Integration, + "aws_region", req.Region, + "ecs_cluster", clusterName, + ) + + // Do not increase this value because ecs.DescribeServices only allows up to 10 services per API call. + maxServicesPerPage := aws.Int32(10) + listServicesInput := &ecs.ListServicesInput{ + Cluster: &clusterName, + MaxResults: maxServicesPerPage, + LaunchType: ecstypes.LaunchTypeFargate, + } + if req.NextToken != "" { + listServicesInput.NextToken = &req.NextToken + } + + listServicesOutput, err := clt.ListServices(ctx, listServicesInput) + if err != nil { + return nil, trace.Wrap(err) + } + + describeServicesOutput, err := clt.DescribeServices(ctx, &ecs.DescribeServicesInput{ + Services: listServicesOutput.ServiceArns, + Include: []ecstypes.ServiceField{ecstypes.ServiceFieldTags}, + Cluster: &clusterName, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + ownershipTags := tags.DefaultResourceCreationTags(req.TeleportClusterName, req.Integration) + + deployedDatabaseServices := []DeployedDatabaseService{} + for _, ecsService := range describeServicesOutput.Services { + log := log.With("ecs_service", aws.ToString(ecsService.ServiceName)) + if !ownershipTags.MatchesECSTags(ecsService.Tags) { + log.WarnContext(ctx, "Missing ownership tags in ECS Service, skipping") + continue + } + + taskDefinitionOut, err := clt.DescribeTaskDefinition(ctx, &ecs.DescribeTaskDefinitionInput{ + TaskDefinition: ecsService.TaskDefinition, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + if len(taskDefinitionOut.TaskDefinition.ContainerDefinitions) == 0 { + log.WarnContext(ctx, "Task has no containers defined, skipping", + "ecs_task_family", aws.ToString(taskDefinitionOut.TaskDefinition.Family), + "ecs_task_revision", taskDefinitionOut.TaskDefinition.Revision, + ) + continue + } + + entryPoint := taskDefinitionOut.TaskDefinition.ContainerDefinitions[0].EntryPoint + command := taskDefinitionOut.TaskDefinition.ContainerDefinitions[0].Command + + deployedDatabaseServices = append(deployedDatabaseServices, DeployedDatabaseService{ + Name: aws.ToString(ecsService.ServiceName), + ServiceDashboardURL: serviceDashboardURL(req.Region, clusterName, aws.ToString(ecsService.ServiceName)), + ContainerEntryPoint: entryPoint, + ContainerCommand: command, + }) + } + + return &ListDeployedDatabaseServicesResponse{ + DeployedDatabaseServices: deployedDatabaseServices, + NextToken: aws.ToString(listServicesOutput.NextToken), + }, nil +} diff --git a/lib/integrations/awsoidc/listdeployeddatabaseservice_test.go b/lib/integrations/awsoidc/listdeployeddatabaseservice_test.go new file mode 100644 index 0000000000000..67f332d495c2b --- /dev/null +++ b/lib/integrations/awsoidc/listdeployeddatabaseservice_test.go @@ -0,0 +1,360 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package awsoidc + +import ( + "context" + "fmt" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/google/go-cmp/cmp" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +func TestListDeployedDatabaseServicesRequest(t *testing.T) { + isBadParamErrFn := func(tt require.TestingT, err error, i ...any) { + require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err) + } + + baseReqFn := func() ListDeployedDatabaseServicesRequest { + return ListDeployedDatabaseServicesRequest{ + TeleportClusterName: "mycluster", + Region: "eu-west-2", + Integration: "my-integration", + } + } + + for _, tt := range []struct { + name string + req func() ListDeployedDatabaseServicesRequest + errCheck require.ErrorAssertionFunc + reqWithDefaults ListDeployedDatabaseServicesRequest + }{ + { + name: "no fields", + req: func() ListDeployedDatabaseServicesRequest { + return ListDeployedDatabaseServicesRequest{} + }, + errCheck: isBadParamErrFn, + }, + { + name: "missing teleport cluster name", + req: func() ListDeployedDatabaseServicesRequest { + r := baseReqFn() + r.TeleportClusterName = "" + return r + }, + errCheck: isBadParamErrFn, + }, + { + name: "missing region", + req: func() ListDeployedDatabaseServicesRequest { + r := baseReqFn() + r.Region = "" + return r + }, + errCheck: isBadParamErrFn, + }, + { + name: "missing integration", + req: func() ListDeployedDatabaseServicesRequest { + r := baseReqFn() + r.Integration = "" + return r + }, + errCheck: isBadParamErrFn, + }, + } { + t.Run(tt.name, func(t *testing.T) { + r := tt.req() + err := r.checkAndSetDefaults() + tt.errCheck(t, err) + + if err != nil { + return + } + + require.Empty(t, cmp.Diff(tt.reqWithDefaults, r)) + }) + } +} + +type mockListECSClient struct { + pageSize int + + clusterName string + services []*ecstypes.Service + mapServices map[string]ecstypes.Service + taskDefinition map[string]*ecstypes.TaskDefinition +} + +func (m *mockListECSClient) ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) { + ret := &ecs.ListServicesOutput{} + if aws.ToString(params.Cluster) != m.clusterName { + return ret, nil + } + + requestedPage := 1 + + totalEndpoints := len(m.services) + + if params.NextToken != nil { + currentMarker, err := strconv.Atoi(*params.NextToken) + if err != nil { + return nil, trace.Wrap(err) + } + requestedPage = currentMarker + } + + sliceStart := m.pageSize * (requestedPage - 1) + sliceEnd := m.pageSize * requestedPage + if sliceEnd > totalEndpoints { + sliceEnd = totalEndpoints + } + + for _, service := range m.services[sliceStart:sliceEnd] { + ret.ServiceArns = append(ret.ServiceArns, aws.ToString(service.ServiceArn)) + } + + if sliceEnd < totalEndpoints { + nextToken := strconv.Itoa(requestedPage + 1) + ret.NextToken = &nextToken + } + + return ret, nil +} + +func (m *mockListECSClient) DescribeServices(ctx context.Context, params *ecs.DescribeServicesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) { + ret := &ecs.DescribeServicesOutput{} + if aws.ToString(params.Cluster) != m.clusterName { + return ret, nil + } + + for _, serviceARN := range params.Services { + ret.Services = append(ret.Services, m.mapServices[serviceARN]) + } + return ret, nil +} + +func (m *mockListECSClient) DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) { + ret := &ecs.DescribeTaskDefinitionOutput{} + ret.TaskDefinition = m.taskDefinition[aws.ToString(params.TaskDefinition)] + + return ret, nil +} + +func dummyServiceTask(idx int) (ecstypes.Service, *ecstypes.TaskDefinition) { + taskName := fmt.Sprintf("task-family-name-%d", idx) + serviceARN := fmt.Sprintf("arn:eks:service-%d", idx) + + ecsTask := &ecstypes.TaskDefinition{ + Family: aws.String(taskName), + ContainerDefinitions: []ecstypes.ContainerDefinition{{ + EntryPoint: []string{"teleport"}, + Command: []string{"start"}, + }}, + } + + ecsService := ecstypes.Service{ + ServiceArn: aws.String(serviceARN), + ServiceName: aws.String(fmt.Sprintf("database-service-vpc-%d", idx)), + TaskDefinition: aws.String(taskName), + Tags: []ecstypes.Tag{ + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("my-cluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("my-integration")}, + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + }, + } + + return ecsService, ecsTask +} + +func TestListDeployedDatabaseServices(t *testing.T) { + ctx := context.Background() + + const pageSize = 100 + t.Run("pagination", func(t *testing.T) { + totalServices := 203 + + allServices := make([]*ecstypes.Service, 0, totalServices) + mapServices := make(map[string]ecstypes.Service, totalServices) + allTasks := make(map[string]*ecstypes.TaskDefinition, totalServices) + for i := 0; i < totalServices; i++ { + ecsService, ecsTask := dummyServiceTask(i) + allTasks[aws.ToString(ecsTask.Family)] = ecsTask + mapServices[aws.ToString(ecsService.ServiceArn)] = ecsService + allServices = append(allServices, &ecsService) + } + + mockListClient := &mockListECSClient{ + pageSize: pageSize, + clusterName: "my-cluster-teleport", + mapServices: mapServices, + services: allServices, + taskDefinition: allTasks, + } + + // First page must return pageSize number of Endpoints + resp, err := ListDeployedDatabaseServices(ctx, mockListClient, ListDeployedDatabaseServicesRequest{ + Integration: "my-integration", + TeleportClusterName: "my-cluster", + Region: "us-east-1", + }) + require.NoError(t, err) + require.NotEmpty(t, resp.NextToken) + require.Len(t, resp.DeployedDatabaseServices, pageSize) + require.Equal(t, "database-service-vpc-0", resp.DeployedDatabaseServices[0].Name) + require.Equal(t, "https://us-east-1.console.aws.amazon.com/ecs/v2/clusters/my-cluster-teleport/services/database-service-vpc-0", resp.DeployedDatabaseServices[0].ServiceDashboardURL) + require.Equal(t, []string{"teleport"}, resp.DeployedDatabaseServices[0].ContainerEntryPoint) + require.Equal(t, []string{"start"}, resp.DeployedDatabaseServices[0].ContainerCommand) + + // Second page must return pageSize number of Endpoints + nextPageToken := resp.NextToken + resp, err = ListDeployedDatabaseServices(ctx, mockListClient, ListDeployedDatabaseServicesRequest{ + Integration: "my-integration", + TeleportClusterName: "my-cluster", + Region: "us-east-1", + NextToken: nextPageToken, + }) + require.NoError(t, err) + require.NotEmpty(t, resp.NextToken) + require.Len(t, resp.DeployedDatabaseServices, pageSize) + require.Equal(t, "https://us-east-1.console.aws.amazon.com/ecs/v2/clusters/my-cluster-teleport/services/database-service-vpc-100", resp.DeployedDatabaseServices[0].ServiceDashboardURL) + + // Third page must return only the remaining Endpoints and an empty nextToken + nextPageToken = resp.NextToken + resp, err = ListDeployedDatabaseServices(ctx, mockListClient, ListDeployedDatabaseServicesRequest{ + Integration: "my-integration", + TeleportClusterName: "my-cluster", + Region: "us-east-1", + NextToken: nextPageToken, + }) + require.NoError(t, err) + require.Empty(t, resp.NextToken) + require.Len(t, resp.DeployedDatabaseServices, 3) + }) + + for _, tt := range []struct { + name string + req ListDeployedDatabaseServicesRequest + mockClient func() *mockListECSClient + errCheck require.ErrorAssertionFunc + respCheck func(*testing.T, *ListDeployedDatabaseServicesResponse) + }{ + { + name: "ignores ECS Services without ownership tags", + req: ListDeployedDatabaseServicesRequest{ + Integration: "my-integration", + TeleportClusterName: "my-cluster", + Region: "us-east-1", + }, + mockClient: func() *mockListECSClient { + ret := &mockListECSClient{ + pageSize: 10, + clusterName: "my-cluster-teleport", + } + ecsService, ecsTask := dummyServiceTask(0) + + ecsServiceAnotherIntegration, ecsTaskAnotherIntegration := dummyServiceTask(1) + ecsServiceAnotherIntegration.Tags = []ecstypes.Tag{{Key: aws.String("teleport.dev/integration"), Value: aws.String("another-integration")}} + + ret.taskDefinition = map[string]*ecstypes.TaskDefinition{ + aws.ToString(ecsTask.Family): ecsTask, + aws.ToString(ecsTaskAnotherIntegration.Family): ecsTaskAnotherIntegration, + } + ret.mapServices = map[string]ecstypes.Service{ + aws.ToString(ecsService.ServiceArn): ecsService, + aws.ToString(ecsServiceAnotherIntegration.ServiceArn): ecsServiceAnotherIntegration, + } + ret.services = append(ret.services, &ecsService) + ret.services = append(ret.services, &ecsServiceAnotherIntegration) + return ret + }, + respCheck: func(t *testing.T, resp *ListDeployedDatabaseServicesResponse) { + require.Len(t, resp.DeployedDatabaseServices, 1, "expected 1 service, got %d", len(resp.DeployedDatabaseServices)) + require.Empty(t, resp.NextToken, "expected an empty NextToken") + + expectedService := DeployedDatabaseService{ + Name: "database-service-vpc-0", + ServiceDashboardURL: "https://us-east-1.console.aws.amazon.com/ecs/v2/clusters/my-cluster-teleport/services/database-service-vpc-0", + ContainerEntryPoint: []string{"teleport"}, + ContainerCommand: []string{"start"}, + } + require.Empty(t, cmp.Diff(expectedService, resp.DeployedDatabaseServices[0])) + }, + errCheck: require.NoError, + }, + { + name: "ignores ECS Services without containers", + req: ListDeployedDatabaseServicesRequest{ + Integration: "my-integration", + TeleportClusterName: "my-cluster", + Region: "us-east-1", + }, + mockClient: func() *mockListECSClient { + ret := &mockListECSClient{ + pageSize: 10, + clusterName: "my-cluster-teleport", + } + ecsService, ecsTask := dummyServiceTask(0) + + ecsServiceWithoutContainers, ecsTaskWithoutContainers := dummyServiceTask(1) + ecsTaskWithoutContainers.ContainerDefinitions = []ecstypes.ContainerDefinition{} + + ret.taskDefinition = map[string]*ecstypes.TaskDefinition{ + aws.ToString(ecsTask.Family): ecsTask, + aws.ToString(ecsTaskWithoutContainers.Family): ecsTaskWithoutContainers, + } + ret.mapServices = map[string]ecstypes.Service{ + aws.ToString(ecsService.ServiceArn): ecsService, + aws.ToString(ecsServiceWithoutContainers.ServiceArn): ecsServiceWithoutContainers, + } + ret.services = append(ret.services, &ecsService) + ret.services = append(ret.services, &ecsServiceWithoutContainers) + return ret + }, + respCheck: func(t *testing.T, resp *ListDeployedDatabaseServicesResponse) { + require.Len(t, resp.DeployedDatabaseServices, 1, "expected 1 service, got %d", len(resp.DeployedDatabaseServices)) + require.Empty(t, resp.NextToken, "expected an empty NextToken") + + expectedService := DeployedDatabaseService{ + Name: "database-service-vpc-0", + ServiceDashboardURL: "https://us-east-1.console.aws.amazon.com/ecs/v2/clusters/my-cluster-teleport/services/database-service-vpc-0", + ContainerEntryPoint: []string{"teleport"}, + ContainerCommand: []string{"start"}, + } + require.Empty(t, cmp.Diff(expectedService, resp.DeployedDatabaseServices[0])) + }, + errCheck: require.NoError, + }, + } { + t.Run(tt.name, func(t *testing.T) { + resp, err := ListDeployedDatabaseServices(ctx, tt.mockClient(), tt.req) + tt.errCheck(t, err) + if tt.respCheck != nil { + tt.respCheck(t, resp) + } + }) + } +}