diff --git a/pkg/cloud/services/network/routetables.go b/pkg/cloud/services/network/routetables.go index 638a87fafb..736a047e5c 100644 --- a/pkg/cloud/services/network/routetables.go +++ b/pkg/cloud/services/network/routetables.go @@ -60,30 +60,10 @@ func (s *Service) reconcileRouteTables() error { for i := range subnets { sn := &subnets[i] // We need to compile the minimum routes for this subnet first, so we can compare it or create them. - var routes []*ec2.Route - if sn.IsPublic { - if s.scope.VPC().InternetGatewayID == nil { - return errors.Errorf("failed to create routing tables: internet gateway for %q is nil", s.scope.VPC().ID) - } - routes = append(routes, s.getGatewayPublicRoute()) - if sn.IsIPv6 { - routes = append(routes, s.getGatewayPublicIPv6Route()) - } - } else { - natGatewayID, err := s.getNatGatewayForSubnet(sn) - if err != nil { - return err - } - routes = append(routes, s.getNatGatewayPrivateRoute(natGatewayID)) - if sn.IsIPv6 { - if !s.scope.VPC().IsIPv6Enabled() { - // Safety net because EgressOnlyInternetGateway needs the ID from the ipv6 block. - // if, for whatever reason by this point that is not available, we don't want to - // panic because of a nil pointer access. This should never occur. Famous last words though. - return errors.Errorf("ipv6 block missing for ipv6 enabled subnet, can't create egress only internet gateway") - } - routes = append(routes, s.getEgressOnlyInternetGateway()) - } + routes, err := s.getRoutesForSubnet(sn) + if err != nil { + record.Warnf(s.scope.InfraCluster(), "FailedRouteTableRoutes", "Failed to get routes for managed RouteTable for subnet %s: %v", sn.ID, err) + return errors.Wrapf(err, "failed to discover routes on route table %s", sn.ID) } if rt, ok := subnetRouteMap[sn.GetResourceID()]; ok { @@ -145,7 +125,7 @@ func (s *Service) reconcileRouteTables() error { return nil } -func (s *Service) fixMismatchedRouting(specRoute *ec2.Route, currentRoute *ec2.Route, rt *ec2.RouteTable) error { +func (s *Service) fixMismatchedRouting(specRoute *ec2.CreateRouteInput, currentRoute *ec2.Route, rt *ec2.RouteTable) error { var input *ec2.ReplaceRouteInput if specRoute.DestinationCidrBlock != nil { if (currentRoute.DestinationCidrBlock != nil && @@ -280,7 +260,7 @@ func (s *Service) describeVpcRouteTables() ([]*ec2.RouteTable, error) { return out.RouteTables, nil } -func (s *Service) createRouteTableWithRoutes(routes []*ec2.Route, isPublic bool, zone string) (*infrav1.RouteTable, error) { +func (s *Service) createRouteTableWithRoutes(routes []*ec2.CreateRouteInput, isPublic bool, zone string) (*infrav1.RouteTable, error) { out, err := s.EC2Client.CreateRouteTableWithContext(context.TODO(), &ec2.CreateRouteTableInput{ VpcId: aws.String(s.scope.VPC().ID), TagSpecifications: []*ec2.TagSpecification{ @@ -296,17 +276,8 @@ func (s *Service) createRouteTableWithRoutes(routes []*ec2.Route, isPublic bool, for i := range routes { route := routes[i] if err := wait.WaitForWithRetryable(wait.NewBackoff(), func() (bool, error) { - if _, err := s.EC2Client.CreateRouteWithContext(context.TODO(), &ec2.CreateRouteInput{ - RouteTableId: out.RouteTable.RouteTableId, - DestinationCidrBlock: route.DestinationCidrBlock, - DestinationIpv6CidrBlock: route.DestinationIpv6CidrBlock, - EgressOnlyInternetGatewayId: route.EgressOnlyInternetGatewayId, - GatewayId: route.GatewayId, - InstanceId: route.InstanceId, - NatGatewayId: route.NatGatewayId, - NetworkInterfaceId: route.NetworkInterfaceId, - VpcPeeringConnectionId: route.VpcPeeringConnectionId, - }); err != nil { + route.RouteTableId = out.RouteTable.RouteTableId + if _, err := s.EC2Client.CreateRouteWithContext(context.TODO(), route); err != nil { return false, err } return true, nil @@ -341,29 +312,29 @@ func (s *Service) associateRouteTable(rt *infrav1.RouteTable, subnetID string) e return nil } -func (s *Service) getNatGatewayPrivateRoute(natGatewayID string) *ec2.Route { - return &ec2.Route{ +func (s *Service) getNatGatewayPrivateRoute(natGatewayID string) *ec2.CreateRouteInput { + return &ec2.CreateRouteInput{ NatGatewayId: aws.String(natGatewayID), DestinationCidrBlock: aws.String(services.AnyIPv4CidrBlock), } } -func (s *Service) getEgressOnlyInternetGateway() *ec2.Route { - return &ec2.Route{ +func (s *Service) getEgressOnlyInternetGateway() *ec2.CreateRouteInput { + return &ec2.CreateRouteInput{ DestinationIpv6CidrBlock: aws.String(services.AnyIPv6CidrBlock), EgressOnlyInternetGatewayId: s.scope.VPC().IPv6.EgressOnlyInternetGatewayID, } } -func (s *Service) getGatewayPublicRoute() *ec2.Route { - return &ec2.Route{ +func (s *Service) getGatewayPublicRoute() *ec2.CreateRouteInput { + return &ec2.CreateRouteInput{ DestinationCidrBlock: aws.String(services.AnyIPv4CidrBlock), GatewayId: aws.String(*s.scope.VPC().InternetGatewayID), } } -func (s *Service) getGatewayPublicIPv6Route() *ec2.Route { - return &ec2.Route{ +func (s *Service) getGatewayPublicIPv6Route() *ec2.CreateRouteInput { + return &ec2.CreateRouteInput{ DestinationIpv6CidrBlock: aws.String(services.AnyIPv6CidrBlock), GatewayId: aws.String(*s.scope.VPC().InternetGatewayID), } @@ -394,3 +365,45 @@ func (s *Service) getRouteTableTagParams(id string, public bool, zone string) in Additional: additionalTags, } } + +func (s *Service) getRoutesToPublicSubnet(sn *infrav1.SubnetSpec) ([]*ec2.CreateRouteInput, error) { + var routes []*ec2.CreateRouteInput + + if s.scope.VPC().InternetGatewayID == nil { + return routes, errors.Errorf("failed to create routing tables: internet gateway for %q is nil", s.scope.VPC().ID) + } + + routes = append(routes, s.getGatewayPublicRoute()) + if sn.IsIPv6 { + routes = append(routes, s.getGatewayPublicIPv6Route()) + } + + return routes, nil +} + +func (s *Service) getRoutesToPrivateSubnet(sn *infrav1.SubnetSpec) (routes []*ec2.CreateRouteInput, err error) { + natGatewayID, err := s.getNatGatewayForSubnet(sn) + if err != nil { + return routes, err + } + + routes = append(routes, s.getNatGatewayPrivateRoute(natGatewayID)) + if sn.IsIPv6 { + if !s.scope.VPC().IsIPv6Enabled() { + // Safety net because EgressOnlyInternetGateway needs the ID from the ipv6 block. + // if, for whatever reason by this point that is not available, we don't want to + // panic because of a nil pointer access. This should never occur. Famous last words though. + return routes, errors.Errorf("ipv6 block missing for ipv6 enabled subnet, can't create route for egress only internet gateway") + } + routes = append(routes, s.getEgressOnlyInternetGateway()) + } + + return routes, nil +} + +func (s *Service) getRoutesForSubnet(sn *infrav1.SubnetSpec) ([]*ec2.CreateRouteInput, error) { + if sn.IsPublic { + return s.getRoutesToPublicSubnet(sn) + } + return s.getRoutesToPrivateSubnet(sn) +} diff --git a/pkg/cloud/services/network/routetables_test.go b/pkg/cloud/services/network/routetables_test.go index 0f38cbacf0..83d5a886b6 100644 --- a/pkg/cloud/services/network/routetables_test.go +++ b/pkg/cloud/services/network/routetables_test.go @@ -25,10 +25,12 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/golang/mock/gomock" + "github.com/google/go-cmp/cmp" . "github.com/onsi/gomega" "github.com/pkg/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client/fake" infrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/api/v1beta2" @@ -873,3 +875,244 @@ func (r routeTableInputMatcher) String() string { func matchRouteTableInput(input *ec2.CreateRouteTableInput) gomock.Matcher { return routeTableInputMatcher{routeTableInput: input} } + +func TestService_getRoutesForSubnet(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + defaultSubnets := infrav1.Subnets{ + { + ResourceID: "subnet-az-2z-private", + AvailabilityZone: "us-east-2z", + IsPublic: false, + }, + { + ResourceID: "subnet-az-2z-public", + AvailabilityZone: "us-east-2z", + IsPublic: true, + NatGatewayID: ptr.To("nat-gw-fromZone-us-east-2z"), + }, + { + ResourceID: "subnet-az-1a-private", + AvailabilityZone: "us-east-1a", + IsPublic: false, + }, + { + ResourceID: "subnet-az-1a-public", + AvailabilityZone: "us-east-1a", + IsPublic: true, + NatGatewayID: ptr.To("nat-gw-fromZone-us-east-1a"), + }, + } + + vpcName := "vpc-test-for-routes" + defaultNetwork := infrav1.NetworkSpec{ + VPC: infrav1.VPCSpec{ + ID: vpcName, + InternetGatewayID: aws.String("vpc-igw"), + IPv6: &infrav1.IPv6{ + CidrBlock: "2001:db8:1234:1::/64", + EgressOnlyInternetGatewayID: aws.String("vpc-eigw"), + }, + }, + Subnets: defaultSubnets, + } + + tests := []struct { + name string + specOverrideNet *infrav1.NetworkSpec + specOverrideSubnets *infrav1.Subnets + inputSubnet *infrav1.SubnetSpec + want []*ec2.CreateRouteInput + wantErr bool + wantErrMessage string + }{ + { + name: "empty subnet should have empty routes", + specOverrideSubnets: &infrav1.Subnets{}, + inputSubnet: &infrav1.SubnetSpec{ + ID: "subnet-1-private", + }, + want: []*ec2.CreateRouteInput{}, + wantErrMessage: `no nat gateways available in "" for private subnet "subnet-1-private", current state: map[]`, + }, + { + name: "empty subnet should have empty routes", + inputSubnet: &infrav1.SubnetSpec{}, + want: []*ec2.CreateRouteInput{}, + wantErrMessage: `no nat gateways available in "" for private subnet "", current state: map[us-east-1a:[nat-gw-fromZone-us-east-1a] us-east-2z:[nat-gw-fromZone-us-east-2z]]`, + }, + // public subnets ipv4 + { + name: "public ipv4 subnet, availability zone, must have ipv4 default route to igw", + inputSubnet: &infrav1.SubnetSpec{ + ResourceID: "subnet-az-1a-public", + AvailabilityZone: "us-east-1a", + IsIPv6: false, + IsPublic: true, + }, + want: []*ec2.CreateRouteInput{ + { + DestinationCidrBlock: aws.String("0.0.0.0/0"), + GatewayId: aws.String("vpc-igw"), + }, + }, + }, + { + name: "public ipv6 subnet, availability zone, must have ipv6 default route to igw", + inputSubnet: &infrav1.SubnetSpec{ + ResourceID: "subnet-az-1a-public", + AvailabilityZone: "us-east-1a", + IsPublic: true, + IsIPv6: true, + }, + want: []*ec2.CreateRouteInput{ + { + DestinationCidrBlock: aws.String("0.0.0.0/0"), + GatewayId: aws.String("vpc-igw"), + }, + { + DestinationIpv6CidrBlock: aws.String("::/0"), + GatewayId: aws.String("vpc-igw"), + }, + }, + }, + // public subnet ipv4, GW not found. + { + name: "public ipv4 subnet, availability zone, must return error when no internet gateway available", + specOverrideNet: func() *infrav1.NetworkSpec { + net := defaultNetwork.DeepCopy() + net.VPC.InternetGatewayID = nil + return net + }(), + inputSubnet: &infrav1.SubnetSpec{ + ResourceID: "subnet-az-1a-public", + AvailabilityZone: "us-east-1a", + IsPublic: true, + }, + wantErrMessage: `failed to create routing tables: internet gateway for "vpc-test-for-routes" is nil`, + }, + // private subnets + { + name: "private ipv4 subnet, availability zone, must have ipv4 default route to nat gateway", + inputSubnet: &infrav1.SubnetSpec{ + ResourceID: "subnet-az-1a-private", + AvailabilityZone: "us-east-1a", + IsPublic: false, + }, + want: []*ec2.CreateRouteInput{ + { + DestinationCidrBlock: aws.String("0.0.0.0/0"), + NatGatewayId: aws.String("nat-gw-fromZone-us-east-1a"), + }, + }, + }, + // egress-only subnet ipv6 + { + name: "egress-only ipv6 subnet, availability zone, must have ipv6 default route to egress-only gateway", + inputSubnet: &infrav1.SubnetSpec{ + ResourceID: "subnet-az-1a-private", + AvailabilityZone: "us-east-1a", + IsIPv6: true, + IsPublic: false, + }, + want: []*ec2.CreateRouteInput{ + { + DestinationCidrBlock: aws.String("0.0.0.0/0"), + NatGatewayId: aws.String("nat-gw-fromZone-us-east-1a"), + }, + { + DestinationIpv6CidrBlock: aws.String("::/0"), + EgressOnlyInternetGatewayId: aws.String("vpc-eigw"), + }, + }, + }, + { + name: "private ipv6 subnet, availability zone, non-ipv6 block, must return error", + specOverrideNet: func() *infrav1.NetworkSpec { + net := defaultNetwork.DeepCopy() + net.VPC.IPv6 = nil + return net + }(), + inputSubnet: &infrav1.SubnetSpec{ + ResourceID: "subnet-az-1a-private", + AvailabilityZone: "us-east-1a", + IsIPv6: true, + IsPublic: false, + }, + wantErrMessage: `ipv6 block missing for ipv6 enabled subnet, can't create route for egress only internet gateway`, + }, + // private subnet, gateway not found + { + name: "private ipv4 subnet, availability zone, must return error when invalid gateway", + specOverrideNet: func() *infrav1.NetworkSpec { + net := defaultNetwork.DeepCopy() + for i := range net.Subnets { + if net.Subnets[i].AvailabilityZone == "us-east-1a" && net.Subnets[i].IsPublic { + net.Subnets[i].NatGatewayID = nil + } + } + return net + }(), + inputSubnet: &infrav1.SubnetSpec{ + ResourceID: "subnet-az-1a-private", + AvailabilityZone: "us-east-1a", + IsPublic: false, + }, + wantErrMessage: `no nat gateways available in "us-east-1a" for private subnet "subnet-az-1a-private", current state: map[us-east-2z:[nat-gw-fromZone-us-east-2z]]`, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + scheme := runtime.NewScheme() + _ = infrav1.AddToScheme(scheme) + client := fake.NewClientBuilder().WithScheme(scheme).Build() + cluster := scope.ClusterScopeParams{ + Client: client, + Cluster: &clusterv1.Cluster{ + ObjectMeta: metav1.ObjectMeta{Name: "test-cluster-routes"}, + }, + AWSCluster: &infrav1.AWSCluster{ + ObjectMeta: metav1.ObjectMeta{Name: "test"}, + Spec: infrav1.AWSClusterSpec{}, + }, + } + cluster.AWSCluster.Spec.NetworkSpec = defaultNetwork + if tc.specOverrideNet != nil { + cluster.AWSCluster.Spec.NetworkSpec = *tc.specOverrideNet + } + if tc.specOverrideSubnets != nil { + cluster.AWSCluster.Spec.NetworkSpec.Subnets = *tc.specOverrideSubnets + } + + scope, err := scope.NewClusterScope(cluster) + if err != nil { + t.Errorf("Service.getRoutesForSubnet() error setting up the test case: %v", err) + } + + s := NewService(scope) + got, err := s.getRoutesForSubnet(tc.inputSubnet) + + wantErr := tc.wantErr + if len(tc.wantErrMessage) > 0 { + wantErr = true + } + if wantErr && err == nil { + t.Fatal("expected error but got no error") + } + if err != nil { + if !wantErr { + t.Fatalf("got an unexpected error: %v", err) + } + if wantErr && len(tc.wantErrMessage) > 0 && err.Error() != tc.wantErrMessage { + t.Fatalf("got an unexpected error message:\nwant: %v\n got: %v\n", tc.wantErrMessage, err) + } + } + if len(tc.want) > 0 { + if !cmp.Equal(got, tc.want) { + t.Errorf("got unexpect routes:\n%v", cmp.Diff(got, tc.want)) + } + } + }) + } +}