diff --git a/.dockerignore b/.dockerignore index 1021c82..1b668dc 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,12 +1,12 @@ -.github/ +config/*.yml docs/ examples/ pkg/ test/ terraform/ +.github/ *.md - .dockerignore .git -.gitignore \ No newline at end of file +.gitignore diff --git a/test/config/config.test.local.sandbox.yml b/config/config.test.local.sandbox.yml similarity index 100% rename from test/config/config.test.local.sandbox.yml rename to config/config.test.local.sandbox.yml diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index c996a68..88198a2 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -172,7 +172,7 @@ Run the `baseca` Container ```sh docker run -p 9090:9090 -e database_credentials=secret -v ~/.aws/:/home/baseca/.aws/:ro \ - -v /path/to/baseca/config:/home/baseca/config ghcr.io/coinbase/baseca:VERSION_SHA + -v /path/to/local/baseca/config:/home/baseca/config ghcr.io/coinbase/baseca:VERSION_SHA ``` ### 3b. Compile `baseca` as Executable (Option B) @@ -340,6 +340,10 @@ func main() { SigningAlgorithm: x509.SHA512WithRSA, PublicKeyAlgorithm: x509.RSA, KeySize: 4096, + DistinguishedName: baseca.DistinguishedName{ + Organization: []string{"Coinbase"}, + // Additional Fields + }, Output: baseca.Output{ PrivateKey: "/tmp/private.key", // baseca Generate Private Key Output Location Certificate: "/tmp/certificate.crt", // baseca Signed Leaf Certificate Output Location diff --git a/internal/config/load.go b/internal/config/load.go index 338df25..a47f4d4 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -3,6 +3,9 @@ package config import ( "errors" "fmt" + "os" + "path/filepath" + "runtime" "github.com/coinbase/baseca/internal/logger" "github.com/mitchellh/mapstructure" @@ -10,6 +13,10 @@ import ( "go.uber.org/zap" ) +const ( + configuration = "config.test.local.sandbox.yml" +) + type configProvider struct { v *viper.Viper } @@ -70,3 +77,48 @@ func (cp *configProvider) Get(path string, cfg any) error { func (cp *configProvider) Exists(path string) bool { return cp.v.Get(path) != nil } + +func GetTestConfigurationPath() (*Config, error) { + _, filename, _, ok := runtime.Caller(0) + if !ok { + fmt.Println("Error: Unable to get current file path") + } + + baseDir := filepath.Dir(filename) + for { + if _, err := os.Stat(filepath.Join(baseDir, "go.mod")); err == nil { + break + } + + parentDir := filepath.Dir(baseDir) + if parentDir == baseDir { + fmt.Println("Error: Unable to find base directory") + break + } + + baseDir = parentDir + } + + path := fmt.Sprintf("%s/config/%s", baseDir, configuration) + config, err := provideConfig(path) + if err != nil { + return nil, err + } + return config, nil +} + +func provideConfig(path string) (*Config, error) { + ctxLogger := logger.ContextLogger{Logger: logger.DefaultLogger} + + v, err := BuildViper(path) + if err != nil { + ctxLogger.Error(err.Error()) + } + + config, err := LoadConfig(v) + if err != nil { + return nil, err + } + + return config, err +} diff --git a/internal/lib/util/validator/domain.go b/internal/lib/util/validator/domain.go index 6d76369..e0b93ef 100644 --- a/internal/lib/util/validator/domain.go +++ b/internal/lib/util/validator/domain.go @@ -9,7 +9,7 @@ import ( ) const ( - _dns_regular_expression = `^[a-zA-Z*.]+$` + _dns_regular_expression = `^[a-zA-Z0-9*._-]+$` ) var valid_domains []string diff --git a/internal/v1/accounts/accounts_test.go b/internal/v1/accounts/accounts_test.go index 01b0f44..e3027f9 100644 --- a/internal/v1/accounts/accounts_test.go +++ b/internal/v1/accounts/accounts_test.go @@ -3,12 +3,12 @@ package accounts import ( mock_store "github.com/coinbase/baseca/db/mock" db "github.com/coinbase/baseca/db/sqlc" + "github.com/coinbase/baseca/internal/config" "github.com/coinbase/baseca/internal/lib/util/validator" - "github.com/coinbase/baseca/test" ) func buildAccountsConfig(store *mock_store.MockStore) (*Service, error) { - config, err := test.GetTestConfigurationPath() + config, err := config.GetTestConfigurationPath() if err != nil { return nil, err } diff --git a/internal/v1/accounts/service_test.go b/internal/v1/accounts/service_test.go index 198ccc9..69c8039 100644 --- a/internal/v1/accounts/service_test.go +++ b/internal/v1/accounts/service_test.go @@ -178,7 +178,7 @@ func TestCreateServiceAccount(t *testing.T) { req: &apiv1.CreateServiceAccountRequest{ ServiceAccount: "example", Environment: "sandbox", - SubjectAlternativeNames: []string{"000.example.com"}, + SubjectAlternativeNames: []string{"{}.example.com"}, ExtendedKey: "EndEntityServerAuthCertificate", CertificateAuthorities: []string{"sandbox_use1"}, SubordinateCa: "infrastructure", diff --git a/internal/v1/certificate/certificate_test.go b/internal/v1/certificate/certificate_test.go index 9bab8c5..ae611bb 100644 --- a/internal/v1/certificate/certificate_test.go +++ b/internal/v1/certificate/certificate_test.go @@ -18,8 +18,8 @@ import ( acm_pca "github.com/coinbase/baseca/internal/client/acmpca" "github.com/coinbase/baseca/internal/client/firehose" redis_client "github.com/coinbase/baseca/internal/client/redis" + "github.com/coinbase/baseca/internal/config" "github.com/coinbase/baseca/internal/lib/util/validator" - "github.com/coinbase/baseca/test" "github.com/go-redis/redis/v8" "github.com/stretchr/testify/mock" ) @@ -112,7 +112,7 @@ func (m *mockedPrivateCaClient) GetCertificateAuthorityCertificate(ctx context.C } func buildCertificateConfig(store *mock_store.MockStore) (*Certificate, error) { - config, err := test.GetTestConfigurationPath() + config, err := config.GetTestConfigurationPath() if err != nil { return nil, err } diff --git a/internal/v1/middleware/authentication.go b/internal/v1/middleware/authentication.go index ff956a3..c201259 100644 --- a/internal/v1/middleware/authentication.go +++ b/internal/v1/middleware/authentication.go @@ -30,16 +30,15 @@ func (m *Middleware) ServerAuthenticationInterceptor(ctx context.Context, req an var ok bool methods := map[string]string{ - "/grpc.health.v1.Health/Check": _pass_auth, - "/baseca.v1.Account/LoginUser": _pass_auth, - "/baseca.v1.Account/UpdateUserCredentials": _pass_auth, - "/baseca.v1.Certificate/SignCSR": _service_auth, - "/baseca.v1.Certificate/OperationsSignCSR": _provisioner_auth, - "/baseca.v1.Certificate/QueryCertificateMetadata": _provisioner_auth, - "/baseca.v1.Certificate/GetSignedIntermediateCertificate": _provisioner_auth, - "/baseca.v1.Service/ProvisionServiceAccount": _provisioner_auth, - "/baseca.v1.Service/GetServiceAccountByMetadata": _provisioner_auth, - "/baseca.v1.Service/DeleteProvisionedServiceAccount": _provisioner_auth, + "/grpc.health.v1.Health/Check": _pass_auth, + "/baseca.v1.Account/LoginUser": _pass_auth, + "/baseca.v1.Account/UpdateUserCredentials": _pass_auth, + "/baseca.v1.Certificate/SignCSR": _service_auth, + "/baseca.v1.Certificate/OperationsSignCSR": _provisioner_auth, + "/baseca.v1.Certificate/QueryCertificateMetadata": _provisioner_auth, + "/baseca.v1.Service/ProvisionServiceAccount": _provisioner_auth, + "/baseca.v1.Service/GetServiceAccountByMetadata": _provisioner_auth, + "/baseca.v1.Service/DeleteProvisionedServiceAccount": _provisioner_auth, } if auth, ok = methods[info.FullMethod]; !ok { diff --git a/internal/v1/users/users_test.go b/internal/v1/users/users_test.go index 0c0f300..ef1af3c 100644 --- a/internal/v1/users/users_test.go +++ b/internal/v1/users/users_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms" mock_store "github.com/coinbase/baseca/db/mock" db "github.com/coinbase/baseca/db/sqlc" + "github.com/coinbase/baseca/internal/config" lib "github.com/coinbase/baseca/internal/lib/authentication" - "github.com/coinbase/baseca/test" "github.com/stretchr/testify/mock" ) @@ -26,7 +26,7 @@ func (m *mockedKmsClient) Verify(ctx context.Context, params *kms.VerifyInput, o } func buildUsersConfig(store *mock_store.MockStore) (*User, error) { - config, err := test.GetTestConfigurationPath() + config, err := config.GetTestConfigurationPath() if err != nil { return nil, err } diff --git a/pkg/client/certificate.go b/pkg/client/certificate.go index b7f0277..a5e6ea4 100644 --- a/pkg/client/certificate.go +++ b/pkg/client/certificate.go @@ -2,14 +2,13 @@ package baseca import ( "context" - "fmt" - "os" apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" "github.com/coinbase/baseca/pkg/types" + "github.com/coinbase/baseca/pkg/util" ) -func (c *client) IssueCertificate(certificateRequest CertificateRequest) (*apiv1.SignedCertificate, error) { +func (c *Client) IssueCertificate(certificateRequest CertificateRequest) (*apiv1.SignedCertificate, error) { signingRequest, err := GenerateCSR(certificateRequest) if err != nil { return nil, err @@ -24,7 +23,7 @@ func (c *client) IssueCertificate(certificateRequest CertificateRequest) (*apiv1 return nil, err } - err = parseCertificateFormat(signedCertificate, types.SignedCertificate{ + err = util.ParseCertificateFormat(signedCertificate, types.SignedCertificate{ CertificatePath: certificateRequest.Output.Certificate, IntermediateCertificateChainPath: certificateRequest.Output.IntermediateCertificateChain, RootCertificateChainPath: certificateRequest.Output.RootCertificateChain, @@ -36,34 +35,3 @@ func (c *client) IssueCertificate(certificateRequest CertificateRequest) (*apiv1 return signedCertificate, nil } - -func parseCertificateFormat(certificate *apiv1.SignedCertificate, parameter types.SignedCertificate) error { - // Leaf Certificate Path - if len(parameter.CertificatePath) != 0 { - certificate := []byte(certificate.Certificate) - if err := os.WriteFile(parameter.CertificatePath, certificate, os.ModePerm); err != nil { - return fmt.Errorf("error writing certificate to [%s]", parameter.CertificatePath) - } - } - - // Intermediate Certificate Chain Path - if len(parameter.IntermediateCertificateChainPath) != 0 { - certificate := []byte(certificate.IntermediateCertificateChain) - if err := os.WriteFile(parameter.IntermediateCertificateChainPath, certificate, os.ModePerm); err != nil { - return fmt.Errorf("error writing certificate to [%s]", parameter.IntermediateCertificateChainPath) - } - } - - // Root Certificate Chain Path - if len(parameter.RootCertificateChainPath) != 0 { - certificate := []byte(certificate.CertificateChain) - if err := os.WriteFile(parameter.RootCertificateChainPath, certificate, os.ModePerm); err != nil { - return fmt.Errorf("error writing certificate chain to [%s]", parameter.RootCertificateChainPath) - } - } - return nil -} - -func (c *client) QueryCertificateMetadata(req *apiv1.QueryCertificateMetadataRequest) (*apiv1.CertificatesParameter, error) { - return c.Certificate.QueryCertificateMetadata(context.Background(), req) -} diff --git a/pkg/client/client.go b/pkg/client/client.go index 23a6541..a19e74f 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -3,7 +3,6 @@ package baseca import ( "context" "crypto/tls" - "crypto/x509" "fmt" "strings" "sync" @@ -17,103 +16,59 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) -// var Endpoints Environment - -var Attestation Provider = Provider{ - Local: "NONE", - AWS: "AWS", -} - -var Env = Environment{ - Local: "Local", - Sandbox: "Sandbox", - Development: "Development", - Staging: "Staging", - PreProduction: "PreProduction", - Production: "Production", -} - -type Environment struct { - Local string - Sandbox string - Development string - Staging string - PreProduction string - Production string -} - -type Configuration struct { - URL string - Environment string -} - -type Provider struct { - Local string - AWS string -} - -type Output struct { - CertificateSigningRequest string - Certificate string - IntermediateCertificateChain string - RootCertificateChain string - PrivateKey string -} - -type CertificateRequest struct { - CommonName string - SubjectAlternateNames []string - DistinguishedName DistinguishedName - SigningAlgorithm x509.SignatureAlgorithm - PublicKeyAlgorithm x509.PublicKeyAlgorithm - KeySize int - Output Output -} - -type DistinguishedName struct { - Country []string - Province []string - Locality []string - Organization []string - OrganizationalUnit []string -} - -type Authentication struct { - ClientId string - ClientToken string -} - -type client struct { - endpoint string - authentication Authentication - attestation string - Certificate apiv1.CertificateClient - Service apiv1.ServiceClient -} - const ( _client_id_header = "X-BASECA-CLIENT-ID" _client_token_header = "X-BASECA-CLIENT-TOKEN" // #nosec G101 False Positive _aws_iid_metadata = "X-BASECA-INSTANCE-METADATA" + _account_auth_header = "AUTHORIZATION" ) +type Client struct { + Endpoint string + Authentication Authentication + Attestation string + Certificate apiv1.CertificateClient + Service apiv1.ServiceClient +} + +type AccountClient interface { + LoginUser(ctx context.Context, in *apiv1.LoginUserRequest, opts ...grpc.CallOption) (*apiv1.LoginUserResponse, error) + DeleteUser(ctx context.Context, in *apiv1.UsernameRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) + GetUser(ctx context.Context, in *apiv1.UsernameRequest, opts ...grpc.CallOption) (*apiv1.User, error) + ListUsers(ctx context.Context, in *apiv1.QueryParameter, opts ...grpc.CallOption) (*apiv1.Users, error) + CreateUser(ctx context.Context, in *apiv1.CreateUserRequest, opts ...grpc.CallOption) (*apiv1.User, error) + UpdateUserCredentials(ctx context.Context, in *apiv1.UpdateCredentialsRequest, opts ...grpc.CallOption) (*apiv1.User, error) + UpdateUserPermissions(ctx context.Context, in *apiv1.UpdatePermissionsRequest, opts ...grpc.CallOption) (*apiv1.User, error) +} + type CertificateClient interface { SignCSR(ctx context.Context, in *apiv1.CertificateSigningRequest, opts ...grpc.CallOption) (*apiv1.SignedCertificate, error) + GetCertificate(ctx context.Context, in *apiv1.CertificateSerialNumber, opts ...grpc.CallOption) (*apiv1.CertificateParameter, error) + ListCertificates(ctx context.Context, in *apiv1.ListCertificatesRequest, opts ...grpc.CallOption) (*apiv1.CertificatesParameter, error) + RevokeCertificate(ctx context.Context, in *apiv1.RevokeCertificateRequest, opts ...grpc.CallOption) (*apiv1.RevokeCertificateResponse, error) OperationsSignCSR(ctx context.Context, in *apiv1.OperationsSignRequest, opts ...grpc.CallOption) (*apiv1.SignedCertificate, error) QueryCertificateMetadata(ctx context.Context, in *apiv1.QueryCertificateMetadataRequest, opts ...grpc.CallOption) (*apiv1.CertificatesParameter, error) } type ServiceClient interface { + CreateServiceAccount(ctx context.Context, in *apiv1.CreateServiceAccountRequest, opts ...grpc.CallOption) (*apiv1.CreateServiceAccountResponse, error) + CreateProvisionerAccount(ctx context.Context, in *apiv1.CreateProvisionerAccountRequest, opts ...grpc.CallOption) (*apiv1.CreateProvisionerAccountResponse, error) + GetProvisionerAccount(ctx context.Context, in *apiv1.AccountId, opts ...grpc.CallOption) (*apiv1.ProvisionerAccount, error) + ListProvisionerAccounts(ctx context.Context, in *apiv1.QueryParameter, opts ...grpc.CallOption) (*apiv1.ProvisionerAccounts, error) ProvisionServiceAccount(ctx context.Context, in *apiv1.ProvisionServiceAccountRequest, opts ...grpc.CallOption) (*apiv1.ProvisionServiceAccountResponse, error) - GetServiceAccountByMetadata(ctx context.Context, in *apiv1.GetServiceAccountMetadataRequest, opts ...grpc.CallOption) (*apiv1.ServiceAccounts, error) + ListServiceAccounts(ctx context.Context, in *apiv1.QueryParameter, opts ...grpc.CallOption) (*apiv1.ServiceAccounts, error) + GetServiceAccount(ctx context.Context, in *apiv1.AccountId, opts ...grpc.CallOption) (*apiv1.ServiceAccount, error) + GetServiceAccountMetadata(ctx context.Context, in *apiv1.GetServiceAccountMetadataRequest, opts ...grpc.CallOption) (*apiv1.ServiceAccounts, error) + DeleteServiceAccount(ctx context.Context, in *apiv1.AccountId, opts ...grpc.CallOption) (*emptypb.Empty, error) + DeleteProvisionerAccount(ctx context.Context, in *apiv1.AccountId, opts ...grpc.CallOption) (*emptypb.Empty, error) DeleteProvisionedServiceAccount(ctx context.Context, in *apiv1.AccountId, opts ...grpc.CallOption) (*emptypb.Empty, error) } -func LoadDefaultConfiguration(configuration Configuration, attestation string, authentication Authentication) (*client, error) { - c := client{ - endpoint: configuration.URL, - authentication: authentication, - attestation: attestation, +func LoadDefaultConfiguration(configuration Configuration, attestation string, authentication Authentication) (*Client, error) { + c := Client{ + Endpoint: configuration.URL, + Authentication: authentication, + Attestation: attestation, } if configuration.Environment == Env.Local { @@ -137,14 +92,21 @@ func LoadDefaultConfiguration(configuration Configuration, attestation string, a } } -func (c *client) methodInterceptor() grpc.UnaryClientInterceptor { +func (c *Client) methodInterceptor() grpc.UnaryClientInterceptor { methodOptions := map[string]grpc.UnaryClientInterceptor{ - "/baseca.v1.Certificate/SignCSR": c.clientAuthUnaryInterceptor, - "/baseca.v1.Certificate/OperationsSignCSR": c.clientAuthUnaryInterceptor, - "/baseca.v1.Certificate/QueryCertificateMetadata": c.clientAuthUnaryInterceptor, + // Certificate Interface + "/baseca.v1.Certificate/SignCSR": c.clientAuthUnaryInterceptor, + "/baseca.v1.Certificate/OperationsSignCSR": c.clientAuthUnaryInterceptor, + "/baseca.v1.Certificate/QueryCertificateMetadata": c.clientAuthUnaryInterceptor, + + // Service Interface "/baseca.v1.Service/ProvisionServiceAccount": c.clientAuthUnaryInterceptor, "/baseca.v1.Service/GetServiceAccountByMetadata": c.clientAuthUnaryInterceptor, "/baseca.v1.Service/DeleteProvisionedServiceAccount": c.clientAuthUnaryInterceptor, + + // Account Interface + "/baseca.v1.Account/LoginUser": c.accountAuthUnaryInterceptor, + // TODO: Add Additional RPC Methods } return mapMethodInterceptor(methodOptions) } @@ -176,11 +138,11 @@ func returnMethodInterceptor(chainMap sync.Map, method string) (grpc.UnaryClient return nil, false } -func (c *client) clientAuthUnaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - ctx = metadata.AppendToOutgoingContext(ctx, _client_id_header, c.authentication.ClientId) - ctx = metadata.AppendToOutgoingContext(ctx, _client_token_header, c.authentication.ClientToken) +func (c *Client) clientAuthUnaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctx = metadata.AppendToOutgoingContext(ctx, _client_id_header, c.Authentication.ClientId) + ctx = metadata.AppendToOutgoingContext(ctx, _client_token_header, c.Authentication.ClientToken) - if c.attestation == Attestation.AWS { + if c.Attestation == Attestation.AWS { instance_metadata, err := aws_iid.BuildInstanceMetadata() if err != nil { return fmt.Errorf("error generating aws_iid node attestation") @@ -191,3 +153,10 @@ func (c *client) clientAuthUnaryInterceptor(ctx context.Context, method string, err := invoker(ctx, method, req, reply, cc, opts...) return err } + +func (c *Client) accountAuthUnaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctx = metadata.AppendToOutgoingContext(ctx, _account_auth_header, fmt.Sprintf("Bearer %s", c.Authentication.AuthToken)) + + err := invoker(ctx, method, req, reply, cc, opts...) + return err +} diff --git a/pkg/client/csr.go b/pkg/client/csr.go index 6ed2c99..f469571 100644 --- a/pkg/client/csr.go +++ b/pkg/client/csr.go @@ -2,27 +2,44 @@ package baseca import ( "bytes" - "crypto/ecdsa" - "crypto/elliptic" "crypto/rand" - "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/pem" - "errors" "fmt" "os" + "github.com/coinbase/baseca/pkg/crypto" "github.com/coinbase/baseca/pkg/types" ) func GenerateCSR(csr CertificateRequest) (*types.SigningRequest, error) { + var generator crypto.CSRGenerator + switch csr.PublicKeyAlgorithm { case x509.RSA: - if csr.KeySize < 2048 { - return nil, errors.New("invalid key size, rsa minimum valid bits 2048]") + if _, ok := types.PublicKeyAlgorithms["RSA"].KeySize[csr.KeySize]; !ok { + return nil, fmt.Errorf("rsa invalid key size %d", csr.KeySize) + } + if _, ok := types.PublicKeyAlgorithms["RSA"].SigningAlgorithm[csr.SigningAlgorithm]; !ok { + return nil, fmt.Errorf("rsa invalid signing algorithm %s", csr.SigningAlgorithm) + } + generator = &crypto.SigningRequestGeneratorRSA{Size: csr.KeySize} + case x509.ECDSA: + if _, ok := types.PublicKeyAlgorithms["ECDSA"].KeySize[csr.KeySize]; !ok { + return nil, fmt.Errorf("ecdsa invalid key size %d", csr.KeySize) } - // TODO: ECDSA + if _, ok := types.PublicKeyAlgorithms["ECDSA"].SigningAlgorithm[csr.SigningAlgorithm]; !ok { + return nil, fmt.Errorf("ecdsa invalid signing algorithm %s", csr.SigningAlgorithm) + } + generator = &crypto.SigningRequestGeneratorECDSA{Curve: csr.KeySize} + default: + return nil, fmt.Errorf("unsupported public key algorithm") + } + + pk, err := generator.Generate() + if err != nil { + return nil, fmt.Errorf("error generating private key [%s]: %w", generator.KeyType(), err) } subject := pkix.Name{ @@ -40,99 +57,45 @@ func GenerateCSR(csr CertificateRequest) (*types.SigningRequest, error) { DNSNames: csr.SubjectAlternateNames, } - switch csr.SigningAlgorithm { - case x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA: - pk, err := rsa.GenerateKey(rand.Reader, csr.KeySize) - if err != nil { - return nil, errors.New("error generating rsa key pair") - } - - csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &template, pk) - if err != nil { - return nil, err - } - - certificatePem := new(bytes.Buffer) - err = pem.Encode(certificatePem, &pem.Block{ - Type: "CERTIFICATE REQUEST", - Bytes: csrBytes, - }) - - if err != nil { - return nil, errors.New("error encoding certificate request (csr)") - } - - if len(csr.Output.CertificateSigningRequest) != 0 { - if err := os.WriteFile(csr.Output.CertificateSigningRequest, certificatePem.Bytes(), os.ModePerm); err != nil { - return nil, fmt.Errorf("error writing certificate signing request (csr) to [%s]", csr.Output.CertificateSigningRequest) - } - } - - pkBlock := &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(pk), - } - - if len(csr.Output.PrivateKey) != 0 { - if err := os.WriteFile(csr.Output.PrivateKey, pem.EncodeToMemory(pkBlock), os.ModePerm); err != nil { - return nil, fmt.Errorf("error writing private key to [%s]", csr.Output.PrivateKey) - } - } - - return &types.SigningRequest{ - CSR: certificatePem, - PrivateKey: pkBlock, - }, nil - - case x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512: - pk, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) - if err != nil { - return nil, errors.New("error generating ECDSA key pair") - } - - csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &template, pk) - if err != nil { - return nil, err - } + csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &template, pk) + if err != nil { + return nil, fmt.Errorf("error creating certificate request: %w", err) + } - certificatePem := new(bytes.Buffer) - err = pem.Encode(certificatePem, &pem.Block{ - Type: "CERTIFICATE REQUEST", - Bytes: csrBytes, - }) + certificatePem := new(bytes.Buffer) + err = pem.Encode(certificatePem, &pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: csrBytes, + }) - if err != nil { - return nil, errors.New("error encoding certificate request (csr)") - } + if err != nil { + return nil, fmt.Errorf("error encoding certificate request (csr): %w", err) + } - if len(csr.Output.CertificateSigningRequest) != 0 { - if err := os.WriteFile(csr.Output.CertificateSigningRequest, certificatePem.Bytes(), os.ModePerm); err != nil { - return nil, fmt.Errorf("error writing certificate signing request (csr) to [%s]", csr.Output.CertificateSigningRequest) - } + if len(csr.Output.CertificateSigningRequest) != 0 { + if err := os.WriteFile(csr.Output.CertificateSigningRequest, certificatePem.Bytes(), os.ModePerm); err != nil { + return nil, fmt.Errorf("error writing certificate signing request (csr) to [%s]", csr.Output.CertificateSigningRequest) } + } - ecPrivateKeyBytes, err := x509.MarshalECPrivateKey(pk) - if err != nil { - return nil, errors.New("error marshaling ECDSA private key") - } + pkBytes, err := generator.MarshalPrivateKey(pk) + if err != nil { + return nil, fmt.Errorf("error marshaling private key: %w", err) + } - pkBlock := &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: ecPrivateKeyBytes, - } + pkBlock := &pem.Block{ + Type: generator.KeyType(), + Bytes: pkBytes, + } - if len(csr.Output.PrivateKey) != 0 { - if err := os.WriteFile(csr.Output.PrivateKey, pem.EncodeToMemory(pkBlock), os.ModePerm); err != nil { - return nil, fmt.Errorf("error writing private key to [%s]", csr.Output.PrivateKey) - } + if len(csr.Output.PrivateKey) != 0 { + if err := os.WriteFile(csr.Output.PrivateKey, pem.EncodeToMemory(pkBlock), os.ModePerm); err != nil { + return nil, fmt.Errorf("error writing private key to [%s]", csr.Output.PrivateKey) } - - return &types.SigningRequest{ - CSR: certificatePem, - PrivateKey: pkBlock, - }, nil - - default: - return nil, errors.New("unsupported signing algorithm") } + + return &types.SigningRequest{ + CSR: certificatePem, + PrivateKey: pkBlock, + }, nil } diff --git a/pkg/client/provisioner.go b/pkg/client/provisioner.go index d1edf5b..dc12940 100644 --- a/pkg/client/provisioner.go +++ b/pkg/client/provisioner.go @@ -5,9 +5,10 @@ import ( apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" "github.com/coinbase/baseca/pkg/types" + "github.com/coinbase/baseca/pkg/util" ) -func (c *client) ProvisionIssueCertificate(certificateRequest CertificateRequest, ca *apiv1.CertificateAuthorityParameter, service, environment, extendedKey string) (*apiv1.SignedCertificate, error) { +func (c *Client) ProvisionIssueCertificate(certificateRequest CertificateRequest, ca *apiv1.CertificateAuthorityParameter, service, environment, extendedKey string) (*apiv1.SignedCertificate, error) { signingRequest, err := GenerateCSR(certificateRequest) if err != nil { return nil, err @@ -26,7 +27,7 @@ func (c *client) ProvisionIssueCertificate(certificateRequest CertificateRequest return nil, err } - err = parseCertificateFormat(signedCertificate, types.SignedCertificate{ + err = util.ParseCertificateFormat(signedCertificate, types.SignedCertificate{ CertificatePath: certificateRequest.Output.Certificate, IntermediateCertificateChainPath: certificateRequest.Output.IntermediateCertificateChain, RootCertificateChainPath: certificateRequest.Output.RootCertificateChain, @@ -38,11 +39,3 @@ func (c *client) ProvisionIssueCertificate(certificateRequest CertificateRequest return signedCertificate, nil } - -func (c *client) ProvisionServiceAccount(req *apiv1.ProvisionServiceAccountRequest) (*apiv1.ProvisionServiceAccountResponse, error) { - return c.Service.ProvisionServiceAccount(context.Background(), req) -} - -func (c *client) GetServiceAccountByMetadata(req *apiv1.GetServiceAccountMetadataRequest) (*apiv1.ServiceAccounts, error) { - return c.Service.GetServiceAccountMetadata(context.Background(), req) -} diff --git a/pkg/client/sign.go b/pkg/client/sign.go index 150cfae..b35c6a0 100644 --- a/pkg/client/sign.go +++ b/pkg/client/sign.go @@ -15,9 +15,10 @@ import ( apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" "github.com/coinbase/baseca/pkg/types" + "github.com/coinbase/baseca/pkg/util" ) -func (c *client) GenerateSignature(csr CertificateRequest, element []byte) (*[]byte, []*x509.Certificate, error) { +func (c *Client) GenerateSignature(csr CertificateRequest, element []byte) (*[]byte, []*x509.Certificate, error) { var certificatePem []*pem.Block var certificateChain []*x509.Certificate @@ -35,7 +36,7 @@ func (c *client) GenerateSignature(csr CertificateRequest, element []byte) (*[]b return nil, nil, err } - err = parseCertificateFormat(signedCertificate, types.SignedCertificate{ + err = util.ParseCertificateFormat(signedCertificate, types.SignedCertificate{ CertificatePath: csr.Output.Certificate, IntermediateCertificateChainPath: csr.Output.IntermediateCertificateChain, RootCertificateChainPath: csr.Output.RootCertificateChain, @@ -83,7 +84,7 @@ func (c *client) GenerateSignature(csr CertificateRequest, element []byte) (*[]b return &signature, certificateChain, nil } -func (c *client) ValidateSignature(tc types.TrustChain, manifest types.Manifest) error { +func (c *Client) ValidateSignature(tc types.TrustChain, manifest types.Manifest) error { err := manifest.CertificateChain[0].CheckSignature(manifest.SigningAlgorithm, manifest.Data, manifest.Signature) if err != nil { return fmt.Errorf("signature verification failed: %s", err) @@ -114,7 +115,7 @@ func (c *client) ValidateSignature(tc types.TrustChain, manifest types.Manifest) return fmt.Errorf("invalid subject alternative name (san) from code signing certificate") } - rootCertificatePool, err := c.generateCertificatePool(tc) + rootCertificatePool, err := util.GenerateCertificatePool(tc) if err != nil { return err } @@ -129,47 +130,3 @@ func (c *client) ValidateSignature(tc types.TrustChain, manifest types.Manifest) } return nil } - -func (c *client) generateCertificatePool(tc types.TrustChain) (*x509.CertPool, error) { - certPool := x509.NewCertPool() - - for _, dir := range tc.CertificateAuthorityDirectory { - files, err := os.ReadDir(dir) - if err != nil { - return nil, errors.New("invalid certificate authority directory") - } - - for _, certFile := range files { // #nosec G304 User Only Has Predefined Environment Parameters - data, err := os.ReadFile(filepath.Join(dir, certFile.Name())) - if err != nil { - return nil, errors.New("invalid certificate file") - } - pemBlock, _ := pem.Decode(data) - if pemBlock == nil || pemBlock.Type != "CERTIFICATE" { - return nil, errors.New("invalid input file") - } - cert, err := x509.ParseCertificate(pemBlock.Bytes) - if err != nil { - return nil, errors.New("error parsing x.509 certificate") - } - certPool.AddCert(cert) - } - } - - for _, ca := range tc.CertificateAuthorityFiles { - data, err := os.ReadFile(filepath.Clean(ca)) - if err != nil { - return nil, errors.New("invalid certificate authority file") - } - pemBlock, _ := pem.Decode(data) - if pemBlock == nil || pemBlock.Type != "CERTIFICATE" { - return nil, errors.New("invalid input file") - } - cert, err := x509.ParseCertificate(pemBlock.Bytes) - if err != nil { - return nil, errors.New("error parsing x.509 certificate") - } - certPool.AddCert(cert) - } - return certPool, nil -} diff --git a/pkg/client/types.go b/pkg/client/types.go new file mode 100644 index 0000000..fee1b8e --- /dev/null +++ b/pkg/client/types.go @@ -0,0 +1,68 @@ +package baseca + +import "crypto/x509" + +var Attestation Provider = Provider{ + Local: "NONE", + AWS: "AWS", +} + +var Env = Environment{ + Local: "Local", + Sandbox: "Sandbox", + Development: "Development", + Staging: "Staging", + PreProduction: "PreProduction", + Production: "Production", +} + +type Environment struct { + Local string + Sandbox string + Development string + Staging string + PreProduction string + Production string +} + +type Configuration struct { + URL string + Environment string +} + +type Provider struct { + Local string + AWS string +} + +type Authentication struct { + ClientId string + ClientToken string + AuthToken string +} + +type CertificateRequest struct { + CommonName string + SubjectAlternateNames []string + DistinguishedName DistinguishedName + SigningAlgorithm x509.SignatureAlgorithm + PublicKeyAlgorithm x509.PublicKeyAlgorithm + KeySize int + Output Output +} + +type DistinguishedName struct { + Country []string + Province []string + Locality []string + Organization []string + OrganizationalUnit []string +} + +type Output struct { + CertificateSigningRequest string + Certificate string + IntermediateCertificateChain string + RootCertificateChain string + PrivateKey string +} diff --git a/pkg/crypto/generate.go b/pkg/crypto/generate.go new file mode 100644 index 0000000..2134e7d --- /dev/null +++ b/pkg/crypto/generate.go @@ -0,0 +1,95 @@ +package crypto + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "fmt" + + "github.com/coinbase/baseca/pkg/types" +) + +type CSRGenerator interface { + Generate() (crypto.PrivateKey, error) + KeyType() string + MarshalPrivateKey(key crypto.PrivateKey) ([]byte, error) + SupportsPublicKeyAlgorithm(algorithm x509.PublicKeyAlgorithm) bool + SupportsSigningAlgorithm(algorithm x509.SignatureAlgorithm) bool + SupportsKeySize(size int) bool +} + +type SigningRequestGeneratorRSA struct { + Size int +} + +type SigningRequestGeneratorECDSA struct { + Curve int +} + +// RSA Interface +func (r *SigningRequestGeneratorRSA) Generate() (crypto.PrivateKey, error) { + return rsa.GenerateKey(rand.Reader, r.Size) +} + +func (r *SigningRequestGeneratorRSA) KeyType() string { + return "RSA PRIVATE KEY" +} + +func (r *SigningRequestGeneratorRSA) MarshalPrivateKey(key crypto.PrivateKey) ([]byte, error) { + return x509.MarshalPKCS1PrivateKey(key.(*rsa.PrivateKey)), nil +} + +func (r *SigningRequestGeneratorRSA) SupportsPublicKeyAlgorithm(algorithm x509.PublicKeyAlgorithm) bool { + return algorithm == x509.RSA +} + +func (r *SigningRequestGeneratorRSA) SupportsSigningAlgorithm(algorithm x509.SignatureAlgorithm) bool { + _, ok := types.PublicKeyAlgorithms["RSA"].SigningAlgorithm[algorithm] + return ok +} + +func (r *SigningRequestGeneratorRSA) SupportsKeySize(size int) bool { + _, ok := types.PublicKeyAlgorithms["RSA"].KeySize[size] + return ok +} + +// ECDSA Interface +func (e *SigningRequestGeneratorECDSA) Generate() (crypto.PrivateKey, error) { + c, ok := types.PublicKeyAlgorithms["ECDSA"].KeySize[e.Curve] + + if !ok { + return nil, fmt.Errorf("ecdsa curve [%d] not supported", e.Curve) + } + + curve, ok := c.(elliptic.Curve) + if !ok { + return nil, fmt.Errorf("invalid elliptic.Curve type") + } + + return ecdsa.GenerateKey(curve, rand.Reader) +} + +func (e *SigningRequestGeneratorECDSA) KeyType() string { + return "EC PRIVATE KEY" +} + +func (e *SigningRequestGeneratorECDSA) MarshalPrivateKey(key crypto.PrivateKey) ([]byte, error) { + return x509.MarshalECPrivateKey(key.(*ecdsa.PrivateKey)) +} + +func (e *SigningRequestGeneratorECDSA) SupportsPublicKeyAlgorithm(algorithm x509.PublicKeyAlgorithm) bool { + return algorithm == x509.ECDSA +} + +func (e *SigningRequestGeneratorECDSA) SupportsSigningAlgorithm(algorithm x509.SignatureAlgorithm) bool { + _, ok := types.PublicKeyAlgorithms["ECDSA"].SigningAlgorithm[algorithm] + return ok +} + +func (e *SigningRequestGeneratorECDSA) SupportsKeySize(size int) bool { + _, ok := types.PublicKeyAlgorithms["ECDSA"].KeySize[size] + return ok +} diff --git a/pkg/crypto/generate_test.go b/pkg/crypto/generate_test.go new file mode 100644 index 0000000..d89bfd0 --- /dev/null +++ b/pkg/crypto/generate_test.go @@ -0,0 +1,71 @@ +package crypto + +import ( + "crypto/x509" + "testing" +) + +func TestSigningRequestGeneratorRSA(t *testing.T) { + r := &SigningRequestGeneratorRSA{ + Size: 2048, + } + + key, err := r.Generate() + if err != nil { + t.Fatalf("error generating rsa private key: %v", err) + } + + if keyType := r.KeyType(); keyType != "RSA PRIVATE KEY" { + t.Errorf("RSA PRIVATE KEY does not exist within private key") + + } + + if !r.SupportsPublicKeyAlgorithm(x509.RSA) { + t.Errorf("rsa public key algorithm not supported") + } + + if !r.SupportsSigningAlgorithm(x509.SHA256WithRSA) { + t.Errorf("SHA256WithRSA signing algorithm not supported") + } + + if !r.SupportsKeySize(2048) { + t.Errorf("rsa key size not supported") + } + + _, err = r.MarshalPrivateKey(key) + if err != nil { + t.Errorf("error marshaling rsa private key: %v", err) + } +} + +func TestSigningRequestGeneratorECDSA(t *testing.T) { + e := &SigningRequestGeneratorECDSA{ + Curve: 256, + } + + key, err := e.Generate() + if err != nil { + t.Fatalf("error generating ecdsa private key: %v", err) + } + + if keyType := e.KeyType(); keyType != "EC PRIVATE KEY" { + t.Errorf("EC PRIVATE KEY does not exist within private key") + } + + if !e.SupportsPublicKeyAlgorithm(x509.ECDSA) { + t.Errorf("ecdsa public key algorithm not supported") + } + + if !e.SupportsSigningAlgorithm(x509.ECDSAWithSHA256) { + t.Errorf("ECDSAWithSHA256 signing algorithm not supported") + } + + if !e.SupportsKeySize(256) { + t.Errorf("ecdsa curve size not supported") + } + + _, err = e.MarshalPrivateKey(key) + if err != nil { + t.Errorf("error marshaling ecdsa private key: %v", err) + } +} diff --git a/pkg/certificate/pk.go b/pkg/crypto/pk.go similarity index 94% rename from pkg/certificate/pk.go rename to pkg/crypto/pk.go index 8779734..3c01a74 100644 --- a/pkg/certificate/pk.go +++ b/pkg/crypto/pk.go @@ -1,4 +1,4 @@ -package certificate +package crypto import ( "crypto" @@ -59,6 +59,10 @@ func (key *ECDSA) Sign(data []byte) ([]byte, error) { } func ReturnPrivateKey(key AsymmetricKey) (interface{}, error) { + if key == nil { + return nil, fmt.Errorf("asymmetric key is nil") + } + switch k := key.KeyPair().(type) { case *RSA: return k.PrivateKey, nil diff --git a/pkg/crypto/pk_test.go b/pkg/crypto/pk_test.go new file mode 100644 index 0000000..2badf9e --- /dev/null +++ b/pkg/crypto/pk_test.go @@ -0,0 +1,88 @@ +package crypto + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "reflect" + "testing" +) + +func TestRSASign(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate private key: %v", err) + } + rsaKey := &RSA{ + PublicKey: &privateKey.PublicKey, + PrivateKey: privateKey, + } + data := []byte("_example") + signature, err := rsaKey.Sign(data) + if err != nil { + t.Fatalf("failed to sign data: %v", err) + } + if len(signature) == 0 { + t.Fatalf("expected non-empty signature") + } +} + +func TestECDSASign(t *testing.T) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("failed to generate private key: %v", err) + } + ecdsaKey := &ECDSA{ + PublicKey: &privateKey.PublicKey, + PrivateKey: privateKey, + } + data := []byte("_example") + signature, err := ecdsaKey.Sign(data) + if err != nil { + t.Fatalf("failed to sign data: %v", err) + } + if len(signature) == 0 { + t.Fatalf("expected non-empty signature") + } +} + +func TestReturnPrivateKey(t *testing.T) { + rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate rsa private key: %v", err) + } + + ecdsaPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("failed to generate ecdsa private key: %v", err) + } + + tests := []struct { + key AsymmetricKey + expected interface{} + }{ + {&RSA{PrivateKey: rsaPrivateKey}, rsaPrivateKey}, + {&ECDSA{PrivateKey: ecdsaPrivateKey}, ecdsaPrivateKey}, + {nil, nil}, + } + + for _, test := range tests { + got, err := ReturnPrivateKey(test.key) + if err != nil && test.key != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(got, test.expected) { + t.Errorf("expected %v, but got %v", test.expected, got) + } + } +} + +func TestCertificateAuthorityInitialization(t *testing.T) { + ca := &CertificateAuthority{ + SerialNumber: "0000000000", + } + if ca.SerialNumber != "0000000000" { + t.Errorf("expected serial number to be '0000000000', but got '%s'", ca.SerialNumber) + } +} diff --git a/pkg/types/certificate.go b/pkg/types/certificate.go index 7f298d4..ff5a5ee 100644 --- a/pkg/types/certificate.go +++ b/pkg/types/certificate.go @@ -2,8 +2,6 @@ package types import ( "bytes" - "crypto/elliptic" - "crypto/x509" "encoding/pem" ) @@ -17,55 +15,3 @@ type SignedCertificate struct { IntermediateCertificateChainPath string RootCertificateChainPath string } - -type PublicKeyAlgorithm struct { - Algorithm x509.PublicKeyAlgorithm - KeySize map[int]any - Signature map[string]bool - SigningAlgorithm map[x509.SignatureAlgorithm]bool -} - -var PublicKeyAlgorithms = map[string]PublicKeyAlgorithm{ - "RSA": { - Algorithm: x509.RSA, - KeySize: map[int]interface{}{ - 2048: true, - 4096: true, - }, - Signature: map[string]bool{ - "SHA256WITHRSA": true, - "SHA384WITHRSA": true, - "SHA512WITHRSA": true, - }, - SigningAlgorithm: map[x509.SignatureAlgorithm]bool{ - x509.SHA256WithRSA: true, - x509.SHA384WithRSA: true, - x509.SHA512WithRSA: true, - }, - }, - "ECDSA": { - Algorithm: x509.ECDSA, - KeySize: map[int]interface{}{ - 256: elliptic.P256(), - 384: elliptic.P384(), - 521: elliptic.P521(), - }, - Signature: map[string]bool{ - "SHA256WITHECDSA": true, - "SHA384WITHECDSA": true, - "SHA512WITHECDSA": true, - }, - SigningAlgorithm: map[x509.SignatureAlgorithm]bool{ - x509.ECDSAWithSHA256: true, - x509.ECDSAWithSHA384: true, - x509.ECDSAWithSHA512: true, - }, - }, - // TODO: Support Ed25519 - "Ed25519": { - Algorithm: x509.Ed25519, - KeySize: map[int]interface{}{ - 256: true, - }, - }, -} diff --git a/pkg/types/pk.go b/pkg/types/pk.go new file mode 100644 index 0000000..b0252d3 --- /dev/null +++ b/pkg/types/pk.go @@ -0,0 +1,58 @@ +package types + +import ( + "crypto/elliptic" + "crypto/x509" +) + +type PublicKeyAlgorithm struct { + Algorithm x509.PublicKeyAlgorithm + KeySize map[int]any + Signature map[string]bool + SigningAlgorithm map[x509.SignatureAlgorithm]bool +} + +var PublicKeyAlgorithms = map[string]PublicKeyAlgorithm{ + "RSA": { + Algorithm: x509.RSA, + KeySize: map[int]interface{}{ + 2048: true, + 4096: true, + }, + Signature: map[string]bool{ + "SHA256WITHRSA": true, + "SHA384WITHRSA": true, + "SHA512WITHRSA": true, + }, + SigningAlgorithm: map[x509.SignatureAlgorithm]bool{ + x509.SHA256WithRSA: true, + x509.SHA384WithRSA: true, + x509.SHA512WithRSA: true, + }, + }, + "ECDSA": { + Algorithm: x509.ECDSA, + KeySize: map[int]interface{}{ + 256: elliptic.P256(), + 384: elliptic.P384(), + 521: elliptic.P521(), + }, + Signature: map[string]bool{ + "SHA256WITHECDSA": true, + "SHA384WITHECDSA": true, + "SHA512WITHECDSA": true, + }, + SigningAlgorithm: map[x509.SignatureAlgorithm]bool{ + x509.ECDSAWithSHA256: true, + x509.ECDSAWithSHA384: true, + x509.ECDSAWithSHA512: true, + }, + }, + // TODO: Support Ed25519 + "Ed25519": { + Algorithm: x509.Ed25519, + KeySize: map[int]interface{}{ + 256: true, + }, + }, +} diff --git a/pkg/util/x509.go b/pkg/util/x509.go new file mode 100644 index 0000000..210e531 --- /dev/null +++ b/pkg/util/x509.go @@ -0,0 +1,84 @@ +package util + +import ( + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "os" + "path/filepath" + + apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" + "github.com/coinbase/baseca/pkg/types" +) + +func ParseCertificateFormat(certificate *apiv1.SignedCertificate, parameter types.SignedCertificate) error { + // Leaf Certificate Path + if len(parameter.CertificatePath) != 0 { + certificate := []byte(certificate.Certificate) + if err := os.WriteFile(parameter.CertificatePath, certificate, os.ModePerm); err != nil { + return fmt.Errorf("error writing certificate to [%s]", parameter.CertificatePath) + } + } + + // Intermediate Certificate Chain Path + if len(parameter.IntermediateCertificateChainPath) != 0 { + certificate := []byte(certificate.IntermediateCertificateChain) + if err := os.WriteFile(parameter.IntermediateCertificateChainPath, certificate, os.ModePerm); err != nil { + return fmt.Errorf("error writing certificate to [%s]", parameter.IntermediateCertificateChainPath) + } + } + + // Root Certificate Chain Path + if len(parameter.RootCertificateChainPath) != 0 { + certificate := []byte(certificate.CertificateChain) + if err := os.WriteFile(parameter.RootCertificateChainPath, certificate, os.ModePerm); err != nil { + return fmt.Errorf("error writing certificate chain to [%s]", parameter.RootCertificateChainPath) + } + } + return nil +} + +func GenerateCertificatePool(tc types.TrustChain) (*x509.CertPool, error) { + certPool := x509.NewCertPool() + + for _, dir := range tc.CertificateAuthorityDirectory { + files, err := os.ReadDir(dir) + if err != nil { + return nil, errors.New("invalid certificate authority directory") + } + + for _, certFile := range files { // #nosec G304 User Only Has Predefined Environment Parameters + data, err := os.ReadFile(filepath.Join(dir, certFile.Name())) + if err != nil { + return nil, errors.New("invalid certificate file") + } + pemBlock, _ := pem.Decode(data) + if pemBlock == nil || pemBlock.Type != "CERTIFICATE" { + return nil, errors.New("invalid input file") + } + cert, err := x509.ParseCertificate(pemBlock.Bytes) + if err != nil { + return nil, errors.New("error parsing x.509 certificate") + } + certPool.AddCert(cert) + } + } + + for _, ca := range tc.CertificateAuthorityFiles { + data, err := os.ReadFile(filepath.Clean(ca)) + if err != nil { + return nil, errors.New("invalid certificate authority file") + } + pemBlock, _ := pem.Decode(data) + if pemBlock == nil || pemBlock.Type != "CERTIFICATE" { + return nil, errors.New("invalid input file") + } + cert, err := x509.ParseCertificate(pemBlock.Bytes) + if err != nil { + return nil, errors.New("error parsing x.509 certificate") + } + certPool.AddCert(cert) + } + return certPool, nil +} diff --git a/pkg/util/x509_test.go b/pkg/util/x509_test.go new file mode 100644 index 0000000..6f169e4 --- /dev/null +++ b/pkg/util/x509_test.go @@ -0,0 +1,45 @@ +package util + +import ( + "os" + "path/filepath" + "testing" + + apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" + "github.com/coinbase/baseca/pkg/types" +) + +func TestParseCertificateFormat(t *testing.T) { + tempDir, err := os.MkdirTemp("", "certificate") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + certificatePath := filepath.Join(tempDir, "certificate.pem") + intermediateCertificateChainPath := filepath.Join(tempDir, "intermediate.pem") + rootCertificateChainPath := filepath.Join(tempDir, "root.pem") + + certificate := &apiv1.SignedCertificate{ + Certificate: "-----BEGIN CERTIFICATE-----", + IntermediateCertificateChain: "-----BEGIN CERTIFICATE-----", + CertificateChain: "-----BEGIN CERTIFICATE-----", + } + parameters := types.SignedCertificate{ + CertificatePath: certificatePath, + IntermediateCertificateChainPath: intermediateCertificateChainPath, + RootCertificateChainPath: rootCertificateChainPath, + } + + err = ParseCertificateFormat(certificate, parameters) + if err != nil { + t.Fatalf("failed to parse certificate format: %v", err) + } + + for _, path := range []string{certificatePath, intermediateCertificateChainPath, rootCertificateChainPath} { + _, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed to read file %s: %v", path, err) + } + } +} diff --git a/test/config.go b/test/config.go deleted file mode 100644 index e767b22..0000000 --- a/test/config.go +++ /dev/null @@ -1,62 +0,0 @@ -package test - -import ( - "fmt" - "os" - "path/filepath" - "runtime" - - "github.com/coinbase/baseca/internal/config" - "github.com/coinbase/baseca/internal/logger" -) - -const ( - configuration = "config.test.local.sandbox.yml" -) - -func GetTestConfigurationPath() (*config.Config, error) { - _, filename, _, ok := runtime.Caller(0) - if !ok { - fmt.Println("Error: Unable to get current file path") - } - - baseDir := filepath.Dir(filename) - for { - if _, err := os.Stat(filepath.Join(baseDir, "go.mod")); err == nil { - break - } - - parentDir := filepath.Dir(baseDir) - if parentDir == baseDir { - fmt.Println("Error: Unable to find base directory") - break - } - - baseDir = parentDir - } - - path := fmt.Sprintf("%s/test/config/%s", baseDir, configuration) - config, err := provideConfig(path) - if err != nil { - return nil, err - } - return config, nil -} - -func provideConfig(path string) (*config.Config, error) { - ctxLogger := logger.ContextLogger{Logger: logger.DefaultLogger} - - v, err := config.BuildViper(path) - if err != nil { - ctxLogger.Error(err.Error()) - } - - config, err := config.LoadConfig(v) - if err != nil { - return nil, err - } - - return config, err -} - -//