diff --git a/ca_certs/uzi_ca_certs.go b/ca_certs/uzi_ca_certs.go index a1cdc9a..7e6a913 100644 --- a/ca_certs/uzi_ca_certs.go +++ b/ca_certs/uzi_ca_certs.go @@ -29,12 +29,12 @@ func GetCertPools(includeTest bool) (root *x509.CertPool, intermediate *x509.Cer return downloadUziPool(pool) } -func GetCerts(includeTest bool) (*[]x509.Certificate, error) { +func GetCerts(includeTest bool) ([]*x509.Certificate, error) { pool := prepareAndCombinePools(includeTest) return downloadUziPoolCerts(pool) } -func GetDERs(includeTest bool) (*[][]byte, error) { +func GetDERs(includeTest bool) ([][]byte, error) { pool := prepareAndCombinePools(includeTest) return downloadUziPoolDERs(pool) } @@ -50,19 +50,19 @@ func prepareAndCombinePools(includeTest bool) UziCaPool { return pool } -func downloadUziPoolDERs(pool UziCaPool) (*[][]byte, error) { +func downloadUziPoolDERs(pool UziCaPool) ([][]byte, error) { var rv = [][]byte{} certs, err := downloadUziPoolCerts(pool) if err != nil { return nil, err } - for _, cert := range *certs { + for _, cert := range certs { rv = append(rv, cert.Raw) } - return &rv, err + return rv, err } -func GetTestCerts() (*[]x509.Certificate, error) { +func GetTestCerts() ([]*x509.Certificate, error) { return downloadUziPoolCerts(TestUziCaPool) } @@ -81,7 +81,7 @@ func downloadUziPool(pool UziCaPool) (*x509.CertPool, *x509.CertPool, error) { return roots, intermediates, nil } -func downloadUziPoolCerts(pool UziCaPool) (*[]x509.Certificate, error) { +func downloadUziPoolCerts(pool UziCaPool) ([]*x509.Certificate, error) { allUrls := append(pool.rootCaUrls, pool.intermediateCaUrls...) all, err := downloadCerts(allUrls) if err != nil { @@ -103,16 +103,16 @@ func downloadPool(urls []string) (*x509.CertPool, error) { return roots, nil } -func downloadCerts(urls []string) (*[]x509.Certificate, error) { - certs := make([]x509.Certificate, 0) +func downloadCerts(urls []string) ([]*x509.Certificate, error) { + certs := make([]*x509.Certificate, 0) for _, url := range urls { certificate, err := readCertificateFromUrl(url) if err != nil { return nil, err } - certs = append(certs, *certificate) + certs = append(certs, certificate) } - return &certs, nil + return certs, nil } func readCertificateFromUrl(url string) (*x509.Certificate, error) { diff --git a/did_x509/did_x509.go b/did_x509/did_x509.go index 701a8f1..1738740 100644 --- a/did_x509/did_x509.go +++ b/did_x509/did_x509.go @@ -21,13 +21,9 @@ type X509Did struct { // FormatDid constructs a decentralized identifier (DID) from a certificate chain and an optional policy. // It returns the formatted DID string or an error if the root certificate or hash calculation fails. -func FormatDid(chain *[]x509.Certificate, policy string) (string, error) { - root, err := FindRootCertificate(chain) - if err != nil { - return "", err - } +func FormatDid(caCert *x509.Certificate, policy string) (string, error) { alg := "sha512" - rootHash, err := x509_cert.Hash(root.Raw, alg) + rootHash, err := x509_cert.Hash(caCert.Raw, alg) if err != nil { return "", err } @@ -42,17 +38,13 @@ func FormatDid(chain *[]x509.Certificate, policy string) (string, error) { // CreateDid generates a Decentralized Identifier (DID) from a given certificate chain. // It extracts the Unique Registration Address (URA) from the chain, creates a policy with it, and formats the DID. // Returns the generated DID or an error if any step fails. -func CreateDid(chain *[]x509.Certificate) (string, error) { - certificate, _, err := x509_cert.FindSigningCertificate(chain) - if err != nil || certificate == nil { - return "", err - } - otherNameValue, sanType, err := x509_cert.FindOtherName(certificate) +func CreateDid(signingCert, caCert *x509.Certificate) (string, error) { + otherNameValue, sanType, err := x509_cert.FindOtherName(signingCert) if err != nil { return "", err } policy := CreatePolicy(otherNameValue, sanType) - formattedDid, err := FormatDid(chain, policy) + formattedDid, err := FormatDid(caCert, policy) return formattedDid, err } func ParseDid(didString string) (*X509Did, error) { @@ -83,10 +75,10 @@ func CreatePolicy(otherNameValue string, sanType x509_cert.SanTypeName) string { } // FindRootCertificate traverses a chain of x509 certificates and returns the first certificate that is a CA. -func FindRootCertificate(chain *[]x509.Certificate) (*x509.Certificate, error) { - for _, cert := range *chain { - if cert.IsCA { - return &cert, nil +func FindRootCertificate(chain []*x509.Certificate) (*x509.Certificate, error) { + for _, cert := range chain { + if x509_cert.IsRootCa(cert) { + return cert, nil } } return nil, fmt.Errorf("cannot find root certificate") diff --git a/did_x509/did_x509_mock.go b/did_x509/did_x509_mock.go index e0cd1e6..99dece3 100644 --- a/did_x509/did_x509_mock.go +++ b/did_x509/did_x509_mock.go @@ -40,16 +40,54 @@ func (m *MockDidCreator) EXPECT() *MockDidCreatorMockRecorder { } // CreateDid mocks base method. -func (m *MockDidCreator) CreateDid(chain *[]x509.Certificate) (string, error) { +func (m *MockDidCreator) CreateDid(signingCert, caCert *x509.Certificate) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateDid", chain) + ret := m.ctrl.Call(m, "CreateDid", signingCert, caCert) ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateDid indicates an expected call of CreateDid. -func (mr *MockDidCreatorMockRecorder) CreateDid(chain any) *gomock.Call { +func (mr *MockDidCreatorMockRecorder) CreateDid(signingCert, caCert any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDid", reflect.TypeOf((*MockDidCreator)(nil).CreateDid), chain) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDid", reflect.TypeOf((*MockDidCreator)(nil).CreateDid), signingCert, caCert) +} + +// MockDidParser is a mock of DidParser interface. +type MockDidParser struct { + ctrl *gomock.Controller + recorder *MockDidParserMockRecorder +} + +// MockDidParserMockRecorder is the mock recorder for MockDidParser. +type MockDidParserMockRecorder struct { + mock *MockDidParser +} + +// NewMockDidParser creates a new mock instance. +func NewMockDidParser(ctrl *gomock.Controller) *MockDidParser { + mock := &MockDidParser{ctrl: ctrl} + mock.recorder = &MockDidParserMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDidParser) EXPECT() *MockDidParserMockRecorder { + return m.recorder +} + +// ParseDid mocks base method. +func (m *MockDidParser) ParseDid(did string) (*X509Did, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ParseDid", did) + ret0, _ := ret[0].(*X509Did) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ParseDid indicates an expected call of ParseDid. +func (mr *MockDidParserMockRecorder) ParseDid(did any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ParseDid", reflect.TypeOf((*MockDidParser)(nil).ParseDid), did) } diff --git a/did_x509/did_x509_test.go b/did_x509/did_x509_test.go index fd1f611..3a79f5a 100644 --- a/did_x509/did_x509_test.go +++ b/did_x509/did_x509_test.go @@ -14,7 +14,7 @@ func TestDefaultDidCreator_CreateDid(t *testing.T) { type fields struct { } type args struct { - chain *[]x509.Certificate + chain []*x509.Certificate } chain, _, rootCert, _, _, err := x509_cert.BuildCertChain("A_BIG_STRING") if err != nil { @@ -34,22 +34,6 @@ func TestDefaultDidCreator_CreateDid(t *testing.T) { want string errMsg string }{ - { - name: "Test case 1", - fields: fields{}, - args: args{chain: &[]x509.Certificate{}}, - want: "", - errMsg: "no certificates provided", - }, - { - name: "Test case 2", - fields: fields{}, - args: args{chain: &[]x509.Certificate{ - {}, - }}, - want: "", - errMsg: "no certificate found in the SAN attributes, please check if the certificate is an UZI Server Certificate", - }, { name: "Happy path", fields: fields{}, @@ -60,7 +44,7 @@ func TestDefaultDidCreator_CreateDid(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := CreateDid(tt.args.chain) + got, err := CreateDid(tt.args.chain[0], tt.args.chain[len(tt.args.chain)-1]) wantErr := tt.errMsg != "" if (err != nil) != wantErr { t.Errorf("DefaultDidProcessor.CreateDid() error = %v, errMsg %v", err, tt.errMsg) diff --git a/main.go b/main.go index dbc4feb..fde4757 100644 --- a/main.go +++ b/main.go @@ -4,19 +4,28 @@ import ( "fmt" "github.com/alecthomas/kong" "github.com/nuts-foundation/uzi-did-x509-issuer/uzi_vc_issuer" + "github.com/nuts-foundation/uzi-did-x509-issuer/x509_cert" "os" ) type VC struct { - CertificateFile string `arg:"" name:"certificate_file" help:"Certificate PEM file." type:"existingfile"` + CertificateFile string `arg:"" name:"certificate_file" help:"Certificate PEM file. If the file contains a chain, the chain will be used for signing." type:"existingfile"` SigningKey string `arg:"" name:"signing_key" help:"PEM key for signing." type:"existingfile"` SubjectDID string `arg:"" name:"subject_did" help:"The subject DID of the VC." type:"key"` - Test bool `short:"t" help:"Allow test certificates."` + Test bool `short:"t" help:"Allow for certificates signed by the TEST UZI Root CA."` +} + +type TestCert struct { + Uzi string `arg:"" name:"uzi" help:"The UZI number for the test certificate."` + Ura string `arg:"" name:"ura" help:"The URA number for the test certificate."` + Agb string `arg:"" name:"agb" help:"The AGB code for the test certificate."` + SubjectDID string `arg:"" default:"did:web:example.com:test" name:"subject_did" help:"The subject DID of the VC." type:"key"` } var CLI struct { - Version string `help:"Show version."` - Vc VC `cmd:"" help:"Create a new VC."` + Version string `help:"Show version."` + Vc VC `cmd:"" help:"Create a new VC."` + TestCert TestCert `cmd:"" help:"Create a new test certificate."` } func main() { @@ -25,17 +34,69 @@ func main() { if err != nil { panic(err) } - _, err = parser.Parse(os.Args[1:]) + ctx, err := parser.Parse(os.Args[1:]) if err != nil { parser.FatalIfErrorf(err) } - vc := cli.Vc - jwt, err := issueVc(vc) - if err != nil { - fmt.Println(err) + + switch ctx.Command() { + case "vc ": + vc := cli.Vc + jwt, err := issueVc(vc) + if err != nil { + fmt.Println(err) + os.Exit(-1) + } + println(jwt) + case "test-cert ", "test-cert ": + // Format is 2.16.528.1.1007.99.2110-1-900030787-S-90000380-00.000-11223344 + // ------ + // 2.16.528.1.1007.99.2110-1--S--00.000- + otherName := fmt.Sprintf("2.16.528.1.1007.99.2110-1-%s-S-%s-00.000-%s", cli.TestCert.Uzi, cli.TestCert.Ura, cli.TestCert.Agb) + fmt.Println("Building certificate chain for identifier:", otherName) + chain, _, _, privKey, _, err := x509_cert.BuildCertChain(otherName) + if err != nil { + fmt.Println(err) + os.Exit(-1) + } + + chainPems, err := x509_cert.EncodeCertificates(chain...) + if err != nil { + fmt.Println(err) + os.Exit(-1) + } + signingKeyPem, err := x509_cert.EncodeRSAPrivateKey(privKey) + if err != nil { + fmt.Println(err) + os.Exit(-1) + } + + err = os.WriteFile("chain.pem", chainPems, 0644) + if err != nil { + fmt.Println(err) + os.Exit(-1) + } + err = os.WriteFile("signing_key.pem", signingKeyPem, 0644) + if err != nil { + fmt.Println(err) + os.Exit(-1) + } + vc := VC{ + CertificateFile: "chain.pem", + SigningKey: "signing_key.pem", + SubjectDID: cli.TestCert.SubjectDID, + Test: false, + } + jwt, err := issueVc(vc) + if err != nil { + fmt.Println(err) + os.Exit(-1) + } + println(jwt) + default: + fmt.Println("Unknown command") os.Exit(-1) } - println(jwt) } func issueVc(vc VC) (string, error) { diff --git a/pem/pem_reader.go b/pem/pem_reader.go index 03e0db7..3a3ec8a 100644 --- a/pem/pem_reader.go +++ b/pem/pem_reader.go @@ -2,11 +2,12 @@ package pem import ( "encoding/pem" + "fmt" "os" ) // ParseFileOrPath processes a file or directory at the given path and extracts PEM blocks of the specified pemType. -func ParseFileOrPath(path string, pemType string) (*[][]byte, error) { +func ParseFileOrPath(path string, pemType string) ([][]byte, error) { fileInfo, err := os.Stat(path) if err != nil { return nil, err @@ -25,9 +26,9 @@ func ParseFileOrPath(path string, pemType string) (*[][]byte, error) { if err != nil { return nil, err } - files = append(files, *blocks...) + files = append(files, blocks...) } - return &files, nil + return files, nil } else { blocks, err := readFile(path, pemType) return blocks, err @@ -36,7 +37,8 @@ func ParseFileOrPath(path string, pemType string) (*[][]byte, error) { } // readFile reads a file from the given filename, parses it for PEM blocks of the specified type, and returns the blocks. -func readFile(filename string, pemType string) (*[][]byte, error) { +func readFile(filename string, pemType string) ([][]byte, error) { + fmt.Println("filename: ", filename) files := make([][]byte, 0) content, err := os.ReadFile(filename) if err != nil { @@ -44,13 +46,13 @@ func readFile(filename string, pemType string) (*[][]byte, error) { } if looksLineCert(content, pemType) { foundBlocks := ParsePemBlocks(content, pemType) - files = append(files, *foundBlocks...) + files = append(files, foundBlocks...) } - return &files, nil + return files, nil } // ParsePemBlocks extracts specified PEM blocks from the provided certificate bytes and returns them as a pointer to a slice of byte slices. -func ParsePemBlocks(cert []byte, pemType string) *[][]byte { +func ParsePemBlocks(cert []byte, pemType string) [][]byte { blocks := make([][]byte, 0) for { pemBlock, tail := pem.Decode(cert) @@ -66,7 +68,7 @@ func ParsePemBlocks(cert []byte, pemType string) *[][]byte { cert = tail } - return &blocks + return blocks } // looksLineCert checks if the given certificate data is a valid PEM block of the specified type. diff --git a/pem/pem_reader_test.go b/pem/pem_reader_test.go index 92b2ae6..62b4c2f 100644 --- a/pem/pem_reader_test.go +++ b/pem/pem_reader_test.go @@ -75,9 +75,9 @@ func TestParseFileOrPath(t *testing.T) { } data, err := ParseFileOrPath(file.Name(), pemType) assert.NoError(t, err) - for i := 0; i < len(*data); i++ { - bytes := (*data)[i] - certificate := (*certs)[i] + for i := 0; i < len(data); i++ { + bytes := (data)[i] + certificate := (certs)[i] ok := assert.Equal(t, bytes, certificate.Raw) if !ok { t.Fail() @@ -110,14 +110,14 @@ func TestParseFileOrPath(t *testing.T) { data, err := ParseFileOrPath(tempDir, pemType) assert.NoError(t, err) dataMap := make(map[string][]byte) - for i := 0; i < len(*data); i++ { - bytes := (*data)[i] + for i := 0; i < len(data); i++ { + bytes := (data)[i] hash, err := x509_cert.Hash(bytes, "sha512") assert.NoError(t, err) dataMap[base64.RawURLEncoding.EncodeToString(hash)] = bytes } - for i := 0; i < len(*certs); i++ { - bytes := (*certs)[i].Raw + for i := 0; i < len(certs); i++ { + bytes := (certs)[i].Raw hash, err := x509_cert.Hash(bytes, "sha512") assert.NoError(t, err) fileBytes := dataMap[base64.RawURLEncoding.EncodeToString(hash)] diff --git a/uzi_vc_issuer/ura_issuer.go b/uzi_vc_issuer/ura_issuer.go index 87cf79d..ba0bba3 100644 --- a/uzi_vc_issuer/ura_issuer.go +++ b/uzi_vc_issuer/ura_issuer.go @@ -10,82 +10,79 @@ import ( "encoding/pem" "errors" "fmt" + "regexp" + "time" + "github.com/google/uuid" "github.com/lestrrat-go/jwx/v2/cert" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" ssi "github.com/nuts-foundation/go-did" + "github.com/nuts-foundation/go-did/vc" "github.com/nuts-foundation/uzi-did-x509-issuer/ca_certs" "github.com/nuts-foundation/uzi-did-x509-issuer/did_x509" pem2 "github.com/nuts-foundation/uzi-did-x509-issuer/pem" "github.com/nuts-foundation/uzi-did-x509-issuer/uzi_vc_validator" "github.com/nuts-foundation/uzi-did-x509-issuer/x509_cert" - "regexp" - "time" ) -import "github.com/nuts-foundation/go-did/vc" -var RegexOtherNameValue = regexp.MustCompile(`2\.16\.528\.1\.1007.\d+\.\d+-\d+-\d+-S-(\d+)-00\.000-\d+`) +// RegexOtherNameValue matches thee OtherName field: ----- +// e.g.: 1-123456789-S-88888801-00.000-12345678 +// var RegexOtherNameValue = regexp.MustCompile(`2\.16\.528\.1\.1007.\d+\.\d+-\d+-\d+-S-(\d+)-00\.000-\d+`) +var RegexOtherNameValue = regexp.MustCompile(`\d+-\d+-S-(\d+)-00\.000-\d+`) // Issue generates a URA Verifiable Credential using provided certificate, signing key, subject DID, and subject name. -func Issue(certificateFile string, signingKeyFile string, subjectDID string, test bool) (string, error) { - certificate, err := pem2.ParseFileOrPath(certificateFile, "CERTIFICATE") +func Issue(certificateFile string, signingKeyFile string, subjectDID string, allowTestUraCa bool) (string, error) { + pemBlocks, err := pem2.ParseFileOrPath(certificateFile, "CERTIFICATE") if err != nil { return "", err } - _certificates, err := x509_cert.ParseCertificates(certificate) - if err != nil { - return "", err - } - if len(*_certificates) != 1 { - err = fmt.Errorf("did not find exactly one certificate in file %s", certificateFile) - return "", err - } - - chain, err := ca_certs.GetDERs(test) - if err != nil { - return "", err + allowSelfSignedCa := len(pemBlocks) > 1 + if len(pemBlocks) == 1 { + certificate := pemBlocks[0] + pemBlocks, err = ca_certs.GetDERs(allowTestUraCa) + if err != nil { + return "", err + } + pemBlocks = append(pemBlocks, certificate) } - _chain := append(*chain, *certificate...) - chain = &_chain signingKeys, err := pem2.ParseFileOrPath(signingKeyFile, "PRIVATE KEY") if err != nil { return "", err } - if signingKeys == nil { + if len(signingKeys) == 0 { err := fmt.Errorf("no signing keys found") return "", err - } - var signingKey *[]byte - if len(*signingKeys) == 1 { - signingKey = &(*signingKeys)[0] - } else { - err := fmt.Errorf("no signing keys found") + privateKey, err := x509_cert.ParsePrivateKey(signingKeys[0]) + if err != nil { return "", err } - privateKey, err := x509_cert.ParsePrivateKey(signingKey) + + certs, err := x509_cert.ParseCertificates(pemBlocks) if err != nil { return "", err } - certChain, err := x509_cert.ParseCertificates(chain) + chain := BuildCertificateChain(certs) + err = validateChain(chain) if err != nil { + fmt.Println("error validating chain: ", err) return "", err } - credential, err := BuildUraVerifiableCredential(certChain, privateKey, subjectDID) + credential, err := BuildUraVerifiableCredential(chain, privateKey, subjectDID) if err != nil { return "", err } - marshal, err := json.Marshal(credential) + credentialJSON, err := json.Marshal(credential) if err != nil { return "", err } - validator := uzi_vc_validator.NewUraValidator(test) - jwtString := string(marshal) + validator := uzi_vc_validator.NewUraValidator(allowTestUraCa, allowSelfSignedCa) + jwtString := string(credentialJSON) jwtString = jwtString[1:] // Chop start jwtString = jwtString[:len(jwtString)-1] // Chop end err = validator.Validate(jwtString) @@ -96,25 +93,25 @@ func Issue(certificateFile string, signingKeyFile string, subjectDID string, tes } // BuildUraVerifiableCredential constructs a verifiable credential with specified certificates, signing key, subject DID. -func BuildUraVerifiableCredential(certificates *[]x509.Certificate, signingKey *rsa.PrivateKey, subjectDID string) (*vc.VerifiableCredential, error) { - signingCert, otherNameValue, err := x509_cert.FindSigningCertificate(certificates) - if err != nil { - return nil, err - } - chain := BuildCertificateChain(certificates, signingCert) - err = validateChain(chain) - if err != nil { - return nil, err +func BuildUraVerifiableCredential(chain []*x509.Certificate, signingKey *rsa.PrivateKey, subjectDID string) (*vc.VerifiableCredential, error) { + if len(chain) == 0 { + return nil, errors.New("empty certificate chain") } - did, err := did_x509.CreateDid(chain) + did, err := did_x509.CreateDid(chain[0], chain[len(chain)-1]) if err != nil { return nil, err } + // signing cert is at the start of the chain + signingCert := chain[0] serialNumber := signingCert.Subject.SerialNumber if serialNumber == "" { return nil, errors.New("serialNumber not found in signing certificate ") } uzi := serialNumber + otherNameValue, _, err := x509_cert.FindOtherName(signingCert) + if err != nil { + return nil, err + } template, err := uraCredential(did, otherNameValue, uzi, subjectDID) if err != nil { return nil, err @@ -137,7 +134,7 @@ func BuildUraVerifiableCredential(certificates *[]x509.Certificate, signingKey * } // x5c - serializedCert, err := marshalChain(chain) + serializedCert, err := marshalChain(chain...) if err != nil { return "", err } @@ -164,57 +161,77 @@ func BuildUraVerifiableCredential(certificates *[]x509.Certificate, signingKey * // marshalChain converts a slice of x509.Certificate instances to a cert.Chain, encoding each certificate as PEM. // It returns the PEM-encoded cert.Chain and an error if the encoding or header fixation fails. -func marshalChain(certificates *[]x509.Certificate) (*cert.Chain, error) { - rv := &cert.Chain{} - certs := *certificates - for i, _ := range certs { - certificate := certs[len(certs)-i-1] - err := rv.Add(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certificate.Raw})) +func marshalChain(certificates ...*x509.Certificate) (*cert.Chain, error) { + chainPems := &cert.Chain{} + for _, certificate := range certificates { + err := chainPems.Add(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certificate.Raw})) if err != nil { return nil, err } } - rv, err := x509_cert.FixChainHeaders(rv) - return rv, err + headers, err := x509_cert.FixChainHeaders(chainPems) + return headers, err } -func validateChain(certificates *[]x509.Certificate) error { - certs := *certificates +func validateChain(certs []*x509.Certificate) error { var prev *x509.Certificate = nil for i := range certs { certificate := certs[len(certs)-i-1] if prev != nil { - err := prev.CheckSignatureFrom(&certificate) + err := prev.CheckSignatureFrom(certificate) if err != nil { return err } } - if x509_cert.IsRootCa(&certificate) { + if x509_cert.IsRootCa(certificate) { return nil } - prev = &certificate + prev = certificate } return errors.New("failed to find a path to the root certificate in the chain, are you using a (Test) URA server certificate (Hint: the --test mode is required for Test URA server certificates)") } // BuildCertificateChain constructs a certificate chain from a given list of certificates and a starting signing certificate. // It recursively finds parent certificates for non-root CAs and appends them to the chain. -func BuildCertificateChain(certs *[]x509.Certificate, signingCert *x509.Certificate) *[]x509.Certificate { - var chain []x509.Certificate +// It assumes the list might not be in order. +// The returning chain contains the signing cert at the start and the root cert at the end. +func BuildCertificateChain(certs []*x509.Certificate) []*x509.Certificate { + var signingCert *x509.Certificate + for _, c := range certs { + if !c.IsCA { + signingCert = c + break + } + } if signingCert == nil { - return &chain + fmt.Println("failed to find signing certificate") + return nil } - if !x509_cert.IsRootCa(signingCert) { - for _, parent := range *certs { - err := signingCert.CheckSignatureFrom(&parent) + + var chain []*x509.Certificate + chain = append(chain, signingCert) + + certToCheck := signingCert + for !x509_cert.IsRootCa(certToCheck) { + found := false + for _, c := range certs { + if c.Equal(signingCert) { + continue + } + err := certToCheck.CheckSignatureFrom(c) if err == nil { - parentChain := BuildCertificateChain(certs, &parent) - chain = append(chain, *parentChain...) + chain = append(chain, c) + certToCheck = c + found = true + break } } + if !found { + fmt.Println("failed to find path from signingCert to root") + return nil + } } - chain = append(chain, *signingCert) - return &chain + return chain } // convertClaims converts a map of claims to a JWT token. diff --git a/uzi_vc_issuer/ura_issuer_test.go b/uzi_vc_issuer/ura_issuer_test.go index b213054..04804bf 100644 --- a/uzi_vc_issuer/ura_issuer_test.go +++ b/uzi_vc_issuer/ura_issuer_test.go @@ -22,14 +22,14 @@ func TestBuildUraVerifiableCredential(t *testing.T) { tests := []struct { name string - in func() (*[]x509.Certificate, *rsa.PrivateKey, string) + in func() ([]*x509.Certificate, *rsa.PrivateKey, string) want func(error) bool }{ { name: "invalid signing certificate", - in: func() (*[]x509.Certificate, *rsa.PrivateKey, string) { - certs := []x509.Certificate{*cert} - return &certs, privKey, "did:example:123" + in: func() ([]*x509.Certificate, *rsa.PrivateKey, string) { + certs := []*x509.Certificate{cert} + return certs, privKey, "did:example:123" }, want: func(err error) bool { return err != nil diff --git a/uzi_vc_validator/ura_validator.go b/uzi_vc_validator/ura_validator.go index dc140a8..5951e3f 100644 --- a/uzi_vc_validator/ura_validator.go +++ b/uzi_vc_validator/ura_validator.go @@ -22,7 +22,12 @@ type UraValidator interface { } type UraValidatorImpl struct { - test bool + allowUziTestCa bool + allowSelfSignedCa bool +} + +func NewUraValidator(allowUziTestCa bool, allowSelfSignedCa bool) *UraValidatorImpl { + return &UraValidatorImpl{allowUziTestCa, allowSelfSignedCa} } type JwtHeaderValues struct { @@ -34,7 +39,8 @@ type JwtHeaderValues struct { func (u UraValidatorImpl) Validate(jwtString string) error { credential := &vc.VerifiableCredential{} - err := json.Unmarshal([]byte(fmt.Sprintf("\"%s\"", jwtString)), credential) + marshal, _ := json.Marshal(jwtString) + err := json.Unmarshal(marshal, credential) if err != nil { return err } @@ -57,7 +63,7 @@ func (u UraValidatorImpl) Validate(jwtString string) error { return err } - err = validateChain(signingCert, chainCertificates, u.test) + err = validateChain(signingCert, chainCertificates, u.allowUziTestCa, u.allowSelfSignedCa) if err != nil { return err } @@ -76,7 +82,7 @@ func (u UraValidatorImpl) Validate(jwtString string) error { } if ura != parseDid.Ura { - return fmt.Errorf("Ura in credential does not match Ura in signing certificate") + return fmt.Errorf("URA in credential does not match Ura in signing certificate") } if sanType != parseDid.SanType { return fmt.Errorf("SanType in credential does not match SanType in signing certificate") @@ -85,31 +91,27 @@ func (u UraValidatorImpl) Validate(jwtString string) error { return nil } -func validateChain(signingCert *x509.Certificate, certificates *[]x509.Certificate, includeTest bool) error { - var err error - intermediates := x509.NewCertPool() +// func validateChain(signingCert *x509.Certificate, certificates []*x509.Certificate, includeTest bool) error { +func validateChain(signingCert *x509.Certificate, chain []*x509.Certificate, allowUziTestCa bool, allowSelfSignedCa bool) error { + roots := x509.NewCertPool() - for _, c := range *certificates { - if x509_cert.IsIntermediateCa(&c) { - intermediates.AddCert(&c) - } else if x509_cert.IsRootCa(&c) { - roots.AddCert(&c) + intermediates := x509.NewCertPool() + var err error + + if allowSelfSignedCa { + roots.AddCert(chain[len(chain)-1]) + for i := 1; i < len(chain)-1; i++ { + intermediates.AddCert(chain[i]) + } + } else { + roots, intermediates, err = ca_certs.GetCertPools(allowUziTestCa) + if err != nil { + return err } } - - // First validate against the own provided pool err = validate(signingCert, roots, intermediates) if err != nil { - err = fmt.Errorf("could not validate against own provided pool: %s", err.Error()) - return err - } - root, intermediate, err := ca_certs.GetCertPools(includeTest) - if err != nil { - return err - } - err = validate(signingCert, root, intermediate) - if err != nil { - err = fmt.Errorf("could not validate against the CA pool from zorgcsp (includeTest=%v): %s", includeTest, err.Error()) + err = fmt.Errorf("could not validate against the CA pool. %s", err.Error()) return err } return nil @@ -127,34 +129,35 @@ func validate(signingCert *x509.Certificate, roots *x509.CertPool, intermediates return nil } -func findSigningCertificate(certificates *[]x509.Certificate, thumbprint string) (*x509.Certificate, error) { - for _, c := range *certificates { +func findSigningCertificate(certificates []*x509.Certificate, thumbprint string) (*x509.Certificate, error) { + for _, c := range certificates { hashSha1 := sha1.Sum(c.Raw) - if base64.RawURLEncoding.EncodeToString(hashSha1[:]) == thumbprint { - return &c, nil + hashedCert := base64.RawURLEncoding.EncodeToString(hashSha1[:]) + if hashedCert == thumbprint { + return c, nil } } return nil, fmt.Errorf("Could not find certificate with thumbprint %s", thumbprint) } -func parseCertificate(chain *cert.Chain) (*[]x509.Certificate, error) { - var certificates []x509.Certificate +func parseCertificate(chain *cert.Chain) ([]*x509.Certificate, error) { + var certificates []*x509.Certificate for i := 0; i < chain.Len(); i++ { bytes, _ := chain.Get(i) blocks := pem2.ParsePemBlocks(bytes, "CERTIFICATE") - for _, block := range *blocks { + for _, block := range blocks { found, err := x509.ParseCertificates(block) if err != nil { return nil, err } for _, c := range found { if c != nil { - certificates = append(certificates, *c) + certificates = append(certificates, c) } } } } - return &certificates, nil + return certificates, nil } func parseJwtHeaderValues(jwtString string) (*JwtHeaderValues, error) { @@ -172,7 +175,3 @@ func parseJwtHeaderValues(jwtString string) (*JwtHeaderValues, error) { } return metadata, nil } - -func NewUraValidator(test bool) *UraValidatorImpl { - return &UraValidatorImpl{test} -} diff --git a/x509_cert/x509_cert.go b/x509_cert/x509_cert.go index 7b61ba4..beb352d 100644 --- a/x509_cert/x509_cert.go +++ b/x509_cert/x509_cert.go @@ -44,33 +44,37 @@ func Hash(data []byte, alg string) ([]byte, error) { // ParseCertificates parses a slice of DER-encoded byte arrays into a slice of x509.Certificate. // It returns an error if any of the certificates cannot be parsed. -func ParseCertificates(derChain *[][]byte) (*[]x509.Certificate, error) { +func ParseCertificates(derChain [][]byte) ([]*x509.Certificate, error) { if derChain == nil { return nil, fmt.Errorf("derChain is nil") } - chain := make([]x509.Certificate, len(*derChain)) + chain := make([]*x509.Certificate, len(derChain)) - for i, certBytes := range *derChain { + for i, certBytes := range derChain { certificate, err := x509.ParseCertificate(certBytes) if err != nil { return nil, err } - chain[i] = *certificate + chain[i] = certificate } - return &chain, nil + return chain, nil } // ParsePrivateKey parses a DER-encoded private key into an *rsa.PrivateKey. // It returns an error if the key is not in PKCS8 format or not an RSA key. -func ParsePrivateKey(der *[]byte) (*rsa.PrivateKey, error) { +func ParsePrivateKey(der []byte) (*rsa.PrivateKey, error) { if der == nil { return nil, fmt.Errorf("der is nil") } - key, err := x509.ParsePKCS8PrivateKey(*der) + key, err := x509.ParsePKCS8PrivateKey(der) if err != nil { - return nil, err + key, err = x509.ParsePKCS1PrivateKey(der) + if err != nil { + return nil, err + } } + if _, ok := key.(*rsa.PrivateKey); !ok { return nil, fmt.Errorf("key is not RSA") } diff --git a/x509_cert/x509_cert_mock.go b/x509_cert/x509_cert_mock.go index 9e9fb35..0abe8f2 100644 --- a/x509_cert/x509_cert_mock.go +++ b/x509_cert/x509_cert_mock.go @@ -1,12 +1,12 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: uzi_vc_issuer/x509_cert.go +// Source: x509_cert/x509_cert.go // // Generated by this command: // -// mockgen -destination=uzi_vc_issuer/x509_cert_mock.go -package=uzi_vc_issuer -source=uzi_vc_issuer/x509_cert.go +// mockgen -destination=x509_cert/x509_cert_mock.go -package=x509_cert -source=x509_cert/x509_cert.go // -// Package uzi_vc_issuer is a generated GoMock package. +// Package x509_cert is a generated GoMock package. package x509_cert import ( @@ -41,10 +41,10 @@ func (m *MockChainParser) EXPECT() *MockChainParserMockRecorder { } // ParseCertificates mocks base method. -func (m *MockChainParser) ParseCertificates(derChain *[][]byte) (*[]x509.Certificate, error) { +func (m *MockChainParser) ParseCertificates(derChain [][]byte) ([]*x509.Certificate, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ParseCertificates", derChain) - ret0, _ := ret[0].(*[]x509.Certificate) + ret0, _ := ret[0].([]*x509.Certificate) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -56,7 +56,7 @@ func (mr *MockChainParserMockRecorder) ParseCertificates(derChain any) *gomock.C } // ParsePrivateKey mocks base method. -func (m *MockChainParser) ParsePrivateKey(der *[]byte) (*rsa.PrivateKey, error) { +func (m *MockChainParser) ParsePrivateKey(der []byte) (*rsa.PrivateKey, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ParsePrivateKey", der) ret0, _ := ret[0].(*rsa.PrivateKey) diff --git a/x509_cert/x509_cert_test.go b/x509_cert/x509_cert_test.go index 9686127..8aa7285 100644 --- a/x509_cert/x509_cert_test.go +++ b/x509_cert/x509_cert_test.go @@ -91,12 +91,12 @@ func TestParseChain(t *testing.T) { testCases := []struct { name string - derChain *[][]byte + derChain [][]byte errMsg string }{ { name: "Valid Certificates", - derChain: &derChains, + derChain: derChains, }, { name: "Nil ChainPem", @@ -126,16 +126,16 @@ func TestParsePrivateKey(t *testing.T) { pkcs1PrivateKey := x509.MarshalPKCS1PrivateKey(privateKey) testCases := []struct { name string - der *[]byte + der []byte errMsg string }{ { name: "ValidPrivateKey", - der: &privateKeyBytes, + der: privateKeyBytes, }, { name: "InvalidPrivateKey", - der: &pkcs1PrivateKey, + der: pkcs1PrivateKey, errMsg: "x509: failed to parse private key (use ParsePKCS1PrivateKey instead for this key format)", }, { diff --git a/x509_cert/x509_test_utils.go b/x509_cert/x509_test_utils.go index 8fb50a1..d2f825f 100644 --- a/x509_cert/x509_test_utils.go +++ b/x509_cert/x509_test_utils.go @@ -1,6 +1,7 @@ package x509_cert import ( + "bytes" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -8,20 +9,43 @@ import ( "encoding/asn1" "encoding/pem" "fmt" - "github.com/lestrrat-go/jwx/v2/cert" "math/big" "time" + + "github.com/lestrrat-go/jwx/v2/cert" +) + +const ( + CertificateBlockType = "CERTIFICATE" + RSAPrivKeyBlockType = "PRIVATE KEY" ) +func EncodeRSAPrivateKey(key *rsa.PrivateKey) ([]byte, error) { + b := bytes.Buffer{} + err := pem.Encode(&b, &pem.Block{Type: RSAPrivKeyBlockType, Bytes: x509.MarshalPKCS1PrivateKey(key)}) + if err != nil { + return []byte{}, err + } + return b.Bytes(), nil +} + +func EncodeCertificates(certs ...*x509.Certificate) ([]byte, error) { + b := bytes.Buffer{} + for _, c := range certs { + if err := pem.Encode(&b, &pem.Block{Type: CertificateBlockType, Bytes: c.Raw}); err != nil { + return []byte{}, err + } + } + return b.Bytes(), nil +} + // BuildCertChain generates a certificate chain, including root, intermediate, and signing certificates. -func BuildCertChain(identifier string) (*[]x509.Certificate, *cert.Chain, *x509.Certificate, *rsa.PrivateKey, *x509.Certificate, error) { - chain := [4]x509.Certificate{} - chainPems := &cert.Chain{} +func BuildCertChain(identifier string) ([]*x509.Certificate, *cert.Chain, *x509.Certificate, *rsa.PrivateKey, *x509.Certificate, error) { rootKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, nil, nil, nil, nil, err } - rootCertTmpl, err := CertTemplate(nil) + rootCertTmpl, err := CertTemplate(nil, "Root CA") if err != nil { return nil, nil, nil, nil, nil, err } @@ -32,17 +56,12 @@ func BuildCertChain(identifier string) (*[]x509.Certificate, *cert.Chain, *x509. if err != nil { return nil, nil, nil, nil, nil, err } - chain[0] = *rootCert - err = chainPems.Add(rootPem) - if err != nil { - return nil, nil, nil, nil, nil, err - } intermediateL1Key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, nil, nil, nil, nil, err } - intermediateL1Tmpl, err := CertTemplate(nil) + intermediateL1Tmpl, err := CertTemplate(nil, "Intermediate CA Level 1") if err != nil { return nil, nil, nil, nil, nil, err } @@ -52,17 +71,12 @@ func BuildCertChain(identifier string) (*[]x509.Certificate, *cert.Chain, *x509. if err != nil { return nil, nil, nil, nil, nil, err } - chain[1] = *intermediateL1Cert - err = chainPems.Add(intermediateL1Pem) - if err != nil { - return nil, nil, nil, nil, nil, err - } intermediateL2Key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, nil, nil, nil, nil, err } - intermediateL2Tmpl, err := CertTemplate(nil) + intermediateL2Tmpl, err := CertTemplate(nil, "Intermediate CA Level 2") if err != nil { return nil, nil, nil, nil, nil, err } @@ -72,11 +86,6 @@ func BuildCertChain(identifier string) (*[]x509.Certificate, *cert.Chain, *x509. if err != nil { return nil, nil, nil, nil, nil, err } - chain[2] = *intermediateL2Cert - err = chainPems.Add(intermediateL2Pem) - if err != nil { - return nil, nil, nil, nil, nil, err - } signingKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { @@ -93,30 +102,40 @@ func BuildCertChain(identifier string) (*[]x509.Certificate, *cert.Chain, *x509. if err != nil { return nil, nil, nil, nil, nil, err } - chain[3] = *signingCert - err = chainPems.Add(signingPEM) - if err != nil { - return nil, nil, nil, nil, nil, err + + chain := [4]*x509.Certificate{} + for i, c := range []*x509.Certificate{signingCert, intermediateL2Cert, intermediateL1Cert, rootCert} { + chain[i] = c + } + + chainPems := &cert.Chain{} + for _, p := range [][]byte{signingPEM, intermediateL2Pem, intermediateL1Pem, rootPem} { + err = chainPems.Add(p) + if err != nil { + return nil, nil, nil, nil, nil, err + } } + chainPems, err = FixChainHeaders(chainPems) if err != nil { return nil, nil, nil, nil, nil, err } _chain := chain[:] - return &_chain, chainPems, rootCert, signingKey, signingCert, nil + return _chain, chainPems, rootCert, signingKey, signingCert, nil } // CertTemplate generates a template for a x509 certificate with a given serial number. If no serial number is provided, a random one is generated. // The certificate is valid for one month and uses SHA256 with RSA for the signature algorithm. -func CertTemplate(serialNumber *big.Int) (*x509.Certificate, error) { +func CertTemplate(serialNumber *big.Int, organization string) (*x509.Certificate, error) { // generate a random serial number (a real cert authority would have some logic behind this) if serialNumber == nil { serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 8) serialNumber, _ = rand.Int(rand.Reader, serialNumberLimit) } tmpl := x509.Certificate{ + IsCA: true, SerialNumber: serialNumber, - Subject: pkix.Name{Organization: []string{"JaegerTracing"}}, + Subject: pkix.Name{Organization: []string{organization}}, SignatureAlgorithm: x509.SHA256WithRSA, NotBefore: time.Now(), NotAfter: time.Now().Add(time.Hour * 24 * 30), // valid for a month @@ -161,11 +180,11 @@ func SigningCertTemplate(serialNumber *big.Int, identifier string) (*x509.Certif tmpl := x509.Certificate{ SerialNumber: serialNumber, - Subject: pkix.Name{Organization: []string{"JaegerTracing"}}, + Subject: pkix.Name{Organization: []string{"FauxCare"}}, SignatureAlgorithm: x509.SHA256WithRSA, NotBefore: time.Now(), NotAfter: time.Now().Add(time.Hour * 24 * 30), // valid for a month - EmailAddresses: []string{"roland@edia.nl"}, + EmailAddresses: []string{"roland@headease.nl"}, BasicConstraintsValid: true, ExtraExtensions: []pkix.Extension{ { @@ -212,6 +231,7 @@ func CreateCert(template, parent *x509.Certificate, pub interface{}, parentPriv } // DebugUnmarshall recursively unmarshalls ASN.1 encoded data and prints the structure with parsed values. +// Keep this method for debug purposes in the future. func DebugUnmarshall(data []byte, depth int) error { for len(data) > 0 { var x asn1.RawValue diff --git a/x509_cert/x509_utils.go b/x509_cert/x509_utils.go index 78292a5..4bf3674 100644 --- a/x509_cert/x509_utils.go +++ b/x509_cert/x509_utils.go @@ -30,7 +30,7 @@ func FindOtherName(certificate *x509.Certificate) (string, SanTypeName, error) { if otherNameValue != "" { return otherNameValue, SAN_TYPE_OTHER_NAME, nil } - err = errors.New("no certificate found in the SAN attributes, please check if the certificate is an UZI Server Certificate") + err = errors.New("no otherName found in the SAN attributes, please check if the certificate is an UZI Server Certificate") return "", "", err } @@ -118,19 +118,21 @@ func IsIntermediateCa(signingCert *x509.Certificate) bool { // FindSigningCertificate searches the provided certificate chain for a certificate with a specific SAN and Permanent Identifier. // It returns the found certificate, its IdentifierValue, and an error if no matching certificate is found. -func FindSigningCertificate(chain *[]x509.Certificate) (*x509.Certificate, string, error) { - if len(*chain) == 0 { +func FindSigningCertificate(chain []*x509.Certificate) (*x509.Certificate, string, error) { + if len(chain) == 0 { return nil, "", fmt.Errorf("no certificates provided") } var err error var otherNameValue string - for _, c := range *chain { - otherNameValue, _, err = FindOtherName(&c) + for _, c := range chain { + otherNameValue, _, err = FindOtherName(c) if err != nil { + fmt.Printf("info: no SAN in certificate: %v\n", err) continue } if otherNameValue != "" { - return &c, otherNameValue, nil + fmt.Printf("info: found SAN in certificate: %v\n", otherNameValue) + return c, otherNameValue, nil } } return nil, "", err