diff --git a/.dockerignore b/.dockerignore index 1b668dc..55811fa 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,12 +1,11 @@ config/*.yml docs/ examples/ -pkg/ test/ terraform/ -.github/ *.md .dockerignore .git .gitignore +.github/ diff --git a/.github/workflows/release_build.yml b/.github/workflows/release_build.yml index 50e06bd..2090490 100644 --- a/.github/workflows/release_build.yml +++ b/.github/workflows/release_build.yml @@ -8,7 +8,6 @@ env: REGISTRY: ghcr.io IMAGE_NAME: ${{ github.repository }} - jobs: build: @@ -21,40 +20,41 @@ jobs: - name: Checkout repository uses: actions/checkout@v3 - # Workaround: https://github.com/docker/build-push-action/issues/461 - - name: Setup Docker buildx - uses: docker/setup-buildx-action@79abd3f86f79a9d68a23c75a09a9a85889262adf + # QEMU + - name: QEMU + uses: docker/setup-qemu-action@v1 + + # Setup Docker BuildX + - name: Setup Docker BuildX + uses: docker/setup-buildx-action@v1 - # Login against a Docker registry except on PR - # https://github.com/docker/login-action - - name: Log into registry ${{ env.REGISTRY }} + # Login Docker Registry + - name: Log Registry ${{ env.REGISTRY }} if: github.event_name != 'pull_request' - uses: docker/login-action@28218f9b04b4f3f62068d7b6ce6ca5b26e35336c + uses: docker/login-action@v1 with: registry: ${{ env.REGISTRY }} username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - # Extract metadata (tags, labels) for Docker - # https://github.com/docker/metadata-action + # Extract Metadata for Docker - name: Extract Docker Metadata id: meta - uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38 + uses: docker/metadata-action@v3 with: images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} tags: | type=ref,event=tag type=sha - # Build and push Docker image with Buildx (don't push on PR) - # https://github.com/docker/build-push-action + # Build and Push Docker Image with BuildX - name: Build and Push Docker Image - id: build-and-push - uses: docker/build-push-action@ac9327eae2b366085ac7f6a2d02df8aa8ead720a + uses: docker/build-push-action@v2 with: context: . push: ${{ github.event_name != 'pull_request' }} tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} + platforms: linux/amd64,linux/arm64 cache-from: type=gha - cache-to: type=gha,mode=max + cache-to: type=gha,mode=max \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 782d916..d95c4bf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,24 +1,41 @@ -# Build Image -FROM golang:1.20 as builder +FROM golang:1.21 as builder + +# Docker BuildX Target Architecture +ARG TARGETARCH ENV CGO_ENABLED=0 WORKDIR /baseca COPY . /baseca -RUN apt update && apt clean && make build + +# Build ARM64 or AMD64 Binary +RUN apt update && apt clean && \ + if [ "$TARGETARCH" = "amd64" ]; then \ + make build_amd64; \ + elif [ "$TARGETARCH" = "arm64" ]; then \ + make build_arm64; \ + else \ + echo "Unsupported Architecture [$TARGETARCH]"; \ + exit 1; \ + fi # Deploy Image FROM alpine:3.17 +# Non-Root User RUN adduser --home /home/baseca baseca --gecos "baseca" --disabled-password && \ apk --no-cache add ca-certificates && \ rm -rf /var/cache/apk/* +# Copy Binary and Configuration from Build Image COPY --from=builder /baseca/target/bin/linux/baseca /home/baseca/baseca COPY --from=builder /baseca/config /home/baseca/config +# Permissions for Non-Root User RUN chown -R baseca:baseca /home/baseca +# Switch to Non-Root User USER baseca WORKDIR /home/baseca -CMD ["/home/baseca/baseca"] \ No newline at end of file +# Execute coinbase/baseca +CMD ["/home/baseca/baseca"] diff --git a/Makefile b/Makefile index d6c7316..be79d03 100644 --- a/Makefile +++ b/Makefile @@ -29,9 +29,17 @@ test: info clean dependencies .PHONY: build build: info clean - @ GOOS=darwin GOARCH=amd64 go build $(LDFLAGS) -o $(BIN)/darwin/$(SERVICE) cmd/baseca/server.go + @ GOOS=darwin GOARCH=amd64 go build $(LDFLAGS) -o $(BIN)/amd64/$(SERVICE) cmd/baseca/server.go + @ GOOS=darwin GOARCH=arm64 go build $(LDFLAGS) -o $(BIN)/arm64/$(SERVICE) cmd/baseca/server.go + +.PHONY: build_amd64 +build_amd64: info clean @ GOOS=linux GOARCH=amd64 go build $(LDFLAGS) -o $(BIN)/linux/$(SERVICE) cmd/baseca/server.go +.PHONY: build_arm64 +build_arm64: info clean + @ GOOS=linux GOARCH=arm64 go build $(LDFLAGS) -o $(BIN)/linux/$(SERVICE) cmd/baseca/server.go + .PHONY: sqlc sqlc: @ sqlc generate -f db/sqlc.yaml diff --git a/config/aws/ec2.amazonaws.com.crt b/config/aws/ec2.amazonaws.com.crt new file mode 100644 index 0000000..7e3885b --- /dev/null +++ b/config/aws/ec2.amazonaws.com.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDIjCCAougAwIBAgIJAKnL4UEDMN/FMA0GCSqGSIb3DQEBBQUAMGoxCzAJBgNV +BAYTAlVTMRMwEQYDVQQIEwpXYXNoaW5ndG9uMRAwDgYDVQQHEwdTZWF0dGxlMRgw +FgYDVQQKEw9BbWF6b24uY29tIEluYy4xGjAYBgNVBAMTEWVjMi5hbWF6b25hd3Mu +Y29tMB4XDTE0MDYwNTE0MjgwMloXDTI0MDYwNTE0MjgwMlowajELMAkGA1UEBhMC +VVMxEzARBgNVBAgTCldhc2hpbmd0b24xEDAOBgNVBAcTB1NlYXR0bGUxGDAWBgNV +BAoTD0FtYXpvbi5jb20gSW5jLjEaMBgGA1UEAxMRZWMyLmFtYXpvbmF3cy5jb20w +gZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAIe9GN//SRK2knbjySG0ho3yqQM3 +e2TDhWO8D2e8+XZqck754gFSo99AbT2RmXClambI7xsYHZFapbELC4H91ycihvrD +jbST1ZjkLQgga0NE1q43eS68ZeTDccScXQSNivSlzJZS8HJZjgqzBlXjZftjtdJL +XeE4hwvo0sD4f3j9AgMBAAGjgc8wgcwwHQYDVR0OBBYEFCXWzAgVyrbwnFncFFIs +77VBdlE4MIGcBgNVHSMEgZQwgZGAFCXWzAgVyrbwnFncFFIs77VBdlE4oW6kbDBq +MQswCQYDVQQGEwJVUzETMBEGA1UECBMKV2FzaGluZ3RvbjEQMA4GA1UEBxMHU2Vh +dHRsZTEYMBYGA1UEChMPQW1hem9uLmNvbSBJbmMuMRowGAYDVQQDExFlYzIuYW1h +em9uYXdzLmNvbYIJAKnL4UEDMN/FMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF +BQADgYEAFYcz1OgEhQBXIwIdsgCOS8vEtiJYF+j9uO6jz7VOmJqO+pRlAbRlvY8T +C1haGgSI/A1uZUKs/Zfnph0oEI0/hu1IIJ/SKBDtN5lvmZ/IzbOPIJWirlsllQIQ +7zvWbGd9c9+Rm3p04oTvhup99la7kZqevJK0QRdD/6NpCKsqP/0= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/config/certificate_authority/rds.global.bundle.pem b/config/aws/rds.global.bundle.pem similarity index 100% rename from config/certificate_authority/rds.global.bundle.pem rename to config/aws/rds.global.bundle.pem diff --git a/db/sqlc/common.go b/db/sqlc/common.go index 36541dd..cf07fb9 100644 --- a/db/sqlc/common.go +++ b/db/sqlc/common.go @@ -1,25 +1,16 @@ package db -import "github.com/coinbase/baseca/internal/types" - -type CertificateResponseData struct { - Certificate string `json:"certificate"` - IntermediateCertificateChain string `json:"intermediate_certificate_chain,omitempty"` - RootCertificateChain string `json:"root_certificate_chain,omitempty"` - Metadata types.CertificateMetadata `json:"metadata"` -} - type DatabaseEndpoints struct { Writer Store Reader Store } -type CachedServiceAccount struct { +type ServiceAccountAttestation struct { ServiceAccount Account `json:"service_account"` AwsIid AwsAttestation `json:"aws_iid"` } -type CachedProvisionerAccount struct { +type ProvisionerAccountAttestation struct { ProvisionerAccount Provisioner `json:"provisioner_account"` AwsIid AwsAttestation `json:"aws_iid"` } diff --git a/db/sqlc/tx_provisioner_account.go b/db/sqlc/tx_provisioner_account.go index 8bcc741..b9e1713 100644 --- a/db/sqlc/tx_provisioner_account.go +++ b/db/sqlc/tx_provisioner_account.go @@ -3,6 +3,7 @@ package db import ( "context" + "github.com/coinbase/baseca/internal/types" "github.com/google/uuid" ) @@ -19,7 +20,7 @@ func (store *SQLStore) TxCreateProvisionerAccount(ctx context.Context, arg Creat for _, node_attestation := range arg.NodeAttestation { switch node_attestation { - case "AWS_IID": + case types.AWS_IID.String(): // Add to AWS_IID Database _, err = store.StoreInstanceIdentityDocument(ctx, iid) if err != nil { diff --git a/db/sqlc/tx_service_account.go b/db/sqlc/tx_service_account.go index e00515b..1726516 100644 --- a/db/sqlc/tx_service_account.go +++ b/db/sqlc/tx_service_account.go @@ -3,6 +3,7 @@ package db import ( "context" + "github.com/coinbase/baseca/internal/types" "github.com/google/uuid" ) @@ -19,7 +20,7 @@ func (store *SQLStore) TxCreateServiceAccount(ctx context.Context, arg CreateSer for _, node_attestation := range arg.NodeAttestation { switch node_attestation { - case "AWS_IID": + case types.AWS_IID.String(): // Add to AWS_IID Database _, err = store.StoreInstanceIdentityDocument(ctx, iid) if err != nil { diff --git a/db/sqlc/tx_update_account.go b/db/sqlc/tx_update_account.go index 63615fa..f6d2577 100644 --- a/db/sqlc/tx_update_account.go +++ b/db/sqlc/tx_update_account.go @@ -26,19 +26,19 @@ func (store *SQLStore) TxUpdateServiceAccount(ctx context.Context, arg Account, NodeAttestation: arg.NodeAttestation, } - raw_message, err := validator.MapToNullRawMessage(attestation.AWSInstanceIdentityDocument.InstanceTags) + raw_message, err := validator.MapToNullRawMessage(attestation.EC2NodeAttestation.InstanceTags) if err != nil { return nil, err } iid := StoreInstanceIdentityDocumentParams{ ClientID: arg.ClientID, - RoleArn: sql.NullString{String: attestation.AWSInstanceIdentityDocument.RoleArn, Valid: len(attestation.AWSInstanceIdentityDocument.RoleArn) != 0}, - AssumeRole: sql.NullString{String: attestation.AWSInstanceIdentityDocument.AssumeRole, Valid: len(attestation.AWSInstanceIdentityDocument.AssumeRole) != 0}, - SecurityGroupID: attestation.AWSInstanceIdentityDocument.SecurityGroups, - Region: sql.NullString{String: attestation.AWSInstanceIdentityDocument.Region, Valid: len(attestation.AWSInstanceIdentityDocument.Region) != 0}, - InstanceID: sql.NullString{String: attestation.AWSInstanceIdentityDocument.InstanceID, Valid: len(attestation.AWSInstanceIdentityDocument.InstanceID) != 0}, - ImageID: sql.NullString{String: attestation.AWSInstanceIdentityDocument.ImageID, Valid: len(attestation.AWSInstanceIdentityDocument.ImageID) != 0}, + RoleArn: sql.NullString{String: attestation.EC2NodeAttestation.RoleArn, Valid: len(attestation.EC2NodeAttestation.RoleArn) != 0}, + AssumeRole: sql.NullString{String: attestation.EC2NodeAttestation.AssumeRole, Valid: len(attestation.EC2NodeAttestation.AssumeRole) != 0}, + SecurityGroupID: attestation.EC2NodeAttestation.SecurityGroups, + Region: sql.NullString{String: attestation.EC2NodeAttestation.Region, Valid: len(attestation.EC2NodeAttestation.Region) != 0}, + InstanceID: sql.NullString{String: attestation.EC2NodeAttestation.InstanceID, Valid: len(attestation.EC2NodeAttestation.InstanceID) != 0}, + ImageID: sql.NullString{String: attestation.EC2NodeAttestation.ImageID, Valid: len(attestation.EC2NodeAttestation.ImageID) != 0}, InstanceTags: raw_message, } @@ -52,7 +52,7 @@ func (store *SQLStore) TxUpdateServiceAccount(ctx context.Context, arg Account, for _, node_attestation := range arg.NodeAttestation { switch node_attestation { - case types.Attestation.AWS_IID: + case types.AWS_IID.String(): // Add to AWS_IID Database _, err = store.StoreInstanceIdentityDocument(ctx, iid) if err != nil { diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index 3b11446..3159161 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -221,13 +221,11 @@ ssl_mode: disable Compile the Golang Binary `baseca` ```sh -# Darwin AMD64 -GOOS=darwin GOARCH=amd64 go build -o target/bin/darwin/baseca cmd/baseca/server.go -database_credentials=secret ./target/bin/darwin/baseca +cd /path/to/baseca +make build -# Linux AMD64 -GOOS=linux GOARCH=amd64 go build -o target/bin/linux/baseca cmd/baseca/server.go -database_credentials=secret ./target/bin/linux/baseca +# Update Path Based on AMD64 or ARM64 Architecture +database_credentials=secret ./target/bin/arm64/baseca ``` ## Signing x.509 Certificate diff --git a/examples/certificate/baseca.v1.Certificate/code_sign.go b/examples/baseca.v1.Certificate/code_sign.go similarity index 71% rename from examples/certificate/baseca.v1.Certificate/code_sign.go rename to examples/baseca.v1.Certificate/code_sign.go index 1b3956f..7454f73 100644 --- a/examples/certificate/baseca.v1.Certificate/code_sign.go +++ b/examples/baseca.v1.Certificate/code_sign.go @@ -2,7 +2,7 @@ package examples import ( "crypto/x509" - "fmt" + "log" "os" baseca "github.com/coinbase/baseca/pkg/client" @@ -22,12 +22,12 @@ func CodeSign() { client, err := baseca.LoadDefaultConfiguration(configuration, baseca.Attestation.Local, authentication) if err != nil { - fmt.Println(err) + log.Fatal(err) } metadata := baseca.CertificateRequest{ - CommonName: "sandbox.coinbase.com", - SubjectAlternateNames: []string{"sandbox.coinbase.com"}, + CommonName: "example.coinbase.com", + SubjectAlternateNames: []string{"example.coinbase.com"}, SigningAlgorithm: x509.ECDSAWithSHA384, PublicKeyAlgorithm: x509.ECDSA, KeySize: 256, @@ -45,28 +45,33 @@ func CodeSign() { } data, _ := os.ReadFile("/bin/chmod") - signature, chain, err := client.GenerateSignature(metadata, data) + signature, chain, err := client.GenerateSignature(metadata, &data) if err != nil { - panic(err) + log.Fatal(err) } // Validation Happens on Different Server manifest := types.Manifest{ CertificateChain: chain, - Signature: *signature, - Data: data, - SigningAlgorithm: x509.SHA256WithRSA, + Signature: signature, + SigningAlgorithm: x509.ECDSAWithSHA512, + Data: types.Data{ + Path: types.Path{ + File: "/bin/chmod", + Buffer: 4096, + }, + }, } tc := types.TrustChain{ CommonName: "sandbox.coinbase.com", - CertificateAuthorityFiles: []string{"/path/to/intermediate.pem"}, + CertificateAuthorityFiles: []string{"/path/to/intermediate_ca.crt"}, } - err = client.ValidateSignature(tc, manifest) + err = baseca.ValidateSignature(tc, manifest) if err != nil { - panic(err) + log.Fatal(err) } - fmt.Println("Signature Verified") + log.Print("Signature Verified") } diff --git a/examples/certificate/baseca.v1.Certificate/operations_sign_csr.go b/examples/baseca.v1.Certificate/operations_sign_csr.go similarity index 98% rename from examples/certificate/baseca.v1.Certificate/operations_sign_csr.go rename to examples/baseca.v1.Certificate/operations_sign_csr.go index b1d0e99..341c361 100644 --- a/examples/certificate/baseca.v1.Certificate/operations_sign_csr.go +++ b/examples/baseca.v1.Certificate/operations_sign_csr.go @@ -2,7 +2,6 @@ package examples import ( "crypto/x509" - "fmt" "log" apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" @@ -22,7 +21,7 @@ func OperationsSignCSR() { client, err := baseca.LoadDefaultConfiguration(configuration, baseca.Attestation.Local, authentication) if err != nil { - fmt.Println(err) + log.Fatal(err) } certAuth := apiv1.CertificateAuthorityParameter{ diff --git a/examples/certificate/baseca.v1.Certificate/sign_csr.go b/examples/baseca.v1.Certificate/sign_csr.go similarity index 81% rename from examples/certificate/baseca.v1.Certificate/sign_csr.go rename to examples/baseca.v1.Certificate/sign_csr.go index 8597e1a..a8d9459 100644 --- a/examples/certificate/baseca.v1.Certificate/sign_csr.go +++ b/examples/baseca.v1.Certificate/sign_csr.go @@ -2,7 +2,6 @@ package examples import ( "crypto/x509" - "fmt" "log" baseca "github.com/coinbase/baseca/pkg/client" @@ -21,15 +20,15 @@ func SignCSR() { client, err := baseca.LoadDefaultConfiguration(configuration, baseca.Attestation.Local, authentication) if err != nil { - fmt.Println(err) + log.Fatal(err) } metadata := baseca.CertificateRequest{ - CommonName: "sandbox.coinbase.com", - SubjectAlternateNames: []string{"sandbox.coinbase.com"}, - SigningAlgorithm: x509.SHA384WithRSA, - PublicKeyAlgorithm: x509.RSA, - KeySize: 4096, + CommonName: "example.coinbase.com", + SubjectAlternateNames: []string{"example.coinbase.com"}, + SigningAlgorithm: x509.ECDSAWithSHA384, + PublicKeyAlgorithm: x509.ECDSA, + KeySize: 256, DistinguishedName: baseca.DistinguishedName{ Organization: []string{"Coinbase"}, // Additional Fields diff --git a/go.mod b/go.mod index 85a00d4..ef01d4e 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/lib/pq v1.10.9 github.com/mitchellh/mapstructure v1.5.0 github.com/rs/zerolog v1.30.0 + github.com/shirou/gopsutil v3.21.11+incompatible github.com/spf13/viper v1.16.0 github.com/sqlc-dev/pqtype v0.2.0 github.com/stretchr/testify v1.8.4 @@ -50,6 +51,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect + github.com/go-ole/go-ole v1.2.6 // indirect github.com/gogo/googleapis v0.0.0-20180223154316-0cd9801be74a // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.3 // indirect @@ -69,9 +71,12 @@ require ( github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect + github.com/tklauser/go-sysconf v0.3.13 // indirect + github.com/tklauser/numcpus v0.7.0 // indirect + github.com/yusufpapurcu/wmi v1.2.3 // indirect go.uber.org/dig v1.17.0 // indirect go.uber.org/multierr v1.10.0 // indirect - golang.org/x/sys v0.13.0 // indirect + golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.13.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/go.sum b/go.sum index 2718356..ef74357 100644 --- a/go.sum +++ b/go.sum @@ -116,6 +116,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -247,6 +249,8 @@ github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/f github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c= github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= +github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= +github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/spf13/afero v1.9.5 h1:stMpOSZFs//0Lv29HduCmli3GUfpFoF3Y1Q/aXj/wVM= github.com/spf13/afero v1.9.5/go.mod h1:UBogFpq8E9Hx+xc5CNTTEpTnuHVmXDwZcZcE1eb/UhQ= @@ -283,6 +287,10 @@ github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4= +github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0= +github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4= +github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY= github.com/wagslane/go-password-validator v0.3.0 h1:vfxOPzGHkz5S146HDpavl0cw1DSVP061Ry2PX0/ON6I= github.com/wagslane/go-password-validator v0.3.0/go.mod h1:TI1XJ6T5fRdRnHqHt14pvy1tNVnrwe7m3/f1f2fDphQ= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -290,6 +298,8 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw= +github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= @@ -424,6 +434,7 @@ golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -457,8 +468,8 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= diff --git a/internal/attestation/aws_iid/iid.go b/internal/attestation/aws_iid/iid.go index 8ef18db..69e5360 100644 --- a/internal/attestation/aws_iid/iid.go +++ b/internal/attestation/aws_iid/iid.go @@ -3,212 +3,99 @@ package aws_iid import ( "context" "crypto/sha256" - "crypto/x509" - "encoding/base64" "encoding/hex" "encoding/json" - "encoding/pem" "fmt" - "os" - "path/filepath" - "regexp" "github.com/allegro/bigcache/v3" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/credentials/stscreds" - "github.com/aws/aws-sdk-go-v2/service/ec2" - ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/aws/aws-sdk-go-v2/service/sts" - db "github.com/coinbase/baseca/db/sqlc" - apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" + "github.com/coinbase/baseca/internal/client/ec2" "github.com/coinbase/baseca/internal/types" - "github.com/gogo/status" - "github.com/google/uuid" - "google.golang.org/grpc/codes" ) const ( // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/verify-signature.html (Other AWS Regions) - aws_certificate_path = "internal/attestor/aws_iid/certificate/ec2.amazonaws.com.crt" + aws_certificate_path = "config/aws/ec2.amazonaws.com.crt" ) -type InstanceIdentityDocument struct { - AccountId string `json:"accountId"` - Architecture string `json:"architecture"` - AvailabilityZone string `json:"availabilityZone"` - ImageId string `json:"imageId"` - InstanceId string `json:"instanceId"` - InstanceType string `json:"instanceType"` - PrivateIp string `json:"privateIp"` - Region string `json:"region"` - Version string `json:"version"` -} - -var ( - instanceFilters = []ec2types.Filter{ - { - Name: aws.String("instance-state-name"), - Values: []string{ - "pending", - "running", - }, - }, - } -) - -func buildEC2Client(region string, roleARN string) (*ec2.Client, error) { - cfg, err := config.LoadDefaultConfig( - context.TODO(), - config.WithRegion(region), - ) +func AWSIidNodeAttestation(node types.NodeIIDAttestation, cache *bigcache.BigCache) error { + err := validateMetadataSignature(node.EC2InstanceMetadata) if err != nil { - return nil, err - } - - if isValidRoleArn(roleARN) { - stsSvc := sts.NewFromConfig(cfg) - cfg.Credentials = stscreds.NewAssumeRoleProvider(stsSvc, roleARN) + return err } - svc := ec2.NewFromConfig(cfg) - return svc, nil -} - -func isValidRoleArn(arn string) bool { - pattern := `^arn:aws:iam::[0-9]{12}:role\/[a-zA-Z0-9+=,.@_-]{1,64}$` - re := regexp.MustCompile(pattern) - return re.MatchString(arn) -} - -func validateMetadataSignature(iid types.EC2InstanceMetadata) error { - certificate, err := os.ReadFile(filepath.Clean(aws_certificate_path)) + instance_identity_document := types.InstanceIdentityDocument{} + err = json.Unmarshal(node.EC2InstanceMetadata.InstanceIdentityDocument, &instance_identity_document) if err != nil { - return fmt.Errorf("error reading aws certificate for signature validation") + return fmt.Errorf("error unmarshal aws_iid metadata") } - rsa_certificate_pem, _ := pem.Decode([]byte(certificate)) - rsa_certificate, _ := x509.ParseCertificate(rsa_certificate_pem.Bytes) - signature, _ := base64.StdEncoding.DecodeString(string(iid.InstanceIdentitySignature)) - - err = rsa_certificate.CheckSignature(x509.SHA256WithRSA, iid.InstanceIdentityDocument, signature) + err = searchIidCache(node, cache) if err != nil { - return fmt.Errorf("invalid aws_iid signature") + return err } - return nil } -func GetInstanceIdentityDocument(ctx context.Context, db_reader db.Store, client_id uuid.UUID) (*db.AwsAttestation, error) { - node_attestation, err := db_reader.GetInstanceIdentityDocument(ctx, client_id) - if err != nil { - return nil, fmt.Errorf("error retrieving aws_attestation from db, %s", err) - } - return node_attestation, nil -} - -func AWSIidNodeAttestation(client_uuid uuid.UUID, header_metadata string, iid db.AwsAttestation, cache *bigcache.BigCache) error { - var client *ec2.Client - var instance ec2types.Instance - var err error - - request_metadata_byte := []byte(header_metadata) - instance_metadata := types.EC2InstanceMetadata{} - instance_identity_document := InstanceIdentityDocument{} - - err = json.Unmarshal(request_metadata_byte, &instance_metadata) - if err != nil { - return fmt.Errorf("error unmarshal aws_instance metadata") - } - - err = validateMetadataSignature(instance_metadata) - if err != nil { - return err - } - - err = json.Unmarshal(instance_metadata.InstanceIdentityDocument, &instance_identity_document) - if err != nil { - return fmt.Errorf("error unmarshal aws_iid metadata") - } - - // Query Instance Metadata in Cache - hash := sha256.Sum256(instance_metadata.InstanceIdentityDocument) +// Query Instance Metadata in Cache +func searchIidCache(node types.NodeIIDAttestation, cache *bigcache.BigCache) error { + hash := sha256.Sum256(node.EC2InstanceMetadata.InstanceIdentityDocument) hash_key := hex.EncodeToString(hash[:]) - if value, cached := cache.Get(hash_key); cached != nil { - client, err = buildEC2Client(instance_identity_document.Region, iid.AssumeRole.String) - if err != nil { - return fmt.Errorf("error building ec2 client, %s", err) - } - - instancesDesc, err := client.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{ - InstanceIds: []string{instance_identity_document.InstanceId}, - Filters: instanceFilters, - }) - if err != nil { - return fmt.Errorf("ec2 describe instances failed, %s", err) - } - - instance, err = getEC2Instance(instancesDesc) - if err != nil { - return fmt.Errorf("error querying ec2 instance, %s", err) - } - - // IAM Role Arn Attestation - if iid.RoleArn.Valid { - if *instance.IamInstanceProfile.Arn != iid.RoleArn.String { - return fmt.Errorf("aws_iid role arn attestation error [client_id %s] [instance_identity_document %s]", client_uuid, instance_identity_document) - } - } - - data, err := json.Marshal(iid) + if value, err := cache.Get(hash_key); err != nil { + // Cache Missed + err := setIidCache(node, cache) if err != nil { - return fmt.Errorf("error marshalling cached_service_account, %s", err) - } - err = cache.Set(hash_key, data) - if err != nil { - return fmt.Errorf("error setting hashed aws_iid in cache, %s", err) + return fmt.Errorf("error setting iid cache, %s", err) } } else { - // SHA-256 Instance Identity Document [Key], Client ID [Value]. Multiple Instances Map to Single Client ID. - err = json.Unmarshal(value, &iid) + // SHA-256 Instance Identity Document [Key] + // []byte Instance Identity Document [Value] + document := types.InstanceIdentityDocument{} + err := json.Unmarshal(value, &document) if err != nil { return fmt.Errorf("error unmarshal hashed iid in cached, %s", err) } - attested_client_id := iid.ClientID - if attested_client_id != client_uuid { - return fmt.Errorf("request client id does not match attested node in cache") - } } return nil } -func getEC2Instance(instancesDesc *ec2.DescribeInstancesOutput) (ec2types.Instance, error) { - if len(instancesDesc.Reservations) < 1 { - return ec2types.Instance{}, status.Error(codes.Internal, "failed to query AWS via describe-instances: returned no reservations") +func setIidCache(node types.NodeIIDAttestation, cache *bigcache.BigCache) error { + hash := sha256.Sum256(node.EC2InstanceMetadata.InstanceIdentityDocument) + hash_key := hex.EncodeToString(hash[:]) + + client, err := ec2.NewEC2Client(node.Attestation.Region, node.Attestation.AssumeRole) + if err != nil { + return fmt.Errorf("error building ec2 client, %s", err) } - if len(instancesDesc.Reservations[0].Instances) < 1 { - return ec2types.Instance{}, status.Error(codes.Internal, "failed to query AWS via describe-instances: returned no instances") + // Instance ID from EC2 IID + iid := types.InstanceIdentityDocument{} + err = json.Unmarshal(node.EC2InstanceMetadata.InstanceIdentityDocument, &iid) + if err != nil { + return fmt.Errorf("error unmarshal aws_iid metadata") } - return instancesDesc.Reservations[0].Instances[0], nil -} + instance, err := ec2.QueryInstanceMetadata(context.Background(), client, []string{iid.InstanceId}) + if err != nil { + return err + } -func GetNodeAttestation(node_attestation *apiv1.NodeAttestation) []string { - var valid_attestation []string - var iid = node_attestation.AwsIid + // IAM Role Arn Attestation + if len(node.Attestation.RoleArn) > 0 { + if *instance.IamInstanceProfile.Arn != node.Attestation.RoleArn { + return fmt.Errorf("aws_iid role arn attestation error [client_id %s] %s", node.Uuid, string(node.EC2InstanceMetadata.InstanceIdentityDocument)) + } + } - // AWS Node Attestation - if iid != nil { - attestation := iid.RoleArn == "" && iid.AssumeRole == "" && len(iid.SecurityGroups) == 0 && - iid.Region == "" && iid.InstanceId == "" && iid.ImageId == "" && - len(iid.InstanceTags) == 0 + data, err := json.Marshal(iid) + if err != nil { + return fmt.Errorf("error marshalling cached_service_account, %s", err) + } - if !attestation { - valid_attestation = append(valid_attestation, types.Attestation.AWS_IID) - } + err = cache.Set(hash_key, data) + if err != nil { + return fmt.Errorf("error setting hashed aws_iid in cache, %s", err) } - return valid_attestation + return nil } diff --git a/internal/attestation/aws_iid/iid_test.go b/internal/attestation/aws_iid/iid_test.go deleted file mode 100644 index 262c914..0000000 --- a/internal/attestation/aws_iid/iid_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package aws_iid - -import ( - "testing" -) - -func TestIsValidRoleArn(t *testing.T) { - tests := []struct { - name string - arn string - want bool - }{ - { - name: "Valid ARN", - arn: "arn:aws:iam::123456789012:role/Example", - want: true, - }, - { - name: "Invalid ARN", - arn: "invalid:arn:format", - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := isValidRoleArn(tt.arn) - if got != tt.want { - t.Errorf("isValidRoleArn() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/attestation/aws_iid/node.go b/internal/attestation/aws_iid/node.go new file mode 100644 index 0000000..975c403 --- /dev/null +++ b/internal/attestation/aws_iid/node.go @@ -0,0 +1,53 @@ +package aws_iid + +import ( + "context" + "fmt" + + db "github.com/coinbase/baseca/db/sqlc" + apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" + "github.com/coinbase/baseca/internal/lib/util/validator" + "github.com/coinbase/baseca/internal/types" + "github.com/google/uuid" +) + +func GetInstanceIdentityDocument(ctx context.Context, db_reader db.Store, client_id uuid.UUID) (*types.EC2NodeAttestation, error) { + node_attestation, err := db_reader.GetInstanceIdentityDocument(ctx, client_id) + if err != nil { + return nil, fmt.Errorf("error retrieving aws_attestation from db, %s", err) + } + + instance_tag_map, err := validator.ConvertNullRawMessageToMap(node_attestation.InstanceTags) + if err != nil { + return nil, err + } + + return &types.EC2NodeAttestation{ + ClientID: node_attestation.ClientID, + RoleArn: node_attestation.RoleArn.String, + AssumeRole: node_attestation.AssumeRole.String, + SecurityGroups: node_attestation.SecurityGroupID, + Region: node_attestation.Region.String, + InstanceID: node_attestation.InstanceID.String, + ImageID: node_attestation.ImageID.String, + InstanceTags: instance_tag_map, + }, nil +} + +func GetNodeAttestation(node_attestation *apiv1.NodeAttestation) []string { + var valid_attestation []string + var iid = node_attestation.AwsIid + + // AWS Node Attestation + if iid != nil { + attestation := iid.RoleArn == "" && iid.AssumeRole == "" && len(iid.SecurityGroups) == 0 && + iid.Region == "" && iid.InstanceId == "" && iid.ImageId == "" && + len(iid.InstanceTags) == 0 + + if !attestation { + valid_attestation = append(valid_attestation, types.AWS_IID.String()) + } + } + + return valid_attestation +} diff --git a/internal/attestation/aws_iid/signature.go b/internal/attestation/aws_iid/signature.go new file mode 100644 index 0000000..fe7550d --- /dev/null +++ b/internal/attestation/aws_iid/signature.go @@ -0,0 +1,30 @@ +package aws_iid + +import ( + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "os" + "path/filepath" + + "github.com/coinbase/baseca/pkg/attestor/aws_iid" +) + +func validateMetadataSignature(iid aws_iid.EC2InstanceMetadata) error { + certificate, err := os.ReadFile(filepath.Clean(aws_certificate_path)) + if err != nil { + return fmt.Errorf("error reading aws certificate for signature validation") + } + + rsa_certificate_pem, _ := pem.Decode([]byte(certificate)) + rsa_certificate, _ := x509.ParseCertificate(rsa_certificate_pem.Bytes) + signature, _ := base64.StdEncoding.DecodeString(string(iid.InstanceIdentitySignature)) + + err = rsa_certificate.CheckSignature(x509.SHA256WithRSA, iid.InstanceIdentityDocument, signature) + if err != nil { + return fmt.Errorf("invalid aws_iid signature") + } + + return nil +} diff --git a/internal/client/acmpca/issue.go b/internal/client/acmpca/issue.go index 34512f9..279ae39 100644 --- a/internal/client/acmpca/issue.go +++ b/internal/client/acmpca/issue.go @@ -14,6 +14,7 @@ import ( pca_types "github.com/aws/aws-sdk-go-v2/service/acmpca/types" apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" "github.com/coinbase/baseca/internal/types" + lib "github.com/coinbase/baseca/pkg/types" ) const ( @@ -29,7 +30,7 @@ func (c *PrivateCaClient) IssueCertificateFromTemplate(parameters *apiv1.Certifi return nil, err } - signingAlgorithm, ok := types.ValidSignatures[parameters.SignAlgorithm] + signingAlgorithm, ok := lib.ValidSignatures[parameters.SignAlgorithm] if !ok { return nil, fmt.Errorf("signature algorithm %s invalid", parameters.SignAlgorithm) } @@ -70,7 +71,7 @@ func (c *PrivateCaClient) IssueSubordinateCertificate(parameters types.Certifica return nil, err } - signingAlgorithm, ok := types.ValidSignatures[algorithm] + signingAlgorithm, ok := lib.ValidSignatures[algorithm] if !ok { return nil, fmt.Errorf("signature algorithm %s invalid", algorithm) } diff --git a/internal/client/ec2/client.go b/internal/client/ec2/client.go new file mode 100644 index 0000000..f454b1a --- /dev/null +++ b/internal/client/ec2/client.go @@ -0,0 +1,37 @@ +package ec2 + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +type EC2ClientIface interface { + DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) +} + +type EC2Client struct { + Service EC2ClientIface +} + +func NewEC2Client(region string, roleArn string) (*ec2.Client, error) { + cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(region)) + if err != nil { + return nil, fmt.Errorf("unable to load ec2 sdk config, %v", err) + } + + stsClient := sts.NewFromConfig(cfg) + credentials := stscreds.NewAssumeRoleProvider(stsClient, roleArn) + + assumedRoleConfig := aws.Config{ + Credentials: aws.NewCredentialsCache(credentials), + Region: cfg.Region, + } + + return ec2.NewFromConfig(assumedRoleConfig), nil +} diff --git a/internal/client/ec2/metadata.go b/internal/client/ec2/metadata.go new file mode 100644 index 0000000..d906a60 --- /dev/null +++ b/internal/client/ec2/metadata.go @@ -0,0 +1,44 @@ +package ec2 + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go/aws" + "github.com/gogo/status" + "google.golang.org/grpc/codes" +) + +var ( + instanceFilters = []types.Filter{ + { + Name: aws.String("instance-state-name"), + Values: []string{ + "pending", + "running", + }, + }, + } +) + +func QueryInstanceMetadata(ctx context.Context, c *ec2.Client, instanceIds []string) (*types.Instance, error) { + instancesDesc, err := c.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{ + InstanceIds: instanceIds, + Filters: instanceFilters, + }) + if err != nil { + return nil, fmt.Errorf("ec2 describe instances failed, %s", err) + } + + if len(instancesDesc.Reservations) < 1 { + return &types.Instance{}, status.Error(codes.Internal, "failed to query AWS via describe-instances: returned no reservations") + } + + if len(instancesDesc.Reservations[0].Instances) < 1 { + return &types.Instance{}, status.Error(codes.Internal, "failed to query AWS via describe-instances: returned no instances") + } + + return &instancesDesc.Reservations[0].Instances[0], nil +} diff --git a/internal/gateway/fx.go b/internal/gateway/fx.go index a063b0c..deed5df 100644 --- a/internal/gateway/fx.go +++ b/internal/gateway/fx.go @@ -15,6 +15,7 @@ import ( "github.com/coinbase/baseca/internal/client/secretsmanager" "github.com/coinbase/baseca/internal/config" lib "github.com/coinbase/baseca/internal/lib/authentication" + "github.com/coinbase/baseca/internal/lib/util" "github.com/coinbase/baseca/internal/logger" "github.com/coinbase/baseca/internal/v1/accounts" "github.com/coinbase/baseca/internal/v1/certificate" @@ -105,12 +106,15 @@ func StartRPC(lc fx.Lifecycle, cfg *config.Config) error { term := make(chan error) var grpcServer *grpc.Server + // Monitor CPU Load + go util.UpdateCPULoad() + lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { // RPC Middleware Logger logInterceptor := logger.RpcLogger(extractor) - interceptors := grpc_middleware.ChainUnaryServer(server.Middleware.ServerAuthenticationInterceptor, logInterceptor) + interceptors := grpc_middleware.ChainUnaryServer(server.Middleware.SetAuthenticationContext, logInterceptor, server.Middleware.ServerAuthenticationInterceptor) grpcServer = grpc.NewServer(grpc.UnaryInterceptor(interceptors)) // Service Registration @@ -196,7 +200,7 @@ func GetPgConn(conf config.DatabaseConfig, endpoint, credentials string) (*sql.D if conf.SSLMode == "disable" { dataSource = fmt.Sprintf("%s sslmode=disable", dataSource) } else { - dataSource = fmt.Sprintf("%s sslmode=verify-full sslrootcert=config/certificate_authority/rds.global.bundle.pem", dataSource) + dataSource = fmt.Sprintf("%s sslmode=verify-full sslrootcert=config/aws/rds.global.bundle.pem", dataSource) } // Open Database Connection diff --git a/internal/lib/authentication/credentials.go b/internal/lib/authentication/credentials.go index 4d80d5c..bc2c972 100644 --- a/internal/lib/authentication/credentials.go +++ b/internal/lib/authentication/credentials.go @@ -9,7 +9,7 @@ import ( ) func HashPassword(password string) (string, error) { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), 15) + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), 10) if err != nil { return "", fmt.Errorf("failed to hash password %w", err) } diff --git a/internal/lib/crypto/csr.go b/internal/lib/crypto/csr.go deleted file mode 100644 index a81fe61..0000000 --- a/internal/lib/crypto/csr.go +++ /dev/null @@ -1,99 +0,0 @@ -package crypto - -import ( - "bytes" - "crypto/rand" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "fmt" - - "github.com/coinbase/baseca/internal/types" -) - -func GenerateCSR(csr types.CertificateRequest) (*types.SigningRequest, error) { - var generator CSRGenerator - - switch csr.PublicKeyAlgorithm { - case x509.RSA: - if _, ok := types.PublicKeyAlgorithms["RSA"].KeySize[csr.KeySize]; !ok { - return nil, fmt.Errorf("rsa invalid key size %d", csr.KeySize) - } - if _, ok := types.PublicKeyAlgorithms["RSA"].SigningAlgorithm[csr.SigningAlgorithm]; !ok { - return nil, fmt.Errorf("rsa invalid signing algorithm %s", csr.SigningAlgorithm) - } - generator = &SigningRequestGeneratorRSA{Size: csr.KeySize} - case x509.ECDSA: - if _, ok := types.PublicKeyAlgorithms["ECDSA"].KeySize[csr.KeySize]; !ok { - return nil, fmt.Errorf("ecdsa invalid key size %d", csr.KeySize) - } - if _, ok := types.PublicKeyAlgorithms["ECDSA"].SigningAlgorithm[csr.SigningAlgorithm]; !ok { - return nil, fmt.Errorf("ecdsa invalid signing algorithm %s", csr.SigningAlgorithm) - } - generator = &SigningRequestGeneratorECDSA{Curve: csr.KeySize} - default: - return nil, fmt.Errorf("unsupported public key algorithm") - } - - pk, err := generator.Generate() - if err != nil { - return nil, fmt.Errorf("error generating private key [%s]: %w", generator.KeyType(), err) - } - - subject := pkix.Name{ - CommonName: csr.CommonName, - Country: csr.DistinguishedName.Country, - Province: csr.DistinguishedName.Province, - Locality: csr.DistinguishedName.Locality, - Organization: csr.DistinguishedName.Organization, - OrganizationalUnit: csr.DistinguishedName.OrganizationalUnit, - } - - template := x509.CertificateRequest{ - Subject: subject, - SignatureAlgorithm: csr.SigningAlgorithm, - DNSNames: csr.SubjectAlternateNames, - } - - csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &template, pk) - if err != nil { - return nil, fmt.Errorf("error creating certificate request: %w", err) - } - - certificatePem := new(bytes.Buffer) - err = pem.Encode(certificatePem, &pem.Block{ - Type: "CERTIFICATE REQUEST", - Bytes: csrBytes, - }) - - if err != nil { - return nil, fmt.Errorf("error encoding certificate request (csr): %w", err) - } - - if len(csr.Output.CertificateSigningRequest) != 0 { - if err := writeFileToSystem(csr.Output.CertificateSigningRequest, certificatePem.Bytes()); err != nil { - return nil, fmt.Errorf("error writing certificate signing request (csr) to [%s]: %w", csr.Output.CertificateSigningRequest, err) - } - } - - pkBytes, err := generator.MarshalPrivateKey(pk) - if err != nil { - return nil, fmt.Errorf("error marshaling private key: %w", err) - } - - pkBlock := &pem.Block{ - Type: generator.KeyType(), - Bytes: pkBytes, - } - - if len(csr.Output.PrivateKey) != 0 { - if err := writeFileToSystem(csr.Output.PrivateKey, pem.EncodeToMemory(pkBlock)); err != nil { - return nil, fmt.Errorf("error writing private key to [%s]: %w", csr.Output.PrivateKey, err) - } - } - - return &types.SigningRequest{ - CSR: certificatePem, - PrivateKey: pkBlock, - }, nil -} diff --git a/internal/lib/crypto/generate.go b/internal/lib/crypto/generate.go deleted file mode 100644 index f6c4d79..0000000 --- a/internal/lib/crypto/generate.go +++ /dev/null @@ -1,95 +0,0 @@ -package crypto - -import ( - "crypto" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "fmt" - - "github.com/coinbase/baseca/internal/types" -) - -type CSRGenerator interface { - Generate() (crypto.PrivateKey, error) - KeyType() string - MarshalPrivateKey(key crypto.PrivateKey) ([]byte, error) - SupportsPublicKeyAlgorithm(algorithm x509.PublicKeyAlgorithm) bool - SupportsSigningAlgorithm(algorithm x509.SignatureAlgorithm) bool - SupportsKeySize(size int) bool -} - -type SigningRequestGeneratorRSA struct { - Size int -} - -type SigningRequestGeneratorECDSA struct { - Curve int -} - -// RSA Interface -func (r *SigningRequestGeneratorRSA) Generate() (crypto.PrivateKey, error) { - return rsa.GenerateKey(rand.Reader, r.Size) -} - -func (r *SigningRequestGeneratorRSA) KeyType() string { - return "RSA PRIVATE KEY" -} - -func (r *SigningRequestGeneratorRSA) MarshalPrivateKey(key crypto.PrivateKey) ([]byte, error) { - return x509.MarshalPKCS1PrivateKey(key.(*rsa.PrivateKey)), nil -} - -func (r *SigningRequestGeneratorRSA) SupportsPublicKeyAlgorithm(algorithm x509.PublicKeyAlgorithm) bool { - return algorithm == x509.RSA -} - -func (r *SigningRequestGeneratorRSA) SupportsSigningAlgorithm(algorithm x509.SignatureAlgorithm) bool { - _, ok := types.PublicKeyAlgorithms["RSA"].SigningAlgorithm[algorithm] - return ok -} - -func (r *SigningRequestGeneratorRSA) SupportsKeySize(size int) bool { - _, ok := types.PublicKeyAlgorithms["RSA"].KeySize[size] - return ok -} - -// ECDSA Interface -func (e *SigningRequestGeneratorECDSA) Generate() (crypto.PrivateKey, error) { - c, ok := types.PublicKeyAlgorithms["ECDSA"].KeySize[e.Curve] - - if !ok { - return nil, fmt.Errorf("ecdsa curve [%d] not supported", e.Curve) - } - - curve, ok := c.(elliptic.Curve) - if !ok { - return nil, fmt.Errorf("invalid elliptic.Curve type") - } - - return ecdsa.GenerateKey(curve, rand.Reader) -} - -func (e *SigningRequestGeneratorECDSA) KeyType() string { - return "EC PRIVATE KEY" -} - -func (e *SigningRequestGeneratorECDSA) MarshalPrivateKey(key crypto.PrivateKey) ([]byte, error) { - return x509.MarshalECPrivateKey(key.(*ecdsa.PrivateKey)) -} - -func (e *SigningRequestGeneratorECDSA) SupportsPublicKeyAlgorithm(algorithm x509.PublicKeyAlgorithm) bool { - return algorithm == x509.ECDSA -} - -func (e *SigningRequestGeneratorECDSA) SupportsSigningAlgorithm(algorithm x509.SignatureAlgorithm) bool { - _, ok := types.PublicKeyAlgorithms["ECDSA"].SigningAlgorithm[algorithm] - return ok -} - -func (e *SigningRequestGeneratorECDSA) SupportsKeySize(size int) bool { - _, ok := types.PublicKeyAlgorithms["ECDSA"].KeySize[size] - return ok -} diff --git a/internal/lib/crypto/generate_test.go b/internal/lib/crypto/generate_test.go deleted file mode 100644 index d89bfd0..0000000 --- a/internal/lib/crypto/generate_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package crypto - -import ( - "crypto/x509" - "testing" -) - -func TestSigningRequestGeneratorRSA(t *testing.T) { - r := &SigningRequestGeneratorRSA{ - Size: 2048, - } - - key, err := r.Generate() - if err != nil { - t.Fatalf("error generating rsa private key: %v", err) - } - - if keyType := r.KeyType(); keyType != "RSA PRIVATE KEY" { - t.Errorf("RSA PRIVATE KEY does not exist within private key") - - } - - if !r.SupportsPublicKeyAlgorithm(x509.RSA) { - t.Errorf("rsa public key algorithm not supported") - } - - if !r.SupportsSigningAlgorithm(x509.SHA256WithRSA) { - t.Errorf("SHA256WithRSA signing algorithm not supported") - } - - if !r.SupportsKeySize(2048) { - t.Errorf("rsa key size not supported") - } - - _, err = r.MarshalPrivateKey(key) - if err != nil { - t.Errorf("error marshaling rsa private key: %v", err) - } -} - -func TestSigningRequestGeneratorECDSA(t *testing.T) { - e := &SigningRequestGeneratorECDSA{ - Curve: 256, - } - - key, err := e.Generate() - if err != nil { - t.Fatalf("error generating ecdsa private key: %v", err) - } - - if keyType := e.KeyType(); keyType != "EC PRIVATE KEY" { - t.Errorf("EC PRIVATE KEY does not exist within private key") - } - - if !e.SupportsPublicKeyAlgorithm(x509.ECDSA) { - t.Errorf("ecdsa public key algorithm not supported") - } - - if !e.SupportsSigningAlgorithm(x509.ECDSAWithSHA256) { - t.Errorf("ECDSAWithSHA256 signing algorithm not supported") - } - - if !e.SupportsKeySize(256) { - t.Errorf("ecdsa curve size not supported") - } - - _, err = e.MarshalPrivateKey(key) - if err != nil { - t.Errorf("error marshaling ecdsa private key: %v", err) - } -} diff --git a/internal/lib/crypto/pk.go b/internal/lib/crypto/pk.go index 5c2f4f9..afadf0a 100644 --- a/internal/lib/crypto/pk.go +++ b/internal/lib/crypto/pk.go @@ -1,13 +1,8 @@ package crypto import ( - "crypto" - "crypto/ecdsa" - "crypto/rand" - "crypto/rsa" "crypto/x509" "encoding/pem" - "errors" "fmt" "os" "path/filepath" @@ -16,54 +11,6 @@ import ( "github.com/coinbase/baseca/internal/types" ) -type RSA struct { - PublicKey *rsa.PublicKey - PrivateKey *rsa.PrivateKey -} - -type ECDSA struct { - PublicKey *ecdsa.PublicKey - PrivateKey *ecdsa.PrivateKey -} - -func (key *RSA) KeyPair() any { - return key -} - -func (key *RSA) Sign(data []byte) ([]byte, error) { - h := crypto.SHA256.New() - h.Write(data) - hashed := h.Sum(nil) - return rsa.SignPKCS1v15(rand.Reader, key.PrivateKey, crypto.SHA256, hashed) -} - -func (key *ECDSA) KeyPair() any { - return key -} - -func (key *ECDSA) Sign(data []byte) ([]byte, error) { - h := crypto.SHA256.New() - h.Write(data) - hashed := h.Sum(nil) - r, s, err := ecdsa.Sign(rand.Reader, key.PrivateKey, hashed) - if err != nil { - return nil, err - } - signature := append(r.Bytes(), s.Bytes()...) - return signature, nil -} - -func ReturnPrivateKey(key types.AsymmetricKey) (any, error) { - switch k := key.KeyPair().(type) { - case *RSA: - return k.PrivateKey, nil - case *ECDSA: - return k.PrivateKey, nil - default: - return nil, fmt.Errorf("unsupported key type") - } -} - func GetSubordinateCaParameters(service string) (*types.CertificateAuthority, error) { subordinatePath := filepath.Join(types.SubordinatePath, service+_subordinateCertificate) subordinate, err := readFileFromSystem(subordinatePath) @@ -92,11 +39,6 @@ func GetSubordinateCaParameters(service string) (*types.CertificateAuthority, er return nil, fmt.Errorf("error decoding private key") } - subordinatePrivateKey, err := formatAsymmetricKey(pkPem) - if err != nil { - return nil, fmt.Errorf("error formatting private key: %w", err) - } - serialNumberPath := filepath.Join(types.SubordinatePath, service+_subordinateSerialNumber) caSerialNumber, err := readFileFromSystem(serialNumberPath) if err != nil { @@ -111,7 +53,7 @@ func GetSubordinateCaParameters(service string) (*types.CertificateAuthority, er return &types.CertificateAuthority{ Certificate: subordinateCertificate, - AsymmetricKey: &subordinatePrivateKey, + PrivateKey: pkPem, SerialNumber: string(*caSerialNumber), CertificateAuthorityArn: string(*caArn), }, nil @@ -139,46 +81,3 @@ func writeFileToSystem(path string, data []byte) error { } return nil } - -func formatAsymmetricKey(block *pem.Block) (types.AsymmetricKey, error) { - switch block.Type { - case "RSA PRIVATE KEY": - rsaKey, err := parseRSAPrivateKey(block.Bytes) - if err != nil { - return nil, err - } - return rsaKey, nil - case "EC PRIVATE KEY": - ecdsaKey, err := parseECDSAPrivateKey(block.Bytes) - if err != nil { - return nil, err - } - return ecdsaKey, nil - default: - return nil, errors.New("unsupported key type") - } -} - -func parseRSAPrivateKey(keyBytes []byte) (*RSA, error) { - key, err := x509.ParsePKCS1PrivateKey(keyBytes) - if err != nil { - return nil, err - } - rsaPrivateKey := &RSA{ - PublicKey: &key.PublicKey, - PrivateKey: key, - } - return rsaPrivateKey, nil -} - -func parseECDSAPrivateKey(keyBytes []byte) (*ECDSA, error) { - key, err := x509.ParseECPrivateKey(keyBytes) - if err != nil { - return nil, err - } - ecdsaPrivateKey := &ECDSA{ - PublicKey: &key.PublicKey, - PrivateKey: key, - } - return ecdsaPrivateKey, nil -} diff --git a/internal/lib/crypto/store.go b/internal/lib/crypto/store.go index 6881a69..5801549 100644 --- a/internal/lib/crypto/store.go +++ b/internal/lib/crypto/store.go @@ -10,10 +10,10 @@ import ( "github.com/aws/aws-sdk-go-v2/service/acmpca" "github.com/coinbase/baseca/internal/types" + lib "github.com/coinbase/baseca/pkg/types" ) -func WriteKeyToFile(service string, privateKey types.AsymmetricKey) error { - var pemBlock *pem.Block +func WriteKeyToFile(service string, privateKey *pem.Block) error { directoryPath := filepath.Join(types.SubordinatePath, service) filePath := filepath.Join(directoryPath, _subordinatePrivateKey) @@ -21,30 +21,9 @@ func WriteKeyToFile(service string, privateKey types.AsymmetricKey) error { return fmt.Errorf("unsafe file input, write private key") } - switch k := privateKey.KeyPair().(type) { - case *RSA: - pkBytes := x509.MarshalPKCS1PrivateKey(k.PrivateKey) - pemBlock = &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: pkBytes, - } - case *ECDSA: - pkBytes, err := x509.MarshalECPrivateKey(k.PrivateKey) - if err != nil { - return err - } - pemBlock = &pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: pkBytes, - } - default: - return fmt.Errorf("private key format not supported") - } - - if err := os.WriteFile(filePath, pem.EncodeToMemory(pemBlock), os.ModePerm); err != nil { + if err := os.WriteFile(filePath, pem.EncodeToMemory(privateKey), os.ModePerm); err != nil { return err } - return nil } @@ -115,7 +94,7 @@ func encodeCertificateFromString(certificate *string) (*[]byte, error) { } pemBlock := pem.EncodeToMemory( &pem.Block{ - Type: "CERTIFICATE", + Type: lib.CERTIFICATE.String(), Bytes: x509Certificate.Raw, }, ) @@ -125,7 +104,7 @@ func encodeCertificateFromString(certificate *string) (*[]byte, error) { func encodeCertificateFromx509(certificate *x509.Certificate) *[]byte { pemBlock := pem.EncodeToMemory( &pem.Block{ - Type: "CERTIFICATE", + Type: lib.CERTIFICATE.String(), Bytes: certificate.Raw, }, ) diff --git a/internal/lib/util/cpu.go b/internal/lib/util/cpu.go new file mode 100644 index 0000000..a123a39 --- /dev/null +++ b/internal/lib/util/cpu.go @@ -0,0 +1,49 @@ +package util + +import ( + "time" + + "github.com/coinbase/baseca/internal/logger" + "github.com/gogo/status" + "github.com/shirou/gopsutil/cpu" + "go.uber.org/zap" + "google.golang.org/grpc/codes" +) + +var ( + _default_cpu_interval = time.Second * 5 + _default_cpu_threshold = 70.0 + + _backoff_duration = time.Second * 5 + _backoff_timeout = time.Minute * 1 + + CPU_HIGH = false +) + +func UpdateCPULoad() { + ticker := time.NewTicker(_default_cpu_interval) + defer ticker.Stop() + + for range ticker.C { + cpu, err := cpu.Percent(0, false) + if err != nil { + logger.DefaultLogger.Error("error retrieving cpu utilization", zap.Error(err)) + continue + } + CPU_HIGH = cpu[0] > _default_cpu_threshold + } +} + +// Backoff Authentication [middleware/authentication.go] +var ProcessBackoff = func() error { + timeout := time.NewTimer(_backoff_timeout) + defer timeout.Stop() + + select { + case <-time.After(_backoff_duration): + logger.DefaultLogger.Warn("cpu load high") // TODO: Additional Context + return nil + case <-timeout.C: + return logger.RpcError(status.Error(codes.Internal, "queue processinging signing requests at capacity"), nil) + } +} diff --git a/internal/lib/util/random.go b/internal/lib/util/random.go deleted file mode 100644 index 2f4509c..0000000 --- a/internal/lib/util/random.go +++ /dev/null @@ -1,98 +0,0 @@ -package util - -import ( - "crypto/rand" - "encoding/base64" - "encoding/hex" - "fmt" - "testing" - "time" - - db "github.com/coinbase/baseca/db/sqlc" - lib "github.com/coinbase/baseca/internal/lib/authentication" - "github.com/google/uuid" -) - -func GenerateTestUser(t *testing.T, permissions string, length int) (db.User, string) { - client_id, _ := uuid.NewRandom() - credentials := generateRandomCredentials(length) - hashed_credentials, _ := lib.HashPassword(credentials) - email := generateRandomEmail() - username := generateRandomUsername() - full_name := generateRandomName() - - return db.User{ - Uuid: client_id, - Username: username, - HashedCredential: hashed_credentials, - FullName: full_name, - Email: email, - Permissions: permissions, - CredentialChangedAt: time.Now().UTC(), - CreatedAt: time.Now().UTC(), - }, credentials -} - -func generateRandomEmail() string { - randBytes := make([]byte, 8) - _, err := rand.Read(randBytes) - if err != nil { - panic(err) - } - - // Encode the random bytes using base64 encoding to get an ASCII string - randStr := base64.URLEncoding.EncodeToString(randBytes) - - // Use the first 10 characters of the base64-encoded string as the email username - return fmt.Sprintf("%s@coinbase.com", randStr[:10]) -} - -func generateRandomName() string { - // Generate random bytes for the first and last name - firstNameBytes := make([]byte, 6) - _, err := rand.Read(firstNameBytes) - if err != nil { - panic(err) - } - lastNameBytes := make([]byte, 6) - _, err = rand.Read(lastNameBytes) - if err != nil { - panic(err) - } - - // Convert the random bytes to hexadecimal strings - firstNameHex := hex.EncodeToString(firstNameBytes)[:10] - lastNameHex := hex.EncodeToString(lastNameBytes)[:10] - - return fmt.Sprintf("%s %s", firstNameHex, lastNameHex) -} - -func generateRandomUsername() string { - // Generate random bytes for the username - usernameBytes := make([]byte, 8) - _, err := rand.Read(usernameBytes) - if err != nil { - panic(err) - } - - // Encode the random bytes using base64 encoding to get an ASCII string - usernameStr := base64.URLEncoding.EncodeToString(usernameBytes) - - // Use the first 10 characters of the base64-encoded string as the username - return usernameStr[:10] -} - -func generateRandomCredentials(length int) string { - // Generate random bytes for the credentials - credentialsBytes := make([]byte, length) - _, err := rand.Read(credentialsBytes) - if err != nil { - panic(err) - } - - // Encode the random bytes using base64 encoding to get an ASCII string - credentialsStr := base64.URLEncoding.EncodeToString(credentialsBytes) - - // Return the first `length` characters of the base64-encoded string - return credentialsStr[:length] -} diff --git a/internal/lib/util/validator/environment.go b/internal/lib/util/validator/environment.go index 8972890..33fe475 100644 --- a/internal/lib/util/validator/environment.go +++ b/internal/lib/util/validator/environment.go @@ -9,17 +9,28 @@ const ( BaseDirectory = "/tmp/baseca/ssl" ) -var CertificateAuthorityEnvironments map[string][]string +var CertificateAuthorityEnvironments map[types.EnvironmentKey][]string +var CertificateAuthorityEnvironmentsString map[string][]string func SupportedEnvironments(cfg *config.Config) { - CertificateAuthorityEnvironments = map[string][]string{ - "local": cfg.Environment.Local, - "sandbox": cfg.Environment.Sandbox, - "development": cfg.Environment.Development, - "staging": cfg.Environment.Staging, - "pre_production": cfg.Environment.PreProduction, - "production": cfg.Environment.Production, - "corporate": cfg.Environment.Corporate, + CertificateAuthorityEnvironments = map[types.EnvironmentKey][]string{ + types.Local: cfg.Environment.Local, + types.Sandbox: cfg.Environment.Sandbox, + types.Development: cfg.Environment.Development, + types.Staging: cfg.Environment.Staging, + types.PreProduction: cfg.Environment.PreProduction, + types.Production: cfg.Environment.Production, + types.Corporate: cfg.Environment.Corporate, + } + + CertificateAuthorityEnvironmentsString = map[string][]string{ + types.Local.String(): cfg.Environment.Local, + types.Sandbox.String(): cfg.Environment.Sandbox, + types.Development.String(): cfg.Environment.Development, + types.Staging.String(): cfg.Environment.Staging, + types.PreProduction.String(): cfg.Environment.PreProduction, + types.Production.String(): cfg.Environment.Production, + types.Corporate.String(): cfg.Environment.Corporate, } } diff --git a/internal/lib/util/validator/environment_test.go b/internal/lib/util/validator/environment_test.go index f04f93f..2ae3cc8 100644 --- a/internal/lib/util/validator/environment_test.go +++ b/internal/lib/util/validator/environment_test.go @@ -16,13 +16,13 @@ func TestSupportedEnvironments(t *testing.T) { SupportedEnvironments(cfg) - if len(CertificateAuthorityEnvironments["local"]) == 0 { + if len(CertificateAuthorityEnvironments[types.Local]) == 0 { t.Errorf("Expected non-empty local environments, got none") } } func TestSetBaseDirectory(t *testing.T) { - // When BaseDirectory is provided + // BaseDirectory Provided cfg := &config.Config{ SubordinateMetadata: config.SubordinateCertificateAuthority{ BaseDirectory: "/some/dir", diff --git a/internal/lib/util/validator/permission.go b/internal/lib/util/validator/permission.go index 7f6e3bc..2845458 100644 --- a/internal/lib/util/validator/permission.go +++ b/internal/lib/util/validator/permission.go @@ -4,7 +4,7 @@ import "github.com/coinbase/baseca/internal/types" func IsSupportedPermission(permission string) bool { switch permission { - case types.ADMIN, types.PRIVILEGED, types.READ: + case types.ADMIN.String(), types.PRIVILEGED.String(), types.READ.String(): return true } diff --git a/internal/lib/util/validator/validate.go b/internal/lib/util/validator/validate.go index 1d5d18a..ffc160a 100644 --- a/internal/lib/util/validator/validate.go +++ b/internal/lib/util/validator/validate.go @@ -20,7 +20,7 @@ func ValidateCertificateAuthorityEnvironment(config config.Stage, environment st } for _, certificate_authority := range certificate_authorities { - if output := Contains(CertificateAuthorityEnvironments[environment], certificate_authority); !output { + if output := Contains(CertificateAuthorityEnvironmentsString[environment], certificate_authority); !output { return false } } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 4faca0d..47a2de3 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/coinbase/baseca/internal/types" "github.com/gogo/status" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -12,11 +13,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" - - "github.com/coinbase/baseca/internal/types" ) -type Extractor func(resp any, err error, code codes.Code) string +type Extractor func(resp interface{}, err error, code codes.Code) string type Error struct { UserError error @@ -35,7 +34,7 @@ func RpcError(user, internal error) *Error { } func RpcLogger(extractor Extractor) grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { currentTime := time.Now().UTC() result, err := handler(ctx, req) duration := time.Since(currentTime) @@ -56,14 +55,14 @@ func RpcLogger(extractor Extractor) grpc.UnaryServerInterceptor { Str("ip_address", clientIP). Dur("duration", duration) - provisioner, ok := ctx.Value(types.ProvisionerAuthenticationContextKey).(*types.ProvisionerAccountPayload) + provisioner, ok := ctx.Value(types.ProvisionerAuthenticationContextKey).(string) if ok { - event.Str("provisioner_account_uuid", provisioner.ClientId.String()) + event.Str("provisioner_account_uuid", provisioner) } - service, ok := ctx.Value(types.ServiceAuthenticationContextKey).(*types.ServiceAccountPayload) + service, ok := ctx.Value(types.ServiceAuthenticationContextKey).(string) if ok { - event.Str("service_account_uuid", service.ServiceID.String()) + event.Str("service_account_uuid", service) } event.Msg(extractor(result, err, statusCode)) @@ -170,7 +169,7 @@ func (ctxLogger *ContextLogger) fields(fields []zap.Field) []zap.Field { } func (ctxLogger *ContextLogger) stackFields(fields []zap.Field) []zap.Field { - return ctxLogger.fields(fields) + return append(ctxLogger.fields(fields)) } func (ctxLogger *ContextLogger) Panic(msg string, fields ...zap.Field) { diff --git a/internal/types/attestation.go b/internal/types/attestation.go index 93b9cc1..8b11393 100644 --- a/internal/types/attestation.go +++ b/internal/types/attestation.go @@ -1,10 +1,28 @@ package types +import ( + "github.com/coinbase/baseca/pkg/attestor/aws_iid" + "github.com/google/uuid" +) + +type Attestation uint + +const ( + AWS_IID Attestation = iota +) + +func (a Attestation) String() string { + return [...]string{ + "AWS_IID"}[a] +} + type NodeAttestation struct { - AWSInstanceIdentityDocument AWSInstanceIdentityDocument `json:"aws_iid"` + EC2NodeAttestation EC2NodeAttestation `json:"aws_iid"` } -type AWSInstanceIdentityDocument struct { +// Node Attestation Configured in Database +type EC2NodeAttestation struct { + ClientID uuid.UUID `json:"client_id"` RoleArn string `json:"instance_profile_arn,omitempty"` AssumeRole string `json:"assume_role,omitempty"` SecurityGroups []string `json:"security_groups,omitempty"` @@ -14,10 +32,20 @@ type AWSInstanceIdentityDocument struct { InstanceTags map[string]string `json:"instance_tags,omitempty"` } -type Node struct { - AWS_IID string +type NodeIIDAttestation struct { + Uuid uuid.UUID + EC2InstanceMetadata aws_iid.EC2InstanceMetadata + Attestation EC2NodeAttestation } -var Attestation = Node{ - AWS_IID: "AWS_IID", +type InstanceIdentityDocument struct { + AccountId string `json:"accountId"` + Architecture string `json:"architecture"` + AvailabilityZone string `json:"availabilityZone"` + ImageId string `json:"imageId"` + InstanceId string `json:"instanceId"` + InstanceType string `json:"instanceType"` + PrivateIp string `json:"privateIp"` + Region string `json:"region"` + Version string `json:"version"` } diff --git a/internal/types/authentication.go b/internal/types/authentication.go new file mode 100644 index 0000000..3b2f35b --- /dev/null +++ b/internal/types/authentication.go @@ -0,0 +1,47 @@ +package types + +type ContextKey uint + +const ( + // Context Metadata + ServiceAuthenticationContextKey ContextKey = iota + ProvisionerAuthenticationContextKey + UserAuthenticationContextKey +) + +type UserKey uint + +const ( + // User Permissions + ADMIN UserKey = iota + PRIVILEGED + READ +) + +func (u UserKey) String() string { + return [...]string{ + "ADMIN", + "PRIVILEGED", + "READ", + }[u] +} + +type AuthenticationKey uint + +const ( + PassAuthentication AuthenticationKey = iota + ServiceAuthentication + ProvisionerAuthentication +) + +var Methods = map[string]AuthenticationKey{ + "/grpc.health.v1.Health/Check": PassAuthentication, + "/baseca.v1.Account/LoginUser": PassAuthentication, + "/baseca.v1.Account/UpdateUserCredentials": PassAuthentication, + "/baseca.v1.Certificate/SignCSR": ServiceAuthentication, + "/baseca.v1.Certificate/OperationsSignCSR": ProvisionerAuthentication, + "/baseca.v1.Certificate/QueryCertificateMetadata": ProvisionerAuthentication, + "/baseca.v1.Service/ProvisionServiceAccount": ProvisionerAuthentication, + "/baseca.v1.Service/GetServiceAccountByMetadata": ProvisionerAuthentication, + "/baseca.v1.Service/DeleteProvisionedServiceAccount": ProvisionerAuthentication, +} diff --git a/internal/types/certificate.go b/internal/types/certificate.go index 64431d3..91306ba 100644 --- a/internal/types/certificate.go +++ b/internal/types/certificate.go @@ -1,13 +1,9 @@ package types import ( - "bytes" - "crypto/elliptic" "crypto/x509" "encoding/pem" "time" - - "github.com/aws/aws-sdk-go-v2/service/acmpca/types" ) var SubordinatePath string @@ -29,25 +25,16 @@ type Extensions struct { type Algorithm struct { Algorithm x509.PublicKeyAlgorithm - KeySize map[int]any + KeySize map[int]interface{} Signature map[string]bool SigningAlgorithm map[x509.SignatureAlgorithm]bool } -type SignatureAlgorithm struct { - Common x509.SignatureAlgorithm - PCA types.SigningAlgorithm -} - -type SigningRequest struct { - CSR *bytes.Buffer - PrivateKey *pem.Block -} - -type SignedCertificate struct { - CertificatePath string - IntermediateCertificateChainPath string - RootCertificateChainPath string +type CertificateResponseData struct { + Certificate string `json:"certificate"` + IntermediateCertificateChain string `json:"intermediate_certificate_chain,omitempty"` + RootCertificateChain string `json:"root_certificate_chain,omitempty"` + Metadata CertificateMetadata `json:"metadata"` } type CertificateMetadata struct { @@ -63,31 +50,6 @@ type CertificateMetadata struct { RevokeDate time.Time } -type CertificateRequest struct { - CommonName string - SubjectAlternateNames []string - DistinguishedName DistinguishedName - SigningAlgorithm x509.SignatureAlgorithm - PublicKeyAlgorithm x509.PublicKeyAlgorithm - KeySize int - Output Output -} - -type Output struct { - CertificateSigningRequest string - Certificate string - CertificateChain string - PrivateKey string -} - -type DistinguishedName struct { - Country []string - Province []string - Locality []string - Organization []string - OrganizationalUnit []string -} - type EC2InstanceMetadata struct { InstanceIdentityDocument []byte `json:"instance_identity_document"` InstanceIdentitySignature []byte `json:"instance_identity_signature"` @@ -95,89 +57,11 @@ type EC2InstanceMetadata struct { type CertificateAuthority struct { Certificate *x509.Certificate - AsymmetricKey *AsymmetricKey + PrivateKey *pem.Block SerialNumber string CertificateAuthorityArn string } -type AsymmetricKey interface { - KeyPair() any - Sign(data []byte) ([]byte, error) -} - -var ValidSignatures = map[string]SignatureAlgorithm{ - "SHA256WITHECDSA": { - Common: x509.ECDSAWithSHA256, - PCA: types.SigningAlgorithmSha256withecdsa, - }, - "SHA384WITHECDSA": { - Common: x509.ECDSAWithSHA384, - PCA: types.SigningAlgorithmSha384withecdsa, - }, - "SHA512WITHECDSA": { - Common: x509.ECDSAWithSHA512, - PCA: types.SigningAlgorithmSha512withecdsa, - }, - "SHA256WITHRSA": { - Common: x509.SHA256WithRSA, - PCA: types.SigningAlgorithmSha256withrsa, - }, - "SHA384WITHRSA": { - Common: x509.SHA384WithRSA, - PCA: types.SigningAlgorithmSha384withrsa, - }, - "SHA512WITHRSA": { - Common: x509.SHA512WithRSA, - PCA: types.SigningAlgorithmSha512withrsa, - }, - // TODO: Support Probabilistic Element to the Signature Scheme [SHA256WithRSAPSS] -} - -var PublicKeyAlgorithms = map[string]Algorithm{ - "RSA": { - Algorithm: x509.RSA, - KeySize: map[int]any{ - 2048: true, - 4096: true, - }, - Signature: map[string]bool{ - "SHA256WITHRSA": true, - "SHA384WITHRSA": true, - "SHA512WITHRSA": true, - }, - SigningAlgorithm: map[x509.SignatureAlgorithm]bool{ - x509.SHA256WithRSA: true, - x509.SHA384WithRSA: true, - x509.SHA512WithRSA: true, - }, - }, - "ECDSA": { - Algorithm: x509.ECDSA, - KeySize: map[int]any{ - 256: elliptic.P256(), - 384: elliptic.P384(), - 521: elliptic.P521(), - }, - Signature: map[string]bool{ - "SHA256WITHECDSA": true, - "SHA384WITHECDSA": true, - "SHA512WITHECDSA": true, - }, - SigningAlgorithm: map[x509.SignatureAlgorithm]bool{ - x509.ECDSAWithSHA256: true, - x509.ECDSAWithSHA384: true, - x509.ECDSAWithSHA512: true, - }, - }, - // TODO: Support Ed25519 - "Ed25519": { - Algorithm: x509.Ed25519, - KeySize: map[int]any{ - 256: true, - }, - }, -} - var CertificateRequestExtension = map[string]Extensions{ "EndEntityClientAuthCertificate": { KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, @@ -197,6 +81,6 @@ var CertificateRequestExtension = map[string]Extensions{ } var ValidNodeAttestation = map[string]bool{ - "None": false, - "AWS": true, + "Local": false, + "AWS": true, } diff --git a/internal/types/constants.go b/internal/types/constants.go deleted file mode 100644 index 3372b65..0000000 --- a/internal/types/constants.go +++ /dev/null @@ -1,16 +0,0 @@ -package types - -type ContextKey int - -const ( - // Context Metadata - ServiceAuthenticationContextKey ContextKey = iota - ProvisionerAuthenticationContextKey ContextKey = iota - UserAuthenticationContextKey ContextKey = iota - EnrollmentAuthenticationContextKey ContextKey = iota - - // User Permissions - ADMIN = "ADMIN" - PRIVILEGED = "PRIVILEGED" - READ = "READ" -) diff --git a/internal/types/enrollment.go b/internal/types/enrollment.go deleted file mode 100644 index 6d1eec2..0000000 --- a/internal/types/enrollment.go +++ /dev/null @@ -1,14 +0,0 @@ -package types - -type DeviceEnrollmentRequest struct { - SerialNumber string `json:"serial_number" binding:"required"` - Environment string `json:"environment" binding:"required,ca_environment"` -} - -type DeviceEnrollmentResponse struct { - SerialNumber string `json:"serial_number"` - Credentials string `json:"credentials"` -} - -type EndpointCertificateIssueRequest struct { -} diff --git a/internal/types/environment.go b/internal/types/environment.go new file mode 100644 index 0000000..0956b12 --- /dev/null +++ b/internal/types/environment.go @@ -0,0 +1,26 @@ +package types + +type EnvironmentKey uint + +const ( + // Environments + Production EnvironmentKey = iota + PreProduction + Staging + Development + Sandbox + Local + Corporate +) + +func (u EnvironmentKey) String() string { + return [...]string{ + "production", + "pre_production", + "staging", + "development", + "sandbox", + "local", + "corporate", + }[u] +} diff --git a/internal/v1/accounts/provision.go b/internal/v1/accounts/provision.go index fa09d21..244a185 100644 --- a/internal/v1/accounts/provision.go +++ b/internal/v1/accounts/provision.go @@ -21,10 +21,6 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) -var ( - _production = "production" -) - func (s *Service) CreateProvisionerAccount(ctx context.Context, req *apiv1.CreateProvisionerAccountRequest) (*apiv1.CreateProvisionerAccountResponse, error) { var service *db.Provisioner nodeAttestation := []string{} @@ -45,7 +41,7 @@ func (s *Service) CreateProvisionerAccount(ctx context.Context, req *apiv1.Creat } for _, environment := range req.Environments { - if _, ok := validator.CertificateAuthorityEnvironments[environment]; !ok { + if _, ok := validator.CertificateAuthorityEnvironmentsString[environment]; !ok { return nil, logger.RpcError(status.Error(codes.InvalidArgument, "invalid certificate authority environment"), fmt.Errorf("invalid certificate authority environment [%s]", environment)) } } @@ -88,8 +84,8 @@ func (s *Service) CreateProvisionerAccount(ctx context.Context, req *apiv1.Creat return nil, status.Error(codes.InvalidArgument, "service auth context missing") } - if validator.Contains(req.Environments, _production) || req.NodeAttestation != nil { - if err = validateNodeAttestation(req.NodeAttestation); err != nil { + if validator.Contains(req.Environments, types.Production.String()) || req.NodeAttestation != nil { + if err = verifyNodeAttestationParameters(req.NodeAttestation); err != nil { return nil, logger.RpcError(status.Error(codes.InvalidArgument, err.Error()), err) } } @@ -239,7 +235,7 @@ func (s *Service) ProvisionServiceAccount(ctx context.Context, req *apiv1.Provis if len(certificate_authorities) == 0 { environment := req.Environment - certificate_authorities := validator.CertificateAuthorityEnvironments[environment] + certificate_authorities := validator.CertificateAuthorityEnvironmentsString[environment] // Include Default Certificate Authorities for _, ca := range certificate_authorities { @@ -286,8 +282,8 @@ func (s *Service) ProvisionServiceAccount(ctx context.Context, req *apiv1.Provis return nil, logger.RpcError(status.Error(codes.InvalidArgument, "invalid team parameter"), fmt.Errorf("invalid team [%s]", req.Team)) } - if req.Environment == _production || req.NodeAttestation != nil { - if err = validateNodeAttestation(req.NodeAttestation); err != nil { + if req.Environment == types.Production.String() || req.NodeAttestation != nil { + if err = verifyNodeAttestationParameters(req.NodeAttestation); err != nil { return nil, logger.RpcError(status.Error(codes.InvalidArgument, err.Error()), err) } } diff --git a/internal/v1/accounts/query.go b/internal/v1/accounts/query.go index f2b7408..4b8d3d7 100644 --- a/internal/v1/accounts/query.go +++ b/internal/v1/accounts/query.go @@ -118,7 +118,7 @@ func (s *Service) GetProvisionerAccount(ctx context.Context, req *apiv1.AccountI func (s *Service) transformServiceAccount(ctx context.Context, account *db.Account) (*apiv1.ServiceAccount, error) { var attestation types.NodeAttestation - if validator.Contains(account.NodeAttestation, types.Attestation.AWS_IID) { + if validator.Contains(account.NodeAttestation, types.AWS_IID.String()) { iid, err := s.store.Reader.GetInstanceIdentityDocument(ctx, account.ClientID) if err != nil { return nil, err @@ -131,7 +131,7 @@ func (s *Service) transformServiceAccount(ctx context.Context, account *db.Accou // TODO: Update awsIid {} Response attestation = types.NodeAttestation{ - AWSInstanceIdentityDocument: types.AWSInstanceIdentityDocument{ + EC2NodeAttestation: types.EC2NodeAttestation{ RoleArn: iid.RoleArn.String, AssumeRole: iid.AssumeRole.String, SecurityGroups: iid.SecurityGroupID, @@ -160,12 +160,12 @@ func (s *Service) transformServiceAccount(ctx context.Context, account *db.Accou CreatedBy: account.CreatedBy.String(), NodeAttestation: &apiv1.NodeAttestation{ AwsIid: &apiv1.AWSInstanceIdentityDocument{ - RoleArn: attestation.AWSInstanceIdentityDocument.RoleArn, - AssumeRole: attestation.AWSInstanceIdentityDocument.AssumeRole, - SecurityGroups: attestation.AWSInstanceIdentityDocument.SecurityGroups, - Region: attestation.AWSInstanceIdentityDocument.Region, - InstanceId: attestation.AWSInstanceIdentityDocument.InstanceID, - InstanceTags: attestation.AWSInstanceIdentityDocument.InstanceTags, + RoleArn: attestation.EC2NodeAttestation.RoleArn, + AssumeRole: attestation.EC2NodeAttestation.AssumeRole, + SecurityGroups: attestation.EC2NodeAttestation.SecurityGroups, + Region: attestation.EC2NodeAttestation.Region, + InstanceId: attestation.EC2NodeAttestation.InstanceID, + InstanceTags: attestation.EC2NodeAttestation.InstanceTags, }, }, }, nil @@ -174,7 +174,7 @@ func (s *Service) transformServiceAccount(ctx context.Context, account *db.Accou func (s *Service) transformProvisionerAccount(ctx context.Context, account *db.Provisioner) (*apiv1.ProvisionerAccount, error) { var attestation types.NodeAttestation - if validator.Contains(account.NodeAttestation, types.Attestation.AWS_IID) { + if validator.Contains(account.NodeAttestation, types.AWS_IID.String()) { iid, err := s.store.Reader.GetInstanceIdentityDocument(ctx, account.ClientID) if err != nil { return nil, err @@ -187,7 +187,7 @@ func (s *Service) transformProvisionerAccount(ctx context.Context, account *db.P // TODO: Update awsIid {} Response attestation = types.NodeAttestation{ - AWSInstanceIdentityDocument: types.AWSInstanceIdentityDocument{ + EC2NodeAttestation: types.EC2NodeAttestation{ RoleArn: iid.RoleArn.String, AssumeRole: iid.AssumeRole.String, SecurityGroups: iid.SecurityGroupID, @@ -213,12 +213,12 @@ func (s *Service) transformProvisionerAccount(ctx context.Context, account *db.P CreatedBy: account.CreatedBy.String(), NodeAttestation: &apiv1.NodeAttestation{ AwsIid: &apiv1.AWSInstanceIdentityDocument{ - RoleArn: attestation.AWSInstanceIdentityDocument.RoleArn, - AssumeRole: attestation.AWSInstanceIdentityDocument.AssumeRole, - SecurityGroups: attestation.AWSInstanceIdentityDocument.SecurityGroups, - Region: attestation.AWSInstanceIdentityDocument.Region, - InstanceId: attestation.AWSInstanceIdentityDocument.InstanceID, - InstanceTags: attestation.AWSInstanceIdentityDocument.InstanceTags, + RoleArn: attestation.EC2NodeAttestation.RoleArn, + AssumeRole: attestation.EC2NodeAttestation.AssumeRole, + SecurityGroups: attestation.EC2NodeAttestation.SecurityGroups, + Region: attestation.EC2NodeAttestation.Region, + InstanceId: attestation.EC2NodeAttestation.InstanceID, + InstanceTags: attestation.EC2NodeAttestation.InstanceTags, }, }, }, nil diff --git a/internal/v1/accounts/service.go b/internal/v1/accounts/service.go index ab782d0..a2d8db7 100644 --- a/internal/v1/accounts/service.go +++ b/internal/v1/accounts/service.go @@ -70,8 +70,8 @@ func (s *Service) CreateServiceAccount(ctx context.Context, req *apiv1.CreateSer } // Production Service Accounts Require Attestation - if req.Environment == _production || req.NodeAttestation != nil { - if err = validateNodeAttestation(req.NodeAttestation); err != nil { + if req.Environment == types.Production.String() || req.NodeAttestation != nil { + if err = verifyNodeAttestationParameters(req.NodeAttestation); err != nil { return nil, logger.RpcError(status.Error(codes.InvalidArgument, err.Error()), err) } } diff --git a/internal/v1/accounts/validate.go b/internal/v1/accounts/validate.go index 1b20f44..2a17e6a 100644 --- a/internal/v1/accounts/validate.go +++ b/internal/v1/accounts/validate.go @@ -136,7 +136,7 @@ func (s *Service) validateSanInputProvisionerAccount(ctx context.Context, provis } func (s *Service) validateCertificateParameters(certificateAuthorities []string, environment string, certificateValidity int16, subordinateCa string) error { - if _, ok := validator.CertificateAuthorityEnvironments[environment]; !ok { + if _, ok := validator.CertificateAuthorityEnvironmentsString[environment]; !ok { return fmt.Errorf("invalid environment [%s]", environment) } @@ -165,7 +165,7 @@ func (s *Service) validateCertificateParameters(certificateAuthorities []string, return nil } -func validateNodeAttestation(attestation *apiv1.NodeAttestation) error { +func verifyNodeAttestationParameters(attestation *apiv1.NodeAttestation) error { if attestation == nil { return fmt.Errorf("node_attestation cannot be empty") } diff --git a/internal/v1/certificate/operations_test.go b/internal/v1/certificate/operations_test.go index 56a37fa..86153ce 100644 --- a/internal/v1/certificate/operations_test.go +++ b/internal/v1/certificate/operations_test.go @@ -10,12 +10,10 @@ import ( "github.com/coinbase/baseca/db/mock" db "github.com/coinbase/baseca/db/sqlc" - "github.com/coinbase/baseca/internal/lib/crypto" - "github.com/coinbase/baseca/internal/types" + apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" + c "github.com/coinbase/baseca/pkg/client" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - - apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" ) func TestGetCertificate(t *testing.T) { @@ -82,14 +80,14 @@ func TestOperationsSignCSR(t *testing.T) { { name: "OK_NO_CERTIFICATE_AUTHORITY_INPUT", req: func() *apiv1.OperationsSignRequest { - req := types.CertificateRequest{ + req := c.CertificateRequest{ CommonName: "development.example.com", SubjectAlternateNames: []string{"development.example.com"}, SigningAlgorithm: x509.SHA512WithRSA, PublicKeyAlgorithm: x509.RSA, KeySize: 2048, } - csr, _ := crypto.GenerateCSR(req) + csr, _ := c.GenerateCSR(req) return &apiv1.OperationsSignRequest{ CertificateSigningRequest: csr.CSR.String(), @@ -140,14 +138,14 @@ func TestOperationsSignCSR(t *testing.T) { { name: "OK_CERTIFICATE_AUTHORITY_INPUT", req: func() *apiv1.OperationsSignRequest { - req := types.CertificateRequest{ + req := c.CertificateRequest{ CommonName: "development.example.com", SubjectAlternateNames: []string{"development.example.com"}, SigningAlgorithm: x509.SHA512WithRSA, PublicKeyAlgorithm: x509.RSA, KeySize: 2048, } - csr, _ := crypto.GenerateCSR(req) + csr, _ := c.GenerateCSR(req) caParameter := &apiv1.CertificateAuthorityParameter{ Region: "us-east-1", diff --git a/internal/v1/certificate/pca.go b/internal/v1/certificate/pca.go index 99f5647..6f764e9 100644 --- a/internal/v1/certificate/pca.go +++ b/internal/v1/certificate/pca.go @@ -8,10 +8,10 @@ import ( "math/big" "time" - db "github.com/coinbase/baseca/db/sqlc" "github.com/coinbase/baseca/internal/client/firehose" "github.com/coinbase/baseca/internal/lib/crypto" "github.com/coinbase/baseca/internal/types" + lib "github.com/coinbase/baseca/pkg/crypto" ) func (c *Certificate) buildCertificateAuthorityParameters(certificate_authority string) types.CertificateParameters { @@ -25,7 +25,7 @@ func (c *Certificate) buildCertificateAuthorityParameters(certificate_authority } } -func (c *Certificate) issueEndEntityCertificate(auth *types.ServiceAccountPayload, ca_certificate *types.CertificateAuthority, request_csr *x509.CertificateRequest) (*db.CertificateResponseData, error) { +func (c *Certificate) issueEndEntityCertificate(auth *types.ServiceAccountPayload, ca_certificate *types.CertificateAuthority, request_csr *x509.CertificateRequest) (*types.CertificateResponseData, error) { block := make([]byte, 20) _, err := rand.Read(block[:]) if err != nil { @@ -60,11 +60,11 @@ func (c *Certificate) issueEndEntityCertificate(auth *types.ServiceAccountPayloa } certificateAuthorityRaw := ca_certificate.Certificate.Raw - pk, err := crypto.ReturnPrivateKey(*ca_certificate.AsymmetricKey) + signer, err := lib.ReturnSignerInterface(ca_certificate.PrivateKey) if err != nil { return nil, err } - certificateRaw, err := x509.CreateCertificate(rand.Reader, &certificateTemplate, ca_certificate.Certificate, request_csr.PublicKey, pk) + certificateRaw, err := x509.CreateCertificate(rand.Reader, &certificateTemplate, ca_certificate.Certificate, request_csr.PublicKey, signer) if err != nil { return nil, err } @@ -108,7 +108,7 @@ func (c *Certificate) issueEndEntityCertificate(auth *types.ServiceAccountPayloa return nil, err } - return &db.CertificateResponseData{ + return &types.CertificateResponseData{ Certificate: certificate.String(), IntermediateCertificateChain: intermediate_chain.String(), RootCertificateChain: root_chain.String(), diff --git a/internal/v1/certificate/query.go b/internal/v1/certificate/query.go index 0f94528..6ef1845 100644 --- a/internal/v1/certificate/query.go +++ b/internal/v1/certificate/query.go @@ -98,7 +98,7 @@ func (c *Certificate) QueryCertificateMetadata(ctx context.Context, req *apiv1.Q } if len(req.Environment) != 0 { - if _, ok := validator.CertificateAuthorityEnvironments[req.Environment]; !ok { + if _, ok := validator.CertificateAuthorityEnvironmentsString[req.Environment]; !ok { return nil, logger.RpcError(status.Error(codes.InvalidArgument, "invalid environment"), fmt.Errorf("invalid environment: %s", req.Environment)) } diff --git a/internal/v1/certificate/sign.go b/internal/v1/certificate/sign.go index aedde1d..723bffe 100644 --- a/internal/v1/certificate/sign.go +++ b/internal/v1/certificate/sign.go @@ -5,10 +5,8 @@ import ( "context" "crypto/x509" "encoding/pem" - "errors" "fmt" - db "github.com/coinbase/baseca/db/sqlc" apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" "github.com/coinbase/baseca/internal/client/acmpca" "github.com/coinbase/baseca/internal/config" @@ -17,6 +15,8 @@ import ( "github.com/coinbase/baseca/internal/lib/util/validator" "github.com/coinbase/baseca/internal/logger" "github.com/coinbase/baseca/internal/types" + baseca "github.com/coinbase/baseca/pkg/client" + lib "github.com/coinbase/baseca/pkg/types" "github.com/gogo/status" "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/timestamppb" @@ -78,9 +78,9 @@ func (c *Certificate) SignCSR(ctx context.Context, req *apiv1.CertificateSigning } -func (c *Certificate) requestCertificate(ctx context.Context, authPayload *types.ServiceAccountPayload, certificateRequest *x509.CertificateRequest) (*db.CertificateResponseData, error) { +func (c *Certificate) requestCertificate(ctx context.Context, authPayload *types.ServiceAccountPayload, certificateRequest *x509.CertificateRequest) (*types.CertificateResponseData, error) { var subordinate *types.CertificateAuthority - var parameters types.CertificateRequest + var parameters baseca.CertificateRequest var csr *bytes.Buffer var err error @@ -103,18 +103,18 @@ func (c *Certificate) requestCertificate(ctx context.Context, authPayload *types return nil, err } - signingAlgorithm, ok := types.ValidSignatures[c.ca.SigningAlgorithm] + signingAlgorithm, ok := lib.ValidSignatures[c.ca.SigningAlgorithm] if !ok { return nil, fmt.Errorf("invalid signing algorithm: %s", c.ca.SigningAlgorithm) } - parameters = types.CertificateRequest{ + parameters = baseca.CertificateRequest{ CommonName: intermediateCa, SubjectAlternateNames: []string{intermediateCa}, SigningAlgorithm: signingAlgorithm.Common, - PublicKeyAlgorithm: types.PublicKeyAlgorithms[c.ca.KeyAlgorithm].Algorithm, + PublicKeyAlgorithm: lib.PublicKeyAlgorithmStrings[c.ca.KeyAlgorithm].Algorithm, KeySize: c.ca.KeySize, - DistinguishedName: types.DistinguishedName{ + DistinguishedName: baseca.DistinguishedName{ Country: []string{c.ca.Country}, Province: []string{c.ca.Province}, Locality: []string{c.ca.Locality}, @@ -123,18 +123,12 @@ func (c *Certificate) requestCertificate(ctx context.Context, authPayload *types }, } - signingRequest, err := crypto.GenerateCSR(parameters) + signingRequest, err := baseca.GenerateCSR(parameters) if err != nil { return nil, err } - key, err := x509.ParsePKCS1PrivateKey(signingRequest.PrivateKey.Bytes) - if err != nil { - return nil, errors.New("error parsing pkcs1 rsa private key") - } - - pk := crypto.RSA{PrivateKey: key, PublicKey: &key.PublicKey} - err = crypto.WriteKeyToFile(intermediateCa, &pk) + err = crypto.WriteKeyToFile(intermediateCa, signingRequest.PrivateKey) if err != nil { return nil, err } diff --git a/internal/v1/certificate/validate.go b/internal/v1/certificate/validate.go index 4da86cf..eeee822 100644 --- a/internal/v1/certificate/validate.go +++ b/internal/v1/certificate/validate.go @@ -17,6 +17,7 @@ import ( "github.com/coinbase/baseca/internal/lib/util" "github.com/coinbase/baseca/internal/lib/util/validator" "github.com/coinbase/baseca/internal/types" + lib "github.com/coinbase/baseca/pkg/types" ) const ( @@ -120,17 +121,17 @@ func convertX509toString(certificate []byte) (*bytes.Buffer, error) { func ValidateSubordinateParameters(parameter config.SubordinateCertificateAuthority) error { switch parameter.KeyAlgorithm { case "RSA": - if _, ok := types.PublicKeyAlgorithms[parameter.KeyAlgorithm].KeySize[parameter.KeySize]; !ok { + if _, ok := lib.PublicKeyAlgorithmStrings[parameter.KeyAlgorithm].KeySize[parameter.KeySize]; !ok { return fmt.Errorf("invalid rsa key size: %d", parameter.KeySize) } - if _, ok := types.PublicKeyAlgorithms[parameter.KeyAlgorithm].Signature[parameter.SigningAlgorithm]; !ok { + if _, ok := lib.PublicKeyAlgorithmStrings[parameter.KeyAlgorithm].Signature[parameter.SigningAlgorithm]; !ok { return fmt.Errorf("invalid rsa signing algorithm: %s", parameter.SigningAlgorithm) } case "ECDSA": - if _, ok := types.PublicKeyAlgorithms[parameter.KeyAlgorithm].KeySize[parameter.KeySize]; !ok { + if _, ok := lib.PublicKeyAlgorithmStrings[parameter.KeyAlgorithm].KeySize[parameter.KeySize]; !ok { return fmt.Errorf("invalid ecdsa key size: %d", parameter.KeySize) } - if _, ok := types.PublicKeyAlgorithms[parameter.KeyAlgorithm].Signature[parameter.SigningAlgorithm]; !ok { + if _, ok := lib.PublicKeyAlgorithmStrings[parameter.KeyAlgorithm].Signature[parameter.SigningAlgorithm]; !ok { return fmt.Errorf("invalid ecdsa signing algorithm: %s", parameter.SigningAlgorithm) } default: diff --git a/internal/v1/middleware/accounts.go b/internal/v1/middleware/accounts.go new file mode 100644 index 0000000..5f1a202 --- /dev/null +++ b/internal/v1/middleware/accounts.go @@ -0,0 +1,188 @@ +package middleware + +import ( + "context" + "fmt" + "strings" + + lib "github.com/coinbase/baseca/internal/lib/authentication" + "github.com/coinbase/baseca/internal/lib/util/validator" + "github.com/coinbase/baseca/internal/logger" + "github.com/coinbase/baseca/internal/types" + "github.com/gogo/status" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +type AuthenticationChannel chan<- AuthenticationMetadata + +type AuthenticationMetadata struct { + Account *types.ServiceAccountPayload + Error error +} + +type ServiceAccount struct { + middleware *Middleware +} + +func (s *ServiceAccount) Authenticate(ch <-chan context.Context, auth AuthenticationChannel) { + for ctx := range ch { + credentials, err := extractRequestMetadata(ctx) + if err != nil { + auth <- AuthenticationMetadata{Error: logger.RpcError(status.Error(codes.Internal, "authentication failed"), err)} + return + } + + attestation, err := s.middleware.searchServiceAccountMetadata(ctx, credentials.ClientId) + if err != nil { + auth <- AuthenticationMetadata{Error: logger.RpcError(status.Error(codes.Internal, "authentication failed"), err)} + return + } + + if err := lib.CheckPassword(credentials.ClientToken, attestation.ServiceAccount.ApiToken); err != nil { + auth <- AuthenticationMetadata{Error: logger.RpcError(status.Error(codes.Internal, "authentication failed"), err)} + return + } + + instance_tags, err := validator.ConvertNullRawMessageToMap(attestation.AwsIid.InstanceTags) + if err != nil { + auth <- AuthenticationMetadata{Error: logger.RpcError(status.Error(codes.Internal, "authentication failed"), err)} + return + } + + iid := types.NodeIIDAttestation{ + Uuid: attestation.ServiceAccount.ClientID, + Attestation: types.EC2NodeAttestation{ + ClientID: attestation.ServiceAccount.ClientID, + RoleArn: attestation.AwsIid.RoleArn.String, + AssumeRole: attestation.AwsIid.AssumeRole.String, + SecurityGroups: attestation.AwsIid.SecurityGroupID, + Region: attestation.AwsIid.Region.String, + InstanceID: attestation.AwsIid.InstanceID.String, + ImageID: attestation.AwsIid.ImageID.String, + InstanceTags: instance_tags, + }, + } + + err = s.middleware.attestNode(ctx, iid, attestation.ServiceAccount.NodeAttestation) + if err != nil { + auth <- AuthenticationMetadata{Error: logger.RpcError(status.Error(codes.Internal, "attestation failed"), err)} + return + } + + account := &types.ServiceAccountPayload{ + ServiceID: attestation.ServiceAccount.ClientID, + ServiceAccount: attestation.ServiceAccount.ServiceAccount, + Environment: attestation.ServiceAccount.Environment, + ValidSubjectAlternateName: attestation.ServiceAccount.ValidSubjectAlternateName, + ValidCertificateAuthorities: attestation.ServiceAccount.ValidCertificateAuthorities, + CertificateValidity: attestation.ServiceAccount.CertificateValidity, + SubordinateCa: attestation.ServiceAccount.SubordinateCa, + ExtendedKey: attestation.ServiceAccount.ExtendedKey, + SANRegularExpression: validator.NullStringToString(&attestation.ServiceAccount.RegularExpression), + } + auth <- AuthenticationMetadata{Account: account} + } +} + +type ProvisionerAccount struct { + middleware *Middleware +} + +func (p *ProvisionerAccount) Authenticate(ctx context.Context) (interface{}, error) { + credentials, err := extractRequestMetadata(ctx) + if err != nil { + return nil, logger.RpcError(status.Error(codes.Internal, "authentication failed"), err) + } + + attestation, err := p.middleware.serachProvisionerAccountAttestation(ctx, credentials.ClientId) + if err != nil { + return nil, logger.RpcError(status.Error(codes.Internal, "authentication failed"), err) + } + + if err := lib.CheckPassword(credentials.ClientToken, attestation.ProvisionerAccount.ApiToken); err != nil { + return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication failed"), err) + } + + instance_tags, err := validator.ConvertNullRawMessageToMap(attestation.AwsIid.InstanceTags) + if err != nil { + return nil, logger.RpcError(status.Error(codes.Unauthenticated, "error querying attestation tags"), err) + } + + iid := types.NodeIIDAttestation{ + Uuid: attestation.ProvisionerAccount.ClientID, + Attestation: types.EC2NodeAttestation{ + ClientID: attestation.ProvisionerAccount.ClientID, + RoleArn: attestation.AwsIid.RoleArn.String, + AssumeRole: attestation.AwsIid.AssumeRole.String, + SecurityGroups: attestation.AwsIid.SecurityGroupID, + Region: attestation.AwsIid.Region.String, + InstanceID: attestation.AwsIid.InstanceID.String, + ImageID: attestation.AwsIid.ImageID.String, + InstanceTags: instance_tags, + }, + } + + err = p.middleware.attestNode(ctx, iid, attestation.ProvisionerAccount.NodeAttestation) + if err != nil { + return nil, err + } + + account := &types.ProvisionerAccountPayload{ + ClientId: attestation.ProvisionerAccount.ClientID, + ProvisionerAccount: attestation.ProvisionerAccount.ProvisionerAccount, + Environments: attestation.ProvisionerAccount.Environments, + ValidSubjectAlternateNames: attestation.ProvisionerAccount.ValidSubjectAlternateNames, + MaxCertificateValidity: uint32(attestation.ProvisionerAccount.MaxCertificateValidity), + ExtendedKeys: attestation.ProvisionerAccount.ExtendedKeys, + RegularExpression: validator.NullStringToString(&attestation.ProvisionerAccount.RegularExpression), + } + return account, nil +} + +type UserAccount struct { + middleware *Middleware + info *grpc.UnaryServerInfo +} + +func (u *UserAccount) Authenticate(ctx context.Context) (interface{}, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Internal, "failed to retrieve metadata from context") + } + + authorizationHeader, ok := md[authorizationHeaderKey] + if !ok || len(authorizationHeader) == 0 { + return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication failed"), fmt.Errorf("request header empty: %s", authorizationHeaderKey)) + } + + if len(authorizationHeader) != 0 { + fields := strings.Fields(authorizationHeader[0]) + if len(fields) < 2 { + return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication failed"), fmt.Errorf("authorization header not provided")) + } + + authorizationType := strings.ToLower(fields[0]) + if authorizationType != authorizationTypeBearer { + return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication failed"), fmt.Errorf("authorization header not provided")) + } + + accessToken := fields[1] + payload, err := u.middleware.auth.Verify(ctx, accessToken) + if err != nil { + return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication failed"), err) + } + + userPermission := payload.Permission + ok, err := u.middleware.enforcer.Enforce(userPermission, u.info.FullMethod) + if err != nil { + return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication failed"), err) + } + if !ok { + return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication failed"), fmt.Errorf("invalid permission error %s", userPermission)) + } + return payload, nil + } + return nil, nil +} diff --git a/internal/v1/middleware/attestation.go b/internal/v1/middleware/attestation.go new file mode 100644 index 0000000..b098a20 --- /dev/null +++ b/internal/v1/middleware/attestation.go @@ -0,0 +1,80 @@ +package middleware + +import ( + "context" + "encoding/json" + + "github.com/coinbase/baseca/internal/attestation/aws_iid" + "github.com/coinbase/baseca/internal/logger" + "github.com/coinbase/baseca/internal/types" + iid "github.com/coinbase/baseca/pkg/attestor/aws_iid" + "github.com/gogo/status" + "github.com/google/uuid" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +type Credentials struct { + ClientId uuid.UUID + ClientToken string +} + +func (m *Middleware) attestNode(ctx context.Context, node types.NodeIIDAttestation, attestations []string) error { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return status.Errorf(codes.Internal, "failed to retrieve metadata from context") + } + + for _, node_attestation := range attestations { + clientIdentityDocumentHeader, ok := md[clientIdentityDocumentHeaderKey] + if !ok { + return status.Errorf(codes.InvalidArgument, "authorization header not provided") + } + + aws_iid_metadata_bytes := []byte(clientIdentityDocumentHeader[0]) + instance_metadata := iid.EC2InstanceMetadata{} + + err := json.Unmarshal(aws_iid_metadata_bytes, &instance_metadata) + if err != nil { + return status.Errorf(codes.InvalidArgument, "error unmarshal aws_instance metadata") + } + + switch node_attestation { + // EC2 Instance Identity Document Attestation + case types.AWS_IID.String(): + node.EC2InstanceMetadata = instance_metadata + attestation_err := aws_iid.AWSIidNodeAttestation(node, m.cache) + if attestation_err != nil { + return logger.RpcError(status.Error(codes.Unauthenticated, "aws_iid attestation error"), attestation_err) + } + } + } + return nil +} + +func extractRequestMetadata(ctx context.Context) (*Credentials, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Internal, "failed to retrieve metadata from context") + } + + clientIdAuthorizationHeader, ok := md[clientIdAuthorizationHeaderKey] + if !ok { + return nil, status.Errorf(codes.InvalidArgument, "authorization header not provided") + } + + clientTokenAuthorizationHeader, ok := md[clientTokenAuthorizationHeaderKey] + if !ok { + return nil, status.Errorf(codes.InvalidArgument, "authorization header not provided") + } + + client_uuid, err := uuid.Parse(clientIdAuthorizationHeader[0]) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid authorization header") + } + + return &Credentials{ + ClientId: client_uuid, + ClientToken: clientTokenAuthorizationHeader[0], + }, nil +} diff --git a/internal/v1/middleware/authentication.go b/internal/v1/middleware/authentication.go index c201259..8323c66 100644 --- a/internal/v1/middleware/authentication.go +++ b/internal/v1/middleware/authentication.go @@ -2,305 +2,83 @@ package middleware import ( "context" - "encoding/json" - "fmt" - "strings" - db "github.com/coinbase/baseca/db/sqlc" - "github.com/coinbase/baseca/internal/attestation/aws_iid" - lib "github.com/coinbase/baseca/internal/lib/authentication" - "github.com/coinbase/baseca/internal/lib/util/validator" - "github.com/coinbase/baseca/internal/logger" + "github.com/coinbase/baseca/internal/lib/util" "github.com/coinbase/baseca/internal/types" - "github.com/gogo/status" - "github.com/google/uuid" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" ) -const ( - _pass_auth = "pass_authentication" - _service_auth = "service_authentication" - _provisioner_auth = "provisioner_authentication" +var ( + _default_queue = 1000 + ch = make(chan context.Context, _default_queue) + auth = make(chan AuthenticationMetadata, _default_queue) ) -func (m *Middleware) ServerAuthenticationInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - var auth string +func (m *Middleware) ServerAuthenticationInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + var _authenticated bool + var authentication types.AuthenticationKey var ok bool - methods := map[string]string{ - "/grpc.health.v1.Health/Check": _pass_auth, - "/baseca.v1.Account/LoginUser": _pass_auth, - "/baseca.v1.Account/UpdateUserCredentials": _pass_auth, - "/baseca.v1.Certificate/SignCSR": _service_auth, - "/baseca.v1.Certificate/OperationsSignCSR": _provisioner_auth, - "/baseca.v1.Certificate/QueryCertificateMetadata": _provisioner_auth, - "/baseca.v1.Service/ProvisionServiceAccount": _provisioner_auth, - "/baseca.v1.Service/GetServiceAccountByMetadata": _provisioner_auth, - "/baseca.v1.Service/DeleteProvisionedServiceAccount": _provisioner_auth, - } - - if auth, ok = methods[info.FullMethod]; !ok { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return nil, status.Errorf(codes.Internal, "failed to retrieve metadata from context") - } - - authorizationHeader, ok := md[authorizationHeaderKey] - if !ok || len(authorizationHeader) == 0 { - return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication failed"), fmt.Errorf("request header empty: %s", authorizationHeaderKey)) - } - - if len(authorizationHeader) != 0 { - fields := strings.Fields(authorizationHeader[0]) - if len(fields) < 2 { - return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication failed"), fmt.Errorf("authorization header not provided")) - } - - authorizationType := strings.ToLower(fields[0]) - if authorizationType != authorizationTypeBearer { - return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication failed"), fmt.Errorf("authorization header not provided")) - } - - accessToken := fields[1] - payload, err := m.auth.Verify(ctx, accessToken) - if err != nil { - return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication failed"), err) - } - - userPermission := payload.Permission - ok, err := m.enforcer.Enforce(userPermission, info.FullMethod) - if err != nil { - return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication failed"), err) - } - if !ok { - return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication failed"), fmt.Errorf("invalid permission error %s", userPermission)) - } - ctx = context.WithValue(ctx, types.UserAuthenticationContextKey, payload) + // User Authentication + if authentication, ok = types.Methods[info.FullMethod]; !ok { + userAccount := &UserAccount{ + middleware: m, + info: info, } - } else if auth == _service_auth { - service, err := m.AuthenticateServiceAccount(ctx) + payload, err := userAccount.Authenticate(ctx) if err != nil { return nil, err } - - ctx = context.WithValue(ctx, types.ServiceAuthenticationContextKey, service) - } else if auth == _provisioner_auth { - service, err := m.AuthenticateProvisionerAccount(ctx) - if err != nil { - return nil, err - } - - ctx = context.WithValue(ctx, types.ProvisionerAuthenticationContextKey, service) - } - return handler(ctx, req) -} - -func (m *Middleware) AuthenticateServiceAccount(ctx context.Context) (*types.ServiceAccountPayload, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return nil, status.Errorf(codes.Internal, "failed to retrieve metadata from context") - } - - clientIdAuthorizationHeader, ok := md[clientIdAuthorizationHeaderKey] - if !ok { - return nil, status.Errorf(codes.InvalidArgument, "authorization header not provided") - } - - clientTokenAuthorizationHeader, ok := md[clientTokenAuthorizationHeaderKey] - if !ok { - return nil, status.Errorf(codes.InvalidArgument, "authorization header not provided") - } - - client_uuid, err := uuid.Parse(clientIdAuthorizationHeader[0]) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid authorization header") - } - - cachedServiceAccount, err := m.authenticationCacheServiceAccount(ctx, client_uuid) - if err != nil { - return nil, logger.RpcError(status.Error(codes.Internal, "internal server error"), err) - } - - account := cachedServiceAccount.ServiceAccount - if err := lib.CheckPassword(clientTokenAuthorizationHeader[0], account.ApiToken); err != nil { - return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication error"), err) + ctx = context.WithValue(ctx, types.UserAuthenticationContextKey, payload) } - for _, node_attestation := range account.NodeAttestation { - clientIdentityDocumentHeader, ok := md[clientIdentityDocumentHeaderKey] - if !ok { - return nil, status.Errorf(codes.InvalidArgument, "authorization header not provided") + switch authentication { + // Service Account Authentication + case types.ServiceAuthentication: + serviceAccount := &ServiceAccount{ + middleware: m, } - switch node_attestation { - case "AWS_IID": - // Compare Signed Node Data with Attestation Table in Database - attestation_err := aws_iid.AWSIidNodeAttestation(client_uuid, clientIdentityDocumentHeader[0], cachedServiceAccount.AwsIid, m.cache) - if attestation_err != nil { - return nil, logger.RpcError(status.Error(codes.Unauthenticated, "aws_iid attestation error"), attestation_err) - } - } - } - - service := &types.ServiceAccountPayload{ - ServiceID: account.ClientID, - ServiceAccount: account.ServiceAccount, - Environment: account.Environment, - ValidSubjectAlternateName: account.ValidSubjectAlternateName, - ValidCertificateAuthorities: account.ValidCertificateAuthorities, - CertificateValidity: account.CertificateValidity, - SubordinateCa: account.SubordinateCa, - ExtendedKey: account.ExtendedKey, - SANRegularExpression: validator.NullStringToString(&account.RegularExpression), - } - - return service, nil -} - -func (m *Middleware) AuthenticateProvisionerAccount(ctx context.Context) (*types.ProvisionerAccountPayload, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return nil, status.Errorf(codes.Internal, "failed to retrieve metadata from context") - } - - clientIdAuthorizationHeader, ok := md[clientIdAuthorizationHeaderKey] - if !ok { - return nil, status.Errorf(codes.InvalidArgument, "authorization header not provided") - } - - clientTokenAuthorizationHeader, ok := md[clientTokenAuthorizationHeaderKey] - if !ok { - return nil, status.Errorf(codes.InvalidArgument, "authorization header not provided") - } - - client_uuid, err := uuid.Parse(clientIdAuthorizationHeader[0]) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid authorization header") - } - - cachedProvisionerAccount, err := m.authenticationCacheProvisionerAccount(ctx, client_uuid) - if err != nil { - return nil, logger.RpcError(status.Error(codes.Internal, "internal server error"), err) - } - - account := cachedProvisionerAccount.ProvisionerAccount - if err := lib.CheckPassword(clientTokenAuthorizationHeader[0], account.ApiToken); err != nil { - return nil, logger.RpcError(status.Error(codes.Unauthenticated, "authentication error"), err) - } - - for _, node_attestation := range account.NodeAttestation { - clientIdentityDocumentHeader, ok := md[clientIdentityDocumentHeaderKey] - if !ok { - return nil, status.Errorf(codes.InvalidArgument, "authorization header not provided") - } - - switch node_attestation { - case "AWS_IID": - // Compare Signed Node Data with Attestation Table in Database - attestation_err := aws_iid.AWSIidNodeAttestation(client_uuid, clientIdentityDocumentHeader[0], cachedProvisionerAccount.AwsIid, m.cache) - if attestation_err != nil { - return nil, logger.RpcError(status.Error(codes.Unauthenticated, "aws_iid attestation error"), attestation_err) + for !_authenticated { + // CPU Load High + for util.CPU_HIGH { + err := util.ProcessBackoff() + if err != nil { + return nil, err + } } - } - } - - service := &types.ProvisionerAccountPayload{ - ClientId: account.ClientID, - ProvisionerAccount: account.ProvisionerAccount, - Environments: account.Environments, - ValidSubjectAlternateNames: account.ValidSubjectAlternateNames, - MaxCertificateValidity: uint32(account.MaxCertificateValidity), - ExtendedKeys: account.ExtendedKeys, - RegularExpression: validator.NullStringToString(&account.RegularExpression), - } - return service, nil -} - -func (m *Middleware) authenticationCacheServiceAccount(ctx context.Context, client_uuid uuid.UUID) (*db.CachedServiceAccount, error) { - var service_account *db.Account - var instance_identity_document *db.AwsAttestation - var cached_service_account db.CachedServiceAccount - var err error - - db_reader := m.store.Reader - uuid := client_uuid.String() - if value, cached := m.cache.Get(uuid); cached == nil { - err = json.Unmarshal(value, &cached_service_account) - if err != nil { - return &cached_service_account, fmt.Errorf("error unmarshal cached service account account, %s", err) - } - } else { - service_account, err = db_reader.GetServiceUUID(ctx, client_uuid) - if err != nil { - return &cached_service_account, fmt.Errorf("service authentication failed: %s", err) - } - - cached_service_account.ServiceAccount = *service_account - for _, node_attestation := range service_account.NodeAttestation { - switch node_attestation { - case types.Attestation.AWS_IID: - instance_identity_document, err = aws_iid.GetInstanceIdentityDocument(ctx, db_reader, client_uuid) + select { + case ch <- ctx: + go serviceAccount.Authenticate(ch, auth) + _authenticated = true + default: + // Channel Full + err := util.ProcessBackoff() if err != nil { - return &cached_service_account, fmt.Errorf("aws_iid node attestation failed: %s", err) + return nil, err } - cached_service_account.AwsIid = *instance_identity_document } } - data, err := json.Marshal(cached_service_account) - if err != nil { - return &cached_service_account, fmt.Errorf("error marshalling cached_service_account, %s", err) - } - err = m.cache.Set(uuid, data) - if err != nil { - return &cached_service_account, fmt.Errorf("error setting middleware cache, %s", err) - } - } - return &cached_service_account, nil -} - -func (m *Middleware) authenticationCacheProvisionerAccount(ctx context.Context, client_uuid uuid.UUID) (*db.CachedProvisionerAccount, error) { - var provisioner_account *db.Provisioner - var instance_identity_document *db.AwsAttestation - var cached_provisioner_account db.CachedProvisionerAccount - var err error - - db_reader := m.store.Reader - uuid := client_uuid.String() - if value, cached := m.cache.Get(uuid); cached == nil { - err = json.Unmarshal(value, &cached_provisioner_account) + service := <-auth + err := service.Error if err != nil { - return &cached_provisioner_account, fmt.Errorf("error unmarshal cached service account account, %s", err) - } - } else { - provisioner_account, err = db_reader.GetProvisionerUUID(ctx, client_uuid) - if err != nil { - return &cached_provisioner_account, fmt.Errorf("service authentication failed: %s", err) + return nil, err } - cached_provisioner_account.ProvisionerAccount = *provisioner_account - for _, node_attestation := range provisioner_account.NodeAttestation { - switch node_attestation { - case types.Attestation.AWS_IID: - instance_identity_document, err = aws_iid.GetInstanceIdentityDocument(ctx, db_reader, client_uuid) - if err != nil { - return &cached_provisioner_account, fmt.Errorf("aws_iid node attestation failed: %s", err) - } - cached_provisioner_account.AwsIid = *instance_identity_document - } - } + ctx = context.WithValue(ctx, types.ServiceAuthenticationContextKey, service.Account) - data, err := json.Marshal(cached_provisioner_account) - if err != nil { - return &cached_provisioner_account, fmt.Errorf("error marshalling cached_service_account, %s", err) + // Provisioner Account Authentication + case types.ProvisionerAuthentication: + provisionerAccount := &ProvisionerAccount{ + middleware: m, } - err = m.cache.Set(uuid, data) + provisioner, err := provisionerAccount.Authenticate(ctx) if err != nil { - return &cached_provisioner_account, fmt.Errorf("error setting middleware cache, %s", err) + return nil, err } + ctx = context.WithValue(ctx, types.ProvisionerAuthenticationContextKey, provisioner) } - return &cached_provisioner_account, nil + return handler(ctx, req) } diff --git a/internal/v1/middleware/context.go b/internal/v1/middleware/context.go new file mode 100644 index 0000000..19b050a --- /dev/null +++ b/internal/v1/middleware/context.go @@ -0,0 +1,39 @@ +package middleware + +import ( + "context" + + "github.com/coinbase/baseca/internal/types" + "github.com/gogo/status" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +func (m *Middleware) SetAuthenticationContext(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Internal, "failed to retrieve metadata from context") + } + + if auth, ok := types.Methods[info.FullMethod]; ok { + // Service Account UUID + if auth == types.ServiceAuthentication { + clientIdAuthorizationHeader, ok := md[clientIdAuthorizationHeaderKey] + if !ok { + return nil, status.Errorf(codes.InvalidArgument, "authorization header not provided") + } + ctx = context.WithValue(ctx, types.ServiceAuthenticationContextKey, clientIdAuthorizationHeader[0]) + } + + // Provisioner Account UUID + if auth == types.ProvisionerAuthentication { + clientIdAuthorizationHeader, ok := md[clientIdAuthorizationHeaderKey] + if !ok { + return nil, status.Errorf(codes.InvalidArgument, "authorization header not provided") + } + ctx = context.WithValue(ctx, types.ProvisionerAuthenticationContextKey, clientIdAuthorizationHeader[0]) + } + } + return handler(ctx, req) +} diff --git a/internal/v1/middleware/search.go b/internal/v1/middleware/search.go new file mode 100644 index 0000000..1225ae1 --- /dev/null +++ b/internal/v1/middleware/search.go @@ -0,0 +1,128 @@ +package middleware + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + db "github.com/coinbase/baseca/db/sqlc" + "github.com/coinbase/baseca/internal/attestation/aws_iid" + "github.com/coinbase/baseca/internal/lib/util/validator" + "github.com/coinbase/baseca/internal/types" + "github.com/google/uuid" +) + +func (m *Middleware) searchServiceAccountMetadata(ctx context.Context, client_uuid uuid.UUID) (*db.ServiceAccountAttestation, error) { + var service_account *db.Account + var cached_service_account db.ServiceAccountAttestation + var instance_identity_document *types.EC2NodeAttestation + var err error + + db_reader := m.store.Reader + uuid := client_uuid.String() + if value, cached := m.cache.Get(uuid); cached == nil { + err = json.Unmarshal(value, &cached_service_account) + if err != nil { + return &cached_service_account, fmt.Errorf("error unmarshal cached service account account, %s", err) + } + } else { + service_account, err = db_reader.GetServiceUUID(ctx, client_uuid) + if err != nil { + return &cached_service_account, fmt.Errorf("service authentication failed: %s", err) + } + cached_service_account.ServiceAccount = *service_account + + for _, node_attestation := range service_account.NodeAttestation { + switch node_attestation { + case types.AWS_IID.String(): + instance_identity_document, err = aws_iid.GetInstanceIdentityDocument(ctx, db_reader, client_uuid) + if err != nil { + return &cached_service_account, fmt.Errorf("aws_iid node attestation failed: %s", err) + } + + instance_tags, err := validator.MapToNullRawMessage(instance_identity_document.InstanceTags) + if err != nil { + return &cached_service_account, fmt.Errorf("error marshalling instance tags, %s", err) + } + cached_service_account.AwsIid = db.AwsAttestation{ + ClientID: instance_identity_document.ClientID, + RoleArn: sql.NullString{String: instance_identity_document.RoleArn, Valid: len(instance_identity_document.RoleArn) != 0}, + AssumeRole: sql.NullString{String: instance_identity_document.AssumeRole, Valid: len(instance_identity_document.AssumeRole) != 0}, + SecurityGroupID: instance_identity_document.SecurityGroups, + Region: sql.NullString{String: instance_identity_document.Region, Valid: len(instance_identity_document.Region) != 0}, + InstanceID: sql.NullString{String: instance_identity_document.InstanceID, Valid: len(instance_identity_document.InstanceID) != 0}, + ImageID: sql.NullString{String: instance_identity_document.ImageID, Valid: len(instance_identity_document.ImageID) != 0}, + InstanceTags: instance_tags, + } + } + } + + data, err := json.Marshal(cached_service_account) + if err != nil { + return &cached_service_account, fmt.Errorf("error marshalling cached_service_account, %s", err) + } + err = m.cache.Set(uuid, data) + if err != nil { + return &cached_service_account, fmt.Errorf("error setting middleware cache, %s", err) + } + } + return &cached_service_account, nil +} + +func (m *Middleware) serachProvisionerAccountAttestation(ctx context.Context, client_uuid uuid.UUID) (*db.ProvisionerAccountAttestation, error) { + var provisioner_account *db.Provisioner + var instance_identity_document *types.EC2NodeAttestation + var cached_provisioner_account db.ProvisionerAccountAttestation + var err error + + db_reader := m.store.Reader + uuid := client_uuid.String() + if value, cached := m.cache.Get(uuid); cached == nil { + err = json.Unmarshal(value, &cached_provisioner_account) + if err != nil { + return &cached_provisioner_account, fmt.Errorf("error unmarshal cached service account account, %s", err) + } + } else { + provisioner_account, err = db_reader.GetProvisionerUUID(ctx, client_uuid) + if err != nil { + return &cached_provisioner_account, fmt.Errorf("service authentication failed: %s", err) + } + cached_provisioner_account.ProvisionerAccount = *provisioner_account + + for _, node_attestation := range provisioner_account.NodeAttestation { + switch node_attestation { + case types.AWS_IID.String(): + instance_identity_document, err = aws_iid.GetInstanceIdentityDocument(ctx, db_reader, client_uuid) + if err != nil { + return &cached_provisioner_account, fmt.Errorf("aws_iid node attestation failed: %s", err) + } + + instance_tags, err := validator.MapToNullRawMessage(instance_identity_document.InstanceTags) + if err != nil { + return &cached_provisioner_account, fmt.Errorf("error marshalling instance tags, %s", err) + } + cached_provisioner_account.AwsIid = db.AwsAttestation{ + ClientID: instance_identity_document.ClientID, + RoleArn: sql.NullString{String: instance_identity_document.RoleArn, Valid: len(instance_identity_document.RoleArn) != 0}, + AssumeRole: sql.NullString{String: instance_identity_document.AssumeRole, Valid: len(instance_identity_document.AssumeRole) != 0}, + SecurityGroupID: instance_identity_document.SecurityGroups, + Region: sql.NullString{String: instance_identity_document.Region, Valid: len(instance_identity_document.Region) != 0}, + InstanceID: sql.NullString{String: instance_identity_document.InstanceID, Valid: len(instance_identity_document.InstanceID) != 0}, + ImageID: sql.NullString{String: instance_identity_document.ImageID, Valid: len(instance_identity_document.ImageID) != 0}, + InstanceTags: instance_tags, + } + } + } + + data, err := json.Marshal(cached_provisioner_account) + if err != nil { + return &cached_provisioner_account, fmt.Errorf("error marshalling cached_service_account, %s", err) + } + err = m.cache.Set(uuid, data) + if err != nil { + return &cached_provisioner_account, fmt.Errorf("error setting middleware cache, %s", err) + } + } + return &cached_provisioner_account, nil +} diff --git a/internal/v1/users/operations_test.go b/internal/v1/users/operations_test.go index 58f4a2c..f3f8f98 100644 --- a/internal/v1/users/operations_test.go +++ b/internal/v1/users/operations_test.go @@ -2,26 +2,27 @@ package users import ( "context" + "crypto/rand" "database/sql" + "encoding/base64" + "encoding/hex" "fmt" "reflect" "testing" + "time" "github.com/coinbase/baseca/db/mock" db "github.com/coinbase/baseca/db/sqlc" apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" lib "github.com/coinbase/baseca/internal/lib/authentication" - "github.com/coinbase/baseca/internal/lib/util" + "github.com/coinbase/baseca/internal/types" + "github.com/google/uuid" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) -const ( - _read = "READ" -) - func TestCreateUser(t *testing.T) { - user, user_credentials := util.GenerateTestUser(t, _read, 20) + user, user_credentials := GenerateTestUser(t, types.READ.String(), 20) cases := []struct { name string @@ -75,7 +76,7 @@ func TestCreateUser(t *testing.T) { } func TestLoginUser(t *testing.T) { - user, user_credentials := util.GenerateTestUser(t, _read, 20) + user, user_credentials := GenerateTestUser(t, types.READ.String(), 20) cases := []struct { name string @@ -145,3 +146,87 @@ func (e eqCreateUserParamsMatcher) String() string { func EqCreateUserParams(arg db.CreateUserParams, password string) gomock.Matcher { return eqCreateUserParamsMatcher{arg, password} } + +func GenerateTestUser(t *testing.T, permissions string, length int) (db.User, string) { + client_id, _ := uuid.NewRandom() + credentials := generateRandomCredentials(length) + hashed_credentials, _ := lib.HashPassword(credentials) + email := generateRandomEmail() + username := generateRandomUsername() + full_name := generateRandomName() + + return db.User{ + Uuid: client_id, + Username: username, + HashedCredential: hashed_credentials, + FullName: full_name, + Email: email, + Permissions: permissions, + CredentialChangedAt: time.Now().UTC(), + CreatedAt: time.Now().UTC(), + }, credentials +} + +func generateRandomEmail() string { + randBytes := make([]byte, 8) + _, err := rand.Read(randBytes) + if err != nil { + panic(err) + } + + // Encode the random bytes using base64 encoding to get an ASCII string + randStr := base64.URLEncoding.EncodeToString(randBytes) + + // Use the first 10 characters of the base64-encoded string as the email username + return fmt.Sprintf("%s@coinbase.com", randStr[:10]) +} + +func generateRandomName() string { + // Generate random bytes for the first and last name + firstNameBytes := make([]byte, 6) + _, err := rand.Read(firstNameBytes) + if err != nil { + panic(err) + } + lastNameBytes := make([]byte, 6) + _, err = rand.Read(lastNameBytes) + if err != nil { + panic(err) + } + + // Convert the random bytes to hexadecimal strings + firstNameHex := hex.EncodeToString(firstNameBytes)[:10] + lastNameHex := hex.EncodeToString(lastNameBytes)[:10] + + return fmt.Sprintf("%s %s", firstNameHex, lastNameHex) +} + +func generateRandomUsername() string { + // Generate random bytes for the username + usernameBytes := make([]byte, 8) + _, err := rand.Read(usernameBytes) + if err != nil { + panic(err) + } + + // Encode the random bytes using base64 encoding to get an ASCII string + usernameStr := base64.URLEncoding.EncodeToString(usernameBytes) + + // Use the first 10 characters of the base64-encoded string as the username + return usernameStr[:10] +} + +func generateRandomCredentials(length int) string { + // Generate random bytes for the credentials + credentialsBytes := make([]byte, length) + _, err := rand.Read(credentialsBytes) + if err != nil { + panic(err) + } + + // Encode the random bytes using base64 encoding to get an ASCII string + credentialsStr := base64.URLEncoding.EncodeToString(credentialsBytes) + + // Return the first `length` characters of the base64-encoded string + return credentialsStr[:length] +} diff --git a/pkg/client/certificate.go b/pkg/client/certificate.go index a5e6ea4..5a630e7 100644 --- a/pkg/client/certificate.go +++ b/pkg/client/certificate.go @@ -2,10 +2,11 @@ package baseca import ( "context" + "fmt" + "os" apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" "github.com/coinbase/baseca/pkg/types" - "github.com/coinbase/baseca/pkg/util" ) func (c *Client) IssueCertificate(certificateRequest CertificateRequest) (*apiv1.SignedCertificate, error) { @@ -23,7 +24,7 @@ func (c *Client) IssueCertificate(certificateRequest CertificateRequest) (*apiv1 return nil, err } - err = util.ParseCertificateFormat(signedCertificate, types.SignedCertificate{ + err = ParseCertificateFormat(signedCertificate, types.SignedCertificate{ CertificatePath: certificateRequest.Output.Certificate, IntermediateCertificateChainPath: certificateRequest.Output.IntermediateCertificateChain, RootCertificateChainPath: certificateRequest.Output.RootCertificateChain, @@ -35,3 +36,30 @@ func (c *Client) IssueCertificate(certificateRequest CertificateRequest) (*apiv1 return signedCertificate, nil } + +func ParseCertificateFormat(certificate *apiv1.SignedCertificate, parameter types.SignedCertificate) error { + // Leaf Certificate Path + if len(parameter.CertificatePath) != 0 { + certificate := []byte(certificate.Certificate) + if err := os.WriteFile(parameter.CertificatePath, certificate, os.ModePerm); err != nil { + return fmt.Errorf("error writing certificate to [%s]", parameter.CertificatePath) + } + } + + // Intermediate Certificate Chain Path + if len(parameter.IntermediateCertificateChainPath) != 0 { + certificate := []byte(certificate.IntermediateCertificateChain) + if err := os.WriteFile(parameter.IntermediateCertificateChainPath, certificate, os.ModePerm); err != nil { + return fmt.Errorf("error writing certificate to [%s]", parameter.IntermediateCertificateChainPath) + } + } + + // Root Certificate Chain Path + if len(parameter.RootCertificateChainPath) != 0 { + certificate := []byte(certificate.CertificateChain) + if err := os.WriteFile(parameter.RootCertificateChainPath, certificate, os.ModePerm); err != nil { + return fmt.Errorf("error writing certificate chain to [%s]", parameter.RootCertificateChainPath) + } + } + return nil +} diff --git a/pkg/client/client.go b/pkg/client/client.go index a19e74f..0b77814 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" "sync" + "time" apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" "github.com/coinbase/baseca/pkg/attestor/aws_iid" @@ -29,6 +30,7 @@ type Client struct { Attestation string Certificate apiv1.CertificateClient Service apiv1.ServiceClient + *iidCache } type AccountClient interface { @@ -143,11 +145,11 @@ func (c *Client) clientAuthUnaryInterceptor(ctx context.Context, method string, ctx = metadata.AppendToOutgoingContext(ctx, _client_token_header, c.Authentication.ClientToken) if c.Attestation == Attestation.AWS { - instance_metadata, err := aws_iid.BuildInstanceMetadata() + instance_metadata, err := c.iidCache.Get() if err != nil { - return fmt.Errorf("error generating aws_iid node attestation") + return fmt.Errorf("error generating aws_iid node attestation: %w", err) } - ctx = metadata.AppendToOutgoingContext(ctx, _aws_iid_metadata, *instance_metadata) + ctx = metadata.AppendToOutgoingContext(ctx, _aws_iid_metadata, instance_metadata) } err := invoker(ctx, method, req, reply, cc, opts...) @@ -160,3 +162,27 @@ func (c *Client) accountAuthUnaryInterceptor(ctx context.Context, method string, err := invoker(ctx, method, req, reply, cc, opts...) return err } + +func (cache *iidCache) Get() (string, error) { + cache.lock.Lock() + defer cache.lock.Unlock() + + // If we have a cached value and it hasn't expired, use it. + if cache.value != "" && cache.expiration.After(time.Now()) { + return cache.value, nil + } + + // We have no cache or the cache has expired, so refresh it and return that. + instance_metadata, err := aws_iid.BuildInstanceMetadata() + if err != nil { + return "", fmt.Errorf("aws_iid.BuildInstanceMetadata failed: %w", err) + } + + if instance_metadata == nil { + return "", fmt.Errorf("got nil IID") + } + + cache.value = *instance_metadata + cache.expiration = time.Now().Add(iidCacheExpiration) + return cache.value, nil +} diff --git a/pkg/client/csr.go b/pkg/client/csr.go index f469571..48e5b0a 100644 --- a/pkg/client/csr.go +++ b/pkg/client/csr.go @@ -18,18 +18,18 @@ func GenerateCSR(csr CertificateRequest) (*types.SigningRequest, error) { switch csr.PublicKeyAlgorithm { case x509.RSA: - if _, ok := types.PublicKeyAlgorithms["RSA"].KeySize[csr.KeySize]; !ok { + if _, ok := types.PublicKeyAlgorithms[types.RSA].KeySize[csr.KeySize]; !ok { return nil, fmt.Errorf("rsa invalid key size %d", csr.KeySize) } - if _, ok := types.PublicKeyAlgorithms["RSA"].SigningAlgorithm[csr.SigningAlgorithm]; !ok { + if _, ok := types.PublicKeyAlgorithms[types.RSA].SigningAlgorithm[csr.SigningAlgorithm]; !ok { return nil, fmt.Errorf("rsa invalid signing algorithm %s", csr.SigningAlgorithm) } generator = &crypto.SigningRequestGeneratorRSA{Size: csr.KeySize} case x509.ECDSA: - if _, ok := types.PublicKeyAlgorithms["ECDSA"].KeySize[csr.KeySize]; !ok { - return nil, fmt.Errorf("ecdsa invalid key size %d", csr.KeySize) + if _, ok := types.PublicKeyAlgorithms[types.ECDSA].KeySize[csr.KeySize]; !ok { + return nil, fmt.Errorf("ecdsa invalid curve %d", csr.KeySize) } - if _, ok := types.PublicKeyAlgorithms["ECDSA"].SigningAlgorithm[csr.SigningAlgorithm]; !ok { + if _, ok := types.PublicKeyAlgorithms[types.ECDSA].SigningAlgorithm[csr.SigningAlgorithm]; !ok { return nil, fmt.Errorf("ecdsa invalid signing algorithm %s", csr.SigningAlgorithm) } generator = &crypto.SigningRequestGeneratorECDSA{Curve: csr.KeySize} @@ -49,6 +49,9 @@ func GenerateCSR(csr CertificateRequest) (*types.SigningRequest, error) { Locality: csr.DistinguishedName.Locality, Organization: csr.DistinguishedName.Organization, OrganizationalUnit: csr.DistinguishedName.OrganizationalUnit, + StreetAddress: csr.DistinguishedName.StreetAddress, + PostalCode: csr.DistinguishedName.PostalCode, + SerialNumber: csr.DistinguishedName.SerialNumber, } template := x509.CertificateRequest{ @@ -64,7 +67,7 @@ func GenerateCSR(csr CertificateRequest) (*types.SigningRequest, error) { certificatePem := new(bytes.Buffer) err = pem.Encode(certificatePem, &pem.Block{ - Type: "CERTIFICATE REQUEST", + Type: types.CERTIFICATE_REQUEST.String(), Bytes: csrBytes, }) @@ -88,6 +91,11 @@ func GenerateCSR(csr CertificateRequest) (*types.SigningRequest, error) { Bytes: pkBytes, } + pkcs8Encoding, err := crypto.EncodeToPKCS8(pkBlock) + if err != nil { + return nil, fmt.Errorf("error converting private key to pkcs8: %w", err) + } + if len(csr.Output.PrivateKey) != 0 { if err := os.WriteFile(csr.Output.PrivateKey, pem.EncodeToMemory(pkBlock), os.ModePerm); err != nil { return nil, fmt.Errorf("error writing private key to [%s]", csr.Output.PrivateKey) @@ -95,7 +103,8 @@ func GenerateCSR(csr CertificateRequest) (*types.SigningRequest, error) { } return &types.SigningRequest{ - CSR: certificatePem, - PrivateKey: pkBlock, + CSR: certificatePem, + PrivateKey: pkBlock, + EncodedPKCS8: pem.EncodeToMemory(pkcs8Encoding), }, nil } diff --git a/internal/lib/crypto/csr_test.go b/pkg/client/csr_test.go similarity index 53% rename from internal/lib/crypto/csr_test.go rename to pkg/client/csr_test.go index 9673887..34bfdf6 100644 --- a/internal/lib/crypto/csr_test.go +++ b/pkg/client/csr_test.go @@ -1,28 +1,28 @@ -package crypto +package baseca import ( "crypto/x509" "encoding/pem" "testing" - "github.com/coinbase/baseca/internal/types" + "github.com/coinbase/baseca/pkg/types" "github.com/stretchr/testify/assert" ) func TestGenerateCSR(t *testing.T) { - csr := types.CertificateRequest{ + csr := CertificateRequest{ PublicKeyAlgorithm: x509.RSA, KeySize: 2048, SigningAlgorithm: x509.SHA256WithRSA, CommonName: "example.com", - DistinguishedName: types.DistinguishedName{ + DistinguishedName: DistinguishedName{ Country: []string{"US"}, Province: []string{"CA"}, }, - SubjectAlternateNames: []string{"www.example.com", "sub.example.com"}, - Output: types.Output{ - CertificateSigningRequest: "/tmp/unit_test_csr.pem", - PrivateKey: "/tmp/unit_test_pk.pem", + SubjectAlternateNames: []string{"www.example.com", "subordinate.example.com"}, + Output: Output{ + CertificateSigningRequest: "/tmp/csr.pem", + PrivateKey: "/tmp/pk.pem", }, } @@ -31,23 +31,23 @@ func TestGenerateCSR(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, rsaSigningRequest) - assert.Contains(t, string(rsaSigningRequest.CSR.String()), "CERTIFICATE REQUEST") - assert.Contains(t, string(pem.EncodeToMemory(rsaSigningRequest.PrivateKey)), "RSA PRIVATE KEY") + assert.Contains(t, string(rsaSigningRequest.CSR.String()), types.CERTIFICATE_REQUEST.String()) + assert.Contains(t, string(pem.EncodeToMemory(rsaSigningRequest.PrivateKey)), types.RSA_PRIVATE_KEY.String()) // Create an ECDSA CertificateRequest - ecdsaCsr := types.CertificateRequest{ + ecdsaCsr := CertificateRequest{ PublicKeyAlgorithm: x509.ECDSA, KeySize: 256, SigningAlgorithm: x509.ECDSAWithSHA512, CommonName: "example.com", - DistinguishedName: types.DistinguishedName{ + DistinguishedName: DistinguishedName{ Country: []string{"US"}, Province: []string{"CA"}, }, - SubjectAlternateNames: []string{"www.example.com", "sub.example.com"}, - Output: types.Output{ - CertificateSigningRequest: "/tmp/unit_test_csr.pem", - PrivateKey: "/tmp/unit_test_pk.pem", + SubjectAlternateNames: []string{"www.example.com", "subordinate.example.com"}, + Output: Output{ + CertificateSigningRequest: "/tmp/csr.pem", + PrivateKey: "/tmp/pk.pem", }, } @@ -56,6 +56,6 @@ func TestGenerateCSR(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, ecdsaSigningRequest) - assert.Contains(t, string(ecdsaSigningRequest.CSR.String()), "CERTIFICATE REQUEST") - assert.Contains(t, string(pem.EncodeToMemory(ecdsaSigningRequest.PrivateKey)), "EC PRIVATE KEY") + assert.Contains(t, string(ecdsaSigningRequest.CSR.String()), types.CERTIFICATE_REQUEST.String()) + assert.Contains(t, string(pem.EncodeToMemory(ecdsaSigningRequest.PrivateKey)), types.ECDSA_PRIVATE_KEY.String()) } diff --git a/pkg/client/sign.go b/pkg/client/sign.go index b35c6a0..2e873bc 100644 --- a/pkg/client/sign.go +++ b/pkg/client/sign.go @@ -2,10 +2,8 @@ package baseca import ( "context" - "crypto" - "crypto/rand" + "crypto/ecdsa" "crypto/rsa" - "crypto/sha256" "crypto/x509" "encoding/pem" "errors" @@ -14,11 +12,17 @@ import ( "path/filepath" apiv1 "github.com/coinbase/baseca/gen/go/baseca/v1" + "github.com/coinbase/baseca/pkg/crypto" "github.com/coinbase/baseca/pkg/types" - "github.com/coinbase/baseca/pkg/util" ) -func (c *Client) GenerateSignature(csr CertificateRequest, element []byte) (*[]byte, []*x509.Certificate, error) { +var _buffer = 1024 + +type Signer interface { + Sign(data []byte) ([]byte, error) +} + +func (c *Client) GenerateSignature(csr CertificateRequest, data *[]byte) (*[]byte, []*x509.Certificate, error) { var certificatePem []*pem.Block var certificateChain []*x509.Certificate @@ -36,25 +40,23 @@ func (c *Client) GenerateSignature(csr CertificateRequest, element []byte) (*[]b return nil, nil, err } - err = util.ParseCertificateFormat(signedCertificate, types.SignedCertificate{ + err = ParseCertificateFormat(signedCertificate, types.SignedCertificate{ CertificatePath: csr.Output.Certificate, IntermediateCertificateChainPath: csr.Output.IntermediateCertificateChain, RootCertificateChainPath: csr.Output.RootCertificateChain, }) - if err != nil { return nil, nil, err } - hashedOutput := sha256.Sum256(element) - pk, err := x509.ParsePKCS1PrivateKey(signingRequest.PrivateKey.Bytes) + signer, err := parsePrivateKey(signingRequest.EncodedPKCS8, csr.SigningAlgorithm) if err != nil { - return nil, nil, errors.New("error parsing pkcs1 private key") + return nil, nil, err } - signature, err := rsa.SignPKCS1v15(rand.Reader, pk, crypto.SHA256, hashedOutput[:]) + signature, err := signer.Sign(*data) if err != nil { - return nil, nil, fmt.Errorf("error calculating signature of hash using pkcs1: %s", err) + return nil, nil, fmt.Errorf("error signing data: %w", err) } fullChain, err := os.ReadFile(filepath.Clean(csr.Output.RootCertificateChain)) @@ -84,49 +86,35 @@ func (c *Client) GenerateSignature(csr CertificateRequest, element []byte) (*[]b return &signature, certificateChain, nil } -func (c *Client) ValidateSignature(tc types.TrustChain, manifest types.Manifest) error { - err := manifest.CertificateChain[0].CheckSignature(manifest.SigningAlgorithm, manifest.Data, manifest.Signature) - if err != nil { - return fmt.Errorf("signature verification failed: %s", err) - } - - // Validate Entire Certificate Chain Does Not Break - for i := range manifest.CertificateChain[:len(manifest.CertificateChain)-1] { - err = manifest.CertificateChain[i].CheckSignatureFrom(manifest.CertificateChain[i+1]) - if err != nil { - return fmt.Errorf("certificate chain invalid: %s", err) - } +func parsePrivateKey(pk []byte, signatureAlgorithm x509.SignatureAlgorithm) (Signer, error) { + block, _ := pem.Decode(pk) + if block == nil { + return nil, errors.New("error parsing pem block from private key") } - if manifest.CertificateChain[0].Subject.CommonName != tc.CommonName { - return fmt.Errorf("invalid common name (cn) from code signing certificate") - } - - validSubjectAlternativeName := false - if len(manifest.CertificateChain[0].DNSNames) > 0 { - for _, san := range manifest.CertificateChain[0].DNSNames { - if san == tc.CommonName { - validSubjectAlternativeName = true - } - } - } - - if !validSubjectAlternativeName { - return fmt.Errorf("invalid subject alternative name (san) from code signing certificate") - } - - rootCertificatePool, err := util.GenerateCertificatePool(tc) + privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { - return err + return nil, fmt.Errorf("error parsing pkcs8 private key: %s", err) } - opts := x509.VerifyOptions{ - Roots: rootCertificatePool, - KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageCodeSigning}, + // Validate Signing Algorithm is Supported + algorithm, exist := types.SignatureAlgorithm[signatureAlgorithm] + if !exist { + return nil, fmt.Errorf("invalid signing algorithm: %s", signatureAlgorithm) } - _, err = manifest.CertificateChain[1].Verify(opts) - if err != nil { - return fmt.Errorf("error validating code signing certificate validity: %s", err) + + switch key := privateKey.(type) { + case *rsa.PrivateKey: + return &crypto.RSASigner{ + PrivateKey: key, + SignatureAlgorithm: signatureAlgorithm, + Hash: algorithm}, nil + case *ecdsa.PrivateKey: + return &crypto.ECDSASigner{ + PrivateKey: key, + SignatureAlgorithm: signatureAlgorithm, + Hash: algorithm}, nil + default: + return nil, errors.New("unsupported private key type") } - return nil } diff --git a/pkg/client/sign_test.go b/pkg/client/sign_test.go new file mode 100644 index 0000000..8b5b968 --- /dev/null +++ b/pkg/client/sign_test.go @@ -0,0 +1,140 @@ +package baseca + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "log" + "testing" + + "github.com/coinbase/baseca/pkg/types" + "github.com/stretchr/testify/require" +) + +func TestParsePrivateKey(t *testing.T) { + tests := []struct { + name string + pk func() []byte + algorithm x509.SignatureAlgorithm + check func(t *testing.T, err error) + }{ + { + name: "Valid RSA", + pk: func() []byte { + return GENERATE_PKCS8_RSA() + }, + algorithm: x509.SHA256WithRSA, + check: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + { + name: "Valid ECDSA", + pk: func() []byte { + return GENERATE_PKCS8_ECDSA() + }, + algorithm: x509.ECDSAWithSHA256, + check: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + { + name: "Invalid Signing Algorithm", + pk: func() []byte { + return GENERATE_PKCS8_RSA() + }, + algorithm: x509.DSAWithSHA1, + check: func(t *testing.T, err error) { + require.EqualError(t, err, "invalid signing algorithm: DSA-SHA1") + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := parsePrivateKey(tc.pk(), tc.algorithm) + tc.check(t, err) + }) + } +} + +func TestSign(t *testing.T) { + tests := []struct { + name string + sign func() ([]byte, error) + check func(t *testing.T, err error) + }{ + { + name: "Valid RSA Signer", + sign: func() ([]byte, error) { + signer, err := parsePrivateKey(GENERATE_PKCS8_RSA(), x509.SHA256WithRSA) + if err != nil { + return nil, err + } + return signer.Sign([]byte("_value")) + }, + check: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + { + name: "Valid ECDSA Signer", + sign: func() ([]byte, error) { + signer, err := parsePrivateKey(GENERATE_PKCS8_ECDSA(), x509.ECDSAWithSHA256) + if err != nil { + return nil, err + } + return signer.Sign([]byte("_value")) + }, + check: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.sign() + tc.check(t, err) + }) + } +} + +func GENERATE_PKCS8_RSA() []byte { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + log.Fatalf("Failed to generate private key: %v", err) + } + + pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + log.Fatalf("Failed to marshal private key to PKCS#8: %v", err) + } + + pemBlock := &pem.Block{ + Type: types.PKCS8_PRIVATE_KEY.String(), // PKCS#8 Encoding + Bytes: pkcs8Bytes, + } + return pem.EncodeToMemory(pemBlock) +} + +func GENERATE_PKCS8_ECDSA() []byte { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + log.Fatalf("Failed to generate ECDSA private key: %v", err) + } + + pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + log.Fatalf("Failed to marshal ECDSA private key to PKCS#8: %v", err) + } + + pemBlock := &pem.Block{ + Type: types.PKCS8_PRIVATE_KEY.String(), // PKCS#8 Encoding + Bytes: pkcs8Bytes, + } + return pem.EncodeToMemory(pemBlock) +} diff --git a/pkg/client/types.go b/pkg/client/types.go index fee1b8e..17e699b 100644 --- a/pkg/client/types.go +++ b/pkg/client/types.go @@ -1,9 +1,13 @@ package baseca -import "crypto/x509" +import ( + "crypto/x509" + "sync" + "time" +) var Attestation Provider = Provider{ - Local: "NONE", + Local: "Local", AWS: "AWS", } @@ -16,6 +20,8 @@ var Env = Environment{ Production: "Production", } +var iidCacheExpiration = 10 * time.Minute + type Environment struct { Local string Sandbox string @@ -57,6 +63,9 @@ type DistinguishedName struct { Locality []string Organization []string OrganizationalUnit []string + StreetAddress []string + PostalCode []string + SerialNumber string } type Output struct { @@ -66,3 +75,9 @@ type Output struct { RootCertificateChain string PrivateKey string } + +type iidCache struct { + expiration time.Time + lock sync.Mutex + value string +} diff --git a/pkg/client/validate.go b/pkg/client/validate.go new file mode 100644 index 0000000..f9eccfa --- /dev/null +++ b/pkg/client/validate.go @@ -0,0 +1,262 @@ +package baseca + +import ( + "crypto/ecdsa" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/coinbase/baseca/pkg/types" +) + +// Signature Validation for Different Data Inputs +func ValidateSignature(tc types.TrustChain, manifest types.Manifest) error { + err := validateManifestParameters(manifest) + if err != nil { + return fmt.Errorf("[manifest] %w", err) + } + + // Priority of Data Inputs (Path, Reader, Raw) + switch { + case manifest.Data.Path != types.Path{}: + err := validateStreamedSignature(manifest) + if err != nil { + return fmt.Errorf("[data.path] %s", err) + } + case manifest.Data.Reader != types.Reader{}: + err := validateReaderSignature(manifest) + if err != nil { + return fmt.Errorf("[data.reader] %s", err) + } + case manifest.Data.Raw != nil: + err := manifest.CertificateChain[0].CheckSignature(manifest.SigningAlgorithm, *manifest.Data.Raw, *manifest.Signature) + if err != nil { + return fmt.Errorf("[data.raw] %s", err) + } + default: + return errors.New("data not present within manifest") + } + + err = validateCertificateChain(tc, manifest) + if err != nil { + return fmt.Errorf("[certificate chain] %s", err) + } + return nil +} + +func validateCertificateChain(tc types.TrustChain, manifest types.Manifest) error { + // Validate Entire Certificate Chain Does Not Break + for i := range manifest.CertificateChain[:len(manifest.CertificateChain)-1] { + err := manifest.CertificateChain[i].CheckSignatureFrom(manifest.CertificateChain[i+1]) + if err != nil { + return fmt.Errorf("certificate chain invalid: %s", err) + } + } + + if manifest.CertificateChain[0].Subject.CommonName != tc.CommonName { + return fmt.Errorf("invalid common name (cn) from code signing certificate") + } + + validSubjectAlternativeName := false + if len(manifest.CertificateChain[0].DNSNames) > 0 { + for _, san := range manifest.CertificateChain[0].DNSNames { + if san == tc.CommonName { + validSubjectAlternativeName = true + } + } + } + + if !validSubjectAlternativeName { + return fmt.Errorf("invalid subject alternative name (san) from code signing certificate") + } + + rootCertificatePool, err := generateCertificatePool(tc) + if err != nil { + return err + } + + opts := x509.VerifyOptions{ + Roots: rootCertificatePool, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageCodeSigning}, + } + + switch len(manifest.CertificateChain) { + // Single Root CA + case 1: + _, err = manifest.CertificateChain[0].Verify(opts) + if err != nil { + return fmt.Errorf("error validating code signing certificate validity: %w", err) + } + // Subordinate CA Validates Against AWS Intermediate CA Based on x509.VerifyOptions + default: + _, err = manifest.CertificateChain[1].Verify(opts) + if err != nil { + return fmt.Errorf("error validating code signing certificate validity: %w", err) + } + } + return nil +} + +func verifySignature(manifest types.Manifest) error { + algorithm, exist := types.SignatureAlgorithm[manifest.SigningAlgorithm] + if !exist { + return fmt.Errorf("invalid signing algorithm: %s", manifest.SigningAlgorithm) + } + _, cryptoAlgorithm := algorithm() + + switch publicKey := manifest.CertificateChain[0].PublicKey.(type) { + case *rsa.PublicKey: + return rsa.VerifyPKCS1v15(publicKey, cryptoAlgorithm, *manifest.Hash, *manifest.Signature) + case *ecdsa.PublicKey: + if ecdsa.VerifyASN1(publicKey, *manifest.Hash, *manifest.Signature) { + return nil + } + return errors.New("ecdsa signature verification failed") + default: + return errors.New("unsupported public key type") + } +} + +// Signature Validation for Large Files Passing in Filepath +func validateStreamedSignature(manifest types.Manifest) error { + algorithm, exist := types.SignatureAlgorithm[manifest.SigningAlgorithm] + if !exist { + return fmt.Errorf("invalid signing algorithm: %s", manifest.SigningAlgorithm) + } + hashedAlgorithm, _ := algorithm() + + file, err := os.Open(manifest.Data.Path.File) + if err != nil { + return fmt.Errorf("error opening file: %s", err) + } + defer file.Close() + + if manifest.Data.Reader.Buffer > 0 { + _buffer = manifest.Data.Reader.Buffer + } + + buffer := make([]byte, _buffer) + for { + n, err := file.Read(buffer) + if err != nil && err != io.EOF { + return fmt.Errorf("error reading file: %s", err) + } + if n == 0 { + break + } + hashedAlgorithm.Write(buffer[:n]) + } + + hashedArtifact := hashedAlgorithm.Sum(nil) + manifest.Hash = &hashedArtifact + + err = verifySignature(manifest) + if err != nil { + return fmt.Errorf("signature verification failed: %s", err) + } + return nil +} + +// Signature Validation for Large Files Passing in io.Reader +func validateReaderSignature(manifest types.Manifest) error { + algorithm, exist := types.SignatureAlgorithm[manifest.SigningAlgorithm] + if !exist { + return fmt.Errorf("invalid signing algorithm: %s", manifest.SigningAlgorithm) + } + hashedAlgorithm, _ := algorithm() + + if manifest.Data.Reader.Buffer > 0 { + _buffer = manifest.Data.Reader.Buffer + } + + buffer := make([]byte, _buffer) + for { + n, err := manifest.Data.Reader.Interface.Read(buffer) + if err != nil && err != io.EOF { + return fmt.Errorf("error reading file: %s", err) + } + if n == 0 { + break + } + hashedAlgorithm.Write(buffer[:n]) + } + + hashedArtifact := hashedAlgorithm.Sum(nil) + manifest.Hash = &hashedArtifact + + err := verifySignature(manifest) + if err != nil { + return fmt.Errorf("signature verification failed: %s", err) + } + return nil +} + +func generateCertificatePool(tc types.TrustChain) (*x509.CertPool, error) { + certPool := x509.NewCertPool() + + for _, dir := range tc.CertificateAuthorityDirectory { + files, err := os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("invalid certificate authority directory %s", dir) + } + + for _, certFile := range files { // #nosec G304 User Only Has Predefined Environment Parameters + data, err := os.ReadFile(filepath.Join(dir, certFile.Name())) + if err != nil { + return nil, fmt.Errorf("invalid certificate file %s", filepath.Join(dir, certFile.Name())) + } + pemBlock, _ := pem.Decode(data) + if pemBlock == nil { + return nil, errors.New("invalid input file") + } + cert, err := x509.ParseCertificate(pemBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("error parsing x.509 certificate: %w", err) + } + certPool.AddCert(cert) + } + } + + for _, ca := range tc.CertificateAuthorityFiles { + data, err := os.ReadFile(filepath.Clean(ca)) + if err != nil { + return nil, fmt.Errorf("invalid certificate authority file %s", filepath.Clean(ca)) + } + pemBlock, _ := pem.Decode(data) + if pemBlock == nil { + return nil, errors.New("invalid input file") + } + cert, err := x509.ParseCertificate(pemBlock.Bytes) + if err != nil { + return nil, errors.New("error parsing x.509 certificate") + } + certPool.AddCert(cert) + } + return certPool, nil +} + +func validateManifestParameters(manifest types.Manifest) error { + if manifest.Signature == nil { + return errors.New("signature not found") + } + + if manifest.SigningAlgorithm == 0 { + return errors.New("signing algorithm not found") + } + + if len(manifest.CertificateChain) == 0 { + return errors.New("certificate chain not found") + } + + if (manifest.Data.Path == types.Path{} && + manifest.Data.Raw == nil && + manifest.Data.Reader == types.Reader{}) { + return errors.New("data not found") + } + return nil +} diff --git a/pkg/client/validate_test.go b/pkg/client/validate_test.go new file mode 100644 index 0000000..ba41835 --- /dev/null +++ b/pkg/client/validate_test.go @@ -0,0 +1,120 @@ +package baseca + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "testing" + "time" + + "github.com/coinbase/baseca/pkg/types" +) + +func TestValidateSignature(t *testing.T) { + tests := []struct { + name string + validate func() error + check func(t *testing.T, err error) + }{ + { + name: "Valid Signature RSA", + validate: func() error { + data := []byte("_value") + pk, certificate, path := generateSelfSignedCertificateAuthority() + signer, _ := parsePrivateKey(pk, x509.SHA256WithRSA) + + signature, err := signer.Sign(data) + if err != nil { + return fmt.Errorf("error signing data: %s", err) + } + + c, err := x509.ParseCertificate(certificate) + if err != nil { + return fmt.Errorf("error parsing code signing certificate: %s", err) + } + + tc := types.TrustChain{ + CommonName: "example.coinbase.com", + CertificateAuthorityFiles: []string{ + path.Name(), + }, + } + + manifest := types.Manifest{ + CertificateChain: []*x509.Certificate{c}, + Signature: &signature, + SigningAlgorithm: x509.SHA256WithRSA, + Data: types.Data{ + Raw: &data, + }, + } + return ValidateSignature(tc, manifest) + }, + check: func(t *testing.T, err error) { + if err != nil { + t.Errorf("expected no error, got %s", err) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.validate() + tc.check(t, err) + }) + } +} + +func generateSelfSignedCertificateAuthority() ([]byte, []byte, *os.File) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + + notBefore := time.Now().UTC() + notAfter := notBefore.Add(365 * 24 * time.Hour) + serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "example.coinbase.com", + Organization: []string{"Coinbase"}, + }, + DNSNames: []string{"example.coinbase.com"}, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageCodeSigning}, + BasicConstraintsValid: true, + } + + // Generate DER Encoded Self-Signed Certificate + certificate, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + panic(err) + } + + // PKCS#8 Encode Private Key + pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + panic(err) + } + pk := pem.EncodeToMemory(&pem.Block{Type: types.PKCS8_PRIVATE_KEY.String(), Bytes: pkcs8Bytes}) + + // Use Self-Signed Certificate as Certificate Authority + path, err := os.CreateTemp("", "ca.crt") + if err != nil { + panic(err) + } + defer path.Close() + _ = pem.Encode(path, &pem.Block{Type: types.CERTIFICATE.String(), Bytes: certificate}) + + return pk, certificate, path +} diff --git a/pkg/crypto/generate.go b/pkg/crypto/generate.go index 2134e7d..315e88c 100644 --- a/pkg/crypto/generate.go +++ b/pkg/crypto/generate.go @@ -35,7 +35,7 @@ func (r *SigningRequestGeneratorRSA) Generate() (crypto.PrivateKey, error) { } func (r *SigningRequestGeneratorRSA) KeyType() string { - return "RSA PRIVATE KEY" + return types.RSA_PRIVATE_KEY.String() } func (r *SigningRequestGeneratorRSA) MarshalPrivateKey(key crypto.PrivateKey) ([]byte, error) { @@ -47,18 +47,18 @@ func (r *SigningRequestGeneratorRSA) SupportsPublicKeyAlgorithm(algorithm x509.P } func (r *SigningRequestGeneratorRSA) SupportsSigningAlgorithm(algorithm x509.SignatureAlgorithm) bool { - _, ok := types.PublicKeyAlgorithms["RSA"].SigningAlgorithm[algorithm] + _, ok := types.PublicKeyAlgorithms[types.RSA].SigningAlgorithm[algorithm] return ok } func (r *SigningRequestGeneratorRSA) SupportsKeySize(size int) bool { - _, ok := types.PublicKeyAlgorithms["RSA"].KeySize[size] + _, ok := types.PublicKeyAlgorithms[types.RSA].KeySize[size] return ok } // ECDSA Interface func (e *SigningRequestGeneratorECDSA) Generate() (crypto.PrivateKey, error) { - c, ok := types.PublicKeyAlgorithms["ECDSA"].KeySize[e.Curve] + c, ok := types.PublicKeyAlgorithms[types.ECDSA].KeySize[e.Curve] if !ok { return nil, fmt.Errorf("ecdsa curve [%d] not supported", e.Curve) @@ -73,7 +73,7 @@ func (e *SigningRequestGeneratorECDSA) Generate() (crypto.PrivateKey, error) { } func (e *SigningRequestGeneratorECDSA) KeyType() string { - return "EC PRIVATE KEY" + return types.ECDSA_PRIVATE_KEY.String() } func (e *SigningRequestGeneratorECDSA) MarshalPrivateKey(key crypto.PrivateKey) ([]byte, error) { @@ -85,11 +85,11 @@ func (e *SigningRequestGeneratorECDSA) SupportsPublicKeyAlgorithm(algorithm x509 } func (e *SigningRequestGeneratorECDSA) SupportsSigningAlgorithm(algorithm x509.SignatureAlgorithm) bool { - _, ok := types.PublicKeyAlgorithms["ECDSA"].SigningAlgorithm[algorithm] + _, ok := types.PublicKeyAlgorithms[types.ECDSA].SigningAlgorithm[algorithm] return ok } func (e *SigningRequestGeneratorECDSA) SupportsKeySize(size int) bool { - _, ok := types.PublicKeyAlgorithms["ECDSA"].KeySize[size] + _, ok := types.PublicKeyAlgorithms[types.ECDSA].KeySize[size] return ok } diff --git a/pkg/crypto/pk.go b/pkg/crypto/pk.go index 3c01a74..45ab23d 100644 --- a/pkg/crypto/pk.go +++ b/pkg/crypto/pk.go @@ -6,69 +6,94 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" + "encoding/pem" "fmt" + "hash" + + "github.com/coinbase/baseca/pkg/types" ) -type CertificateAuthority struct { - Certificate *x509.Certificate - AsymmetricKey *AsymmetricKey - SerialNumber string - CertificateAuthorityArn string +type RSASigner struct { + PrivateKey *rsa.PrivateKey + SignatureAlgorithm x509.SignatureAlgorithm + Hash func() (hash.Hash, crypto.Hash) } -type AsymmetricKey interface { - KeyPair() interface{} - Sign(data []byte) ([]byte, error) +type ECDSASigner struct { + PrivateKey *ecdsa.PrivateKey + SignatureAlgorithm x509.SignatureAlgorithm + Hash func() (hash.Hash, crypto.Hash) } -type RSA struct { - PublicKey *rsa.PublicKey - PrivateKey *rsa.PrivateKey +func (r *RSASigner) Sign(data []byte) ([]byte, error) { + hash, cryptoHash := r.Hash() + hash.Write(data) + hashedValue := hash.Sum(nil) + switch r.SignatureAlgorithm { + case x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS, x509.SHA512WithRSAPSS: + return rsa.SignPSS(rand.Reader, r.PrivateKey, cryptoHash, hashedValue, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthAuto, + }) + default: + return rsa.SignPKCS1v15(rand.Reader, r.PrivateKey, cryptoHash, hashedValue) + } } -type ECDSA struct { - PublicKey *ecdsa.PublicKey - PrivateKey *ecdsa.PrivateKey +func (e *ECDSASigner) Sign(data []byte) ([]byte, error) { + hash, _ := e.Hash() + hash.Write(data) + hashedValue := hash.Sum(nil) + return ecdsa.SignASN1(rand.Reader, e.PrivateKey, hashedValue) } -func (key *RSA) KeyPair() interface{} { - return key -} +func EncodeToPKCS8(pkBlock *pem.Block) (*pem.Block, error) { + var key interface{} + var err error -func (key *RSA) Sign(data []byte) ([]byte, error) { - h := crypto.SHA256.New() - h.Write(data) - hashed := h.Sum(nil) - return rsa.SignPKCS1v15(rand.Reader, key.PrivateKey, crypto.SHA256, hashed) -} + switch pkBlock.Type { + case types.RSA_PRIVATE_KEY.String(): + key, err = x509.ParsePKCS1PrivateKey(pkBlock.Bytes) + case types.ECDSA_PRIVATE_KEY.String(): + key, err = x509.ParseECPrivateKey(pkBlock.Bytes) + default: + return nil, fmt.Errorf("unsupported key type %s", pkBlock.Type) + } -func (key *ECDSA) KeyPair() interface{} { - return key -} + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %v", err) + } -func (key *ECDSA) Sign(data []byte) ([]byte, error) { - h := crypto.SHA256.New() - h.Write(data) - hashed := h.Sum(nil) - r, s, err := ecdsa.Sign(rand.Reader, key.PrivateKey, hashed) + pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(key) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to marshal to PKCS#8: %v", err) } - signature := append(r.Bytes(), s.Bytes()...) - return signature, nil -} -func ReturnPrivateKey(key AsymmetricKey) (interface{}, error) { - if key == nil { - return nil, fmt.Errorf("asymmetric key is nil") + pkcs8Encoding := &pem.Block{ + Type: types.PKCS8_PRIVATE_KEY.String(), // PKCS#8 Encoding + Bytes: pkcs8Bytes, } - switch k := key.KeyPair().(type) { - case *RSA: - return k.PrivateKey, nil - case *ECDSA: - return k.PrivateKey, nil + return pkcs8Encoding, nil +} + +func ReturnSignerInterface(pkBlock *pem.Block) (crypto.Signer, error) { + var key interface{} + var err error + + switch pkBlock.Type { + case types.RSA_PRIVATE_KEY.String(): + key, err = x509.ParsePKCS1PrivateKey(pkBlock.Bytes) + case types.ECDSA_PRIVATE_KEY.String(): + key, err = x509.ParseECPrivateKey(pkBlock.Bytes) default: - return nil, fmt.Errorf("unsupported key type") + return nil, fmt.Errorf("unsupported key type %s", pkBlock.Type) + } + + var signer crypto.Signer + var ok bool + + if signer, ok = key.(crypto.Signer); !ok { + return nil, fmt.Errorf("failed to parse private key: %v", err) } + return signer, nil } diff --git a/pkg/crypto/pk_test.go b/pkg/crypto/pk_test.go deleted file mode 100644 index 2badf9e..0000000 --- a/pkg/crypto/pk_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package crypto - -import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "reflect" - "testing" -) - -func TestRSASign(t *testing.T) { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatalf("failed to generate private key: %v", err) - } - rsaKey := &RSA{ - PublicKey: &privateKey.PublicKey, - PrivateKey: privateKey, - } - data := []byte("_example") - signature, err := rsaKey.Sign(data) - if err != nil { - t.Fatalf("failed to sign data: %v", err) - } - if len(signature) == 0 { - t.Fatalf("expected non-empty signature") - } -} - -func TestECDSASign(t *testing.T) { - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("failed to generate private key: %v", err) - } - ecdsaKey := &ECDSA{ - PublicKey: &privateKey.PublicKey, - PrivateKey: privateKey, - } - data := []byte("_example") - signature, err := ecdsaKey.Sign(data) - if err != nil { - t.Fatalf("failed to sign data: %v", err) - } - if len(signature) == 0 { - t.Fatalf("expected non-empty signature") - } -} - -func TestReturnPrivateKey(t *testing.T) { - rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatalf("failed to generate rsa private key: %v", err) - } - - ecdsaPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("failed to generate ecdsa private key: %v", err) - } - - tests := []struct { - key AsymmetricKey - expected interface{} - }{ - {&RSA{PrivateKey: rsaPrivateKey}, rsaPrivateKey}, - {&ECDSA{PrivateKey: ecdsaPrivateKey}, ecdsaPrivateKey}, - {nil, nil}, - } - - for _, test := range tests { - got, err := ReturnPrivateKey(test.key) - if err != nil && test.key != nil { - t.Fatalf("unexpected error: %v", err) - } - if !reflect.DeepEqual(got, test.expected) { - t.Errorf("expected %v, but got %v", test.expected, got) - } - } -} - -func TestCertificateAuthorityInitialization(t *testing.T) { - ca := &CertificateAuthority{ - SerialNumber: "0000000000", - } - if ca.SerialNumber != "0000000000" { - t.Errorf("expected serial number to be '0000000000', but got '%s'", ca.SerialNumber) - } -} diff --git a/pkg/types/certificate.go b/pkg/types/certificate.go index ff5a5ee..bc11539 100644 --- a/pkg/types/certificate.go +++ b/pkg/types/certificate.go @@ -2,12 +2,17 @@ package types import ( "bytes" + "crypto/elliptic" + "crypto/x509" "encoding/pem" + + "github.com/aws/aws-sdk-go-v2/service/acmpca/types" ) type SigningRequest struct { - CSR *bytes.Buffer - PrivateKey *pem.Block + CSR *bytes.Buffer + PrivateKey *pem.Block + EncodedPKCS8 []byte } type SignedCertificate struct { @@ -15,3 +20,145 @@ type SignedCertificate struct { IntermediateCertificateChainPath string RootCertificateChainPath string } + +type PublicKeyAlgorithm struct { + Algorithm x509.PublicKeyAlgorithm + KeySize map[int]interface{} + Signature map[string]bool + SigningAlgorithm map[x509.SignatureAlgorithm]bool +} + +var PublicKeyAlgorithms = map[KeyType]PublicKeyAlgorithm{ + RSA: { + Algorithm: x509.RSA, + KeySize: map[int]interface{}{ + 2048: true, + 4096: true, + }, + Signature: map[string]bool{ + "SHA256WITHRSA": true, + "SHA384WITHRSA": true, + "SHA512WITHRSA": true, + "SHA256WITHRSAPSS": true, + "SHA384WITHRSAPSS": true, + "SHA512WithRSAPSS": true, + }, + SigningAlgorithm: map[x509.SignatureAlgorithm]bool{ + x509.SHA256WithRSA: true, + x509.SHA384WithRSA: true, + x509.SHA512WithRSA: true, + x509.SHA256WithRSAPSS: true, + x509.SHA384WithRSAPSS: true, + x509.SHA512WithRSAPSS: true, + }, + }, + ECDSA: { + Algorithm: x509.ECDSA, + KeySize: map[int]interface{}{ + 256: elliptic.P256(), + 384: elliptic.P384(), + 521: elliptic.P521(), + }, + Signature: map[string]bool{ + "SHA256WITHECDSA": true, + "SHA384WITHECDSA": true, + "SHA512WITHECDSA": true, + }, + SigningAlgorithm: map[x509.SignatureAlgorithm]bool{ + x509.ECDSAWithSHA256: true, + x509.ECDSAWithSHA384: true, + x509.ECDSAWithSHA512: true, + }, + }, + // TODO: Support Ed25519 + Ed25519: { + Algorithm: x509.Ed25519, + KeySize: map[int]interface{}{ + 256: true, + }, + }, +} + +var PublicKeyAlgorithmStrings = map[string]PublicKeyAlgorithm{ + RSA.String(): { + Algorithm: x509.RSA, + KeySize: map[int]interface{}{ + 2048: true, + 4096: true, + }, + Signature: map[string]bool{ + "SHA256WITHRSA": true, + "SHA384WITHRSA": true, + "SHA512WITHRSA": true, + "SHA256WITHRSAPSS": true, + "SHA384WITHRSAPSS": true, + "SHA512WithRSAPSS": true, + }, + SigningAlgorithm: map[x509.SignatureAlgorithm]bool{ + x509.SHA256WithRSA: true, + x509.SHA384WithRSA: true, + x509.SHA512WithRSA: true, + x509.SHA256WithRSAPSS: true, + x509.SHA384WithRSAPSS: true, + x509.SHA512WithRSAPSS: true, + }, + }, + ECDSA.String(): { + Algorithm: x509.ECDSA, + KeySize: map[int]interface{}{ + 256: elliptic.P256(), + 384: elliptic.P384(), + 521: elliptic.P521(), + }, + Signature: map[string]bool{ + "SHA256WITHECDSA": true, + "SHA384WITHECDSA": true, + "SHA512WITHECDSA": true, + }, + SigningAlgorithm: map[x509.SignatureAlgorithm]bool{ + x509.ECDSAWithSHA256: true, + x509.ECDSAWithSHA384: true, + x509.ECDSAWithSHA512: true, + }, + }, + // TODO: Support Ed25519 + Ed25519.String(): { + Algorithm: x509.Ed25519, + KeySize: map[int]interface{}{ + 256: true, + }, + }, +} + +type SigningAlgorithm struct { + Common x509.SignatureAlgorithm + PCA types.SigningAlgorithm +} + +var ValidSignatures = map[string]SigningAlgorithm{ + "SHA256WITHECDSA": { + Common: x509.ECDSAWithSHA256, + PCA: types.SigningAlgorithmSha256withecdsa, + }, + "SHA384WITHECDSA": { + Common: x509.ECDSAWithSHA384, + PCA: types.SigningAlgorithmSha384withecdsa, + }, + "SHA512WITHECDSA": { + Common: x509.ECDSAWithSHA512, + PCA: types.SigningAlgorithmSha512withecdsa, + }, + "SHA256WITHRSA": { + Common: x509.SHA256WithRSA, + PCA: types.SigningAlgorithmSha256withrsa, + }, + "SHA384WITHRSA": { + Common: x509.SHA384WithRSA, + PCA: types.SigningAlgorithmSha384withrsa, + }, + "SHA512WITHRSA": { + Common: x509.SHA512WithRSA, + PCA: types.SigningAlgorithmSha512withrsa, + }, + // TODO: Support Probabilistic Element to the Signature Scheme [SHA256WithRSAPSS] +} diff --git a/pkg/types/pk.go b/pkg/types/pk.go index b0252d3..69d362d 100644 --- a/pkg/types/pk.go +++ b/pkg/types/pk.go @@ -1,58 +1,26 @@ package types -import ( - "crypto/elliptic" - "crypto/x509" -) +type KeyType uint -type PublicKeyAlgorithm struct { - Algorithm x509.PublicKeyAlgorithm - KeySize map[int]any - Signature map[string]bool - SigningAlgorithm map[x509.SignatureAlgorithm]bool -} +const ( + RSA_PRIVATE_KEY KeyType = iota + ECDSA_PRIVATE_KEY + PKCS8_PRIVATE_KEY + CERTIFICATE + CERTIFICATE_REQUEST + RSA + ECDSA + Ed25519 +) -var PublicKeyAlgorithms = map[string]PublicKeyAlgorithm{ - "RSA": { - Algorithm: x509.RSA, - KeySize: map[int]interface{}{ - 2048: true, - 4096: true, - }, - Signature: map[string]bool{ - "SHA256WITHRSA": true, - "SHA384WITHRSA": true, - "SHA512WITHRSA": true, - }, - SigningAlgorithm: map[x509.SignatureAlgorithm]bool{ - x509.SHA256WithRSA: true, - x509.SHA384WithRSA: true, - x509.SHA512WithRSA: true, - }, - }, - "ECDSA": { - Algorithm: x509.ECDSA, - KeySize: map[int]interface{}{ - 256: elliptic.P256(), - 384: elliptic.P384(), - 521: elliptic.P521(), - }, - Signature: map[string]bool{ - "SHA256WITHECDSA": true, - "SHA384WITHECDSA": true, - "SHA512WITHECDSA": true, - }, - SigningAlgorithm: map[x509.SignatureAlgorithm]bool{ - x509.ECDSAWithSHA256: true, - x509.ECDSAWithSHA384: true, - x509.ECDSAWithSHA512: true, - }, - }, - // TODO: Support Ed25519 - "Ed25519": { - Algorithm: x509.Ed25519, - KeySize: map[int]interface{}{ - 256: true, - }, - }, +func (k KeyType) String() string { + return [...]string{ + "RSA PRIVATE KEY", + "EC PRIVATE KEY", + "PRIVATE KEY", + "CERTIFICATE", + "CERTIFICATE REQUEST", + "RSA", + "ECDSA", + "Ed25519"}[k] } diff --git a/pkg/types/sign.go b/pkg/types/sign.go index c08d88e..2d0d82b 100644 --- a/pkg/types/sign.go +++ b/pkg/types/sign.go @@ -1,6 +1,13 @@ package types -import "crypto/x509" +import ( + "crypto" + "crypto/sha256" + "crypto/sha512" + "crypto/x509" + "hash" + "io" +) type TrustChain struct { CommonName string @@ -8,9 +15,56 @@ type TrustChain struct { CertificateAuthorityFiles []string } +type Path struct { + File string + Buffer int +} + +type Reader struct { + Interface io.Reader + Buffer int +} + +type Data struct { + Raw *[]byte + Path Path + Reader Reader +} + type Manifest struct { CertificateChain []*x509.Certificate - Signature []byte - Data []byte SigningAlgorithm x509.SignatureAlgorithm + Signature *[]byte + Hash *[]byte + Data Data +} + +var SignatureAlgorithm = map[x509.SignatureAlgorithm]func() (hash.Hash, crypto.Hash){ + x509.ECDSAWithSHA256: func() (hash.Hash, crypto.Hash) { + return sha256.New(), crypto.SHA256 + }, + x509.ECDSAWithSHA384: func() (hash.Hash, crypto.Hash) { + return sha512.New384(), crypto.SHA384 + }, + x509.ECDSAWithSHA512: func() (hash.Hash, crypto.Hash) { + return sha512.New(), crypto.SHA512 + }, + x509.SHA256WithRSA: func() (hash.Hash, crypto.Hash) { + return sha256.New(), crypto.SHA256 + }, + x509.SHA384WithRSA: func() (hash.Hash, crypto.Hash) { + return sha512.New384(), crypto.SHA384 + }, + x509.SHA512WithRSA: func() (hash.Hash, crypto.Hash) { + return sha512.New(), crypto.SHA512 + }, + x509.SHA256WithRSAPSS: func() (hash.Hash, crypto.Hash) { + return sha256.New(), crypto.SHA256 + }, + x509.SHA384WithRSAPSS: func() (hash.Hash, crypto.Hash) { + return sha512.New384(), crypto.SHA384 + }, + x509.SHA512WithRSAPSS: func() (hash.Hash, crypto.Hash) { + return sha512.New(), crypto.SHA512 + }, }