diff --git a/internal/certificate/repository.go b/internal/certificate/repository.go index 8cf5b0c..29b8204 100644 --- a/internal/certificate/repository.go +++ b/internal/certificate/repository.go @@ -59,7 +59,7 @@ func (r acmCertRepository) FindByFilter(filter CertFilter) ([]Certificate, error var certs []Certificate var certDiscoveryErr error - err := r.client.ListCertificatesPages(input, func(output *acm.ListCertificatesOutput, _ bool) bool { + err := r.client.ListCertificatesPages(input, func(output *acm.ListCertificatesOutput, lastPage bool) bool { for _, acmCertSummary := range output.CertificateSummaryList { acmCert, err := r.client.DescribeCertificate(&acm.DescribeCertificateInput{ CertificateArn: acmCertSummary.CertificateArn, @@ -79,7 +79,8 @@ func (r acmCertRepository) FindByFilter(filter CertFilter) ([]Certificate, error certs = append(certs, dnCert) } } - return true + + return lastPage }) if certDiscoveryErr != nil { diff --git a/internal/certificate/service.go b/internal/certificate/service.go index d907782..39176bd 100644 --- a/internal/certificate/service.go +++ b/internal/certificate/service.go @@ -32,7 +32,7 @@ var ( // Service handle the certificate actions as discovery type Service interface { - DiscoverByHost(string) (Certificate, error) + DiscoverByHost([]string) (Certificate, error) } // NewService creates a new Certificate Service @@ -44,10 +44,10 @@ type acmCertService struct { repo Repository } -// DiscoverByHost tries to discover a certificate given a host -func (a acmCertService) DiscoverByHost(host string) (Certificate, error) { +// DiscoverByHost tries to discover a certificate given hosts +func (a acmCertService) DiscoverByHost(hosts []string) (Certificate, error) { - certs, err := a.repo.FindByFilter(matchingDomainFilter(host)) + certs, err := a.repo.FindByFilter(matchingDomainFilter(hosts)) if err != nil { return Certificate{}, fmt.Errorf("discovery certificate: %v", err) @@ -60,25 +60,33 @@ func (a acmCertService) DiscoverByHost(host string) (Certificate, error) { return certs[0], nil } -func matchingDomainFilter(host string) CertFilter { +func matchingDomainFilter(hosts []string) CertFilter { return func(c Certificate) bool { - if host == c.DomainName() { - return true + for _, host := range hosts { + if !certMatches(host, c) { + return false + } } + return true + } +} - for _, alterName := range c.AlternativeNames() { - hs := strings.Split(host, ".") - hostDomain := strings.Join(hs[1:], ".") - - if strings.HasPrefix(alterName, "*.") { - alterName = strings.ReplaceAll(alterName, "*.", "") - } +func certMatches(distHost string, c Certificate) bool { + for _, certHost := range append(c.AlternativeNames(), c.DomainName()) { + if distHost == certHost { + return true + } + hs := strings.Split(distHost, ".") + hostDomain := strings.Join(hs[1:], ".") - if alterName == hostDomain { - return true - } + if strings.HasPrefix(certHost, "*.") { + certHost = strings.ReplaceAll(certHost, "*.", "") } - return false + if certHost == hostDomain { + return true + } } + + return false } diff --git a/internal/certificate/service_test.go b/internal/certificate/service_test.go index 7e0e546..a32e678 100644 --- a/internal/certificate/service_test.go +++ b/internal/certificate/service_test.go @@ -34,24 +34,76 @@ type CertificateServiceTestSuite struct { suite.Suite } -func (s *CertificateServiceTestSuite) TestMatchDomainFilter_MainDomain() { - cert := New( - "arn:xpto", - "foo.xpto.com", - []string{"foo.xpto.com", "*.foo.xpto.com"}, - ) - - filter := matchingDomainFilter("foo.xpto.com") - s.True(filter(cert)) +func (s *CertificateServiceTestSuite) TestMatchDomainFilters_Matches() { + testCases := []struct { + name string + certDomainName string + certAlternativeDomains []string + distrDomains []string + }{ + { + name: "Matching distribution domains using all certificate alternative domains", + certDomainName: "foo.com", + certAlternativeDomains: []string{"foo.com", "*.foo.com"}, + distrDomains: []string{"www.foo.com", "foo.com"}, + }, + { + name: "Matching distribution domains using all certificate alternative domains (alternative order on distr domains)", + certDomainName: "foo.com", + certAlternativeDomains: []string{"foo.com", "*.foo.com"}, + distrDomains: []string{"foo.com", "www.foo.com"}, + }, + { + name: "Matching distribution domains using all certificate alternative domains (alternative order on cert domains)", + certDomainName: "foo.com", + certAlternativeDomains: []string{"*.foo.com", "foo.com"}, + distrDomains: []string{"foo.com", "www.foo.com"}, + }, + { + name: "Matching distribution domains when certificate has additional alternative domains", + certDomainName: "bar.com", + certAlternativeDomains: []string{"*.foo.com", "bar.com", "*.baz.com"}, + distrDomains: []string{"www.foo.com", "bar.com"}, + }, + { + name: "Matching distribution domains exactly with certificates domains", + certDomainName: "bar.com", + certAlternativeDomains: []string{"baz.com"}, + distrDomains: []string{"bar.com", "baz.com"}, + }, + } + + for _, tc := range testCases { + cert := New("arn:foo", tc.certDomainName, tc.certAlternativeDomains) + filter := matchingDomainFilter(tc.distrDomains) + s.Truef(filter(cert), "testCase: %s", tc.name) + } } -func (s *CertificateServiceTestSuite) TestMatchDomainFilter_SubDomain() { - cert := New( - "arn:xpto", - "foo.xpto.com", - []string{"foo.xpto.com", "*.foo.xpto.com"}, - ) +func (s *CertificateServiceTestSuite) TestMatchDomainFilters_DoesntMatch() { + testCases := []struct { + name string + certDomainName string + certAlternativeDomains []string + distrDomains []string + }{ + { + name: "Doesn't Match anything", + certDomainName: "bar.com", + certAlternativeDomains: []string{"bar.com", "*.bar.com"}, + distrDomains: []string{"www.foo.com", "foo.com"}, + }, + { + name: "Doesn't Match one domain", + certDomainName: "*.xpto.com", + certAlternativeDomains: []string{"*.xpto.com"}, + distrDomains: []string{"www.xpto.com", "xpto.com"}, + }, + } - filter := matchingDomainFilter("baz.foo.xpto.com") - s.True(filter(cert)) + for _, tc := range testCases { + cert := New("arn:foo", tc.certDomainName, tc.certAlternativeDomains) + filter := matchingDomainFilter(tc.distrDomains) + s.Falsef(filter(cert), "testCase: %s", tc.name) + } } diff --git a/internal/cloudfront/service.go b/internal/cloudfront/service.go index 623bccd..9ebb087 100644 --- a/internal/cloudfront/service.go +++ b/internal/cloudfront/service.go @@ -217,7 +217,6 @@ func (s *Service) newDistribution(ingresses []k8s.CDNIngress, group string, shar group, s.Config, ) - var err error var cert certificate.Certificate if s.Config.TLSIsEnabled() { @@ -259,17 +258,17 @@ func (s *Service) newDistribution(ingresses []k8s.CDNIngress, group string, shar // discoverCert returns the first found ACM Certificate that matches any Alternate Domain Name of the input Ingresses func (s *Service) discoverCert(ingresses []k8s.CDNIngress) (certificate.Certificate, error) { - errs := &multierror.Error{} + var alternateDomains []string for _, ing := range ingresses { - for _, dn := range ing.AlternateDomainNames { - cert, err := s.CertService.DiscoverByHost(dn) - if err == nil { - return cert, nil - } - errs = multierror.Append(errs, fmt.Errorf("%q: %v", dn, err)) - } + alternateDomains = append(alternateDomains, ing.AlternateDomainNames...) } - return certificate.Certificate{}, errs.ErrorOrNil() + + cert, err := s.CertService.DiscoverByHost(alternateDomains) + if err != nil { + return certificate.Certificate{}, fmt.Errorf("%v: %v", alternateDomains, err) + } + + return cert, nil } func (s *Service) s3Prefix(group string) string {