Skip to content

Commit

Permalink
update vectors and spec error; make tests more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
gabe committed Apr 2, 2024
1 parent a962d4e commit 937a919
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 53 deletions.
51 changes: 46 additions & 5 deletions impl/internal/did/did.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ func (DHT) Method() did.Method {
}

type CreateDIDDHTOpts struct {
// AuthoritativeGateways is a list of authoritative gateways for the DID Document
AuthoritativeGateways []string
// Controller is the DID Controller, can be a list of DIDs
Controller []string
// AlsoKnownAs is a list of alternative identifiers for the DID Document
Expand Down Expand Up @@ -311,9 +313,9 @@ func (d DHT) ToDNSPacket(doc did.Document, types []TypeIndex) (*dns.Msg, error)
txtRecord := fmt.Sprintf("id=%s;t=%d;k=%s", vmKeyFragment, keyType, keyBase64URL)

// only include the alg if it's not the default alg for the key type
forKeyType := algIsDefaultForKeyType(*vm.PublicKeyJWK)
forKeyType := algIsDefaultForJWK(*vm.PublicKeyJWK)
if !forKeyType {
txtRecord += fmt.Sprintf(";alg=%s", vm.PublicKeyJWK.ALG)
txtRecord += fmt.Sprintf(";a=%s", vm.PublicKeyJWK.ALG)
}

// note the controller if it differs from the DID
Expand Down Expand Up @@ -492,7 +494,12 @@ func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, []TypeIndex, error) {
switch record := rr.(type) {
case *dns.TXT:
if strings.HasPrefix(record.Hdr.Name, "_cnt") {
doc.Controller = strings.Split(record.Txt[0], ",")
controllers := strings.Split(record.Txt[0], ",")
if len(controllers) == 1 {
doc.Controller = controllers[0]
} else {
doc.Controller = controllers
}
}
if strings.HasPrefix(record.Hdr.Name, "_aka") {
doc.AlsoKnownAs = strings.Split(record.Txt[0], ",")
Expand All @@ -503,6 +510,7 @@ func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, []TypeIndex, error) {
keyType := keyTypeLookUp(data["t"])
keyBase64URL := data["k"]
controller := data["c"]
alg := data["a"]

// set the controller to the DID if it's not provided
if controller == "" {
Expand All @@ -524,6 +532,17 @@ func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, []TypeIndex, error) {
return nil, nil, err
}

// set the algorithm if it's not the default for the key type
if alg == "" {
defaultAlg := defaultAlgForJWK(*pubKeyJWK)
if defaultAlg == "" {
return nil, nil, fmt.Errorf("unable to provide default alg for unsupported key type: %s", keyType)
}
pubKeyJWK.ALG = defaultAlg
} else {
pubKeyJWK.ALG = alg
}

// make sure the controller of the identity key matches the DID
if vmID == "0" && controller != d.String() {
return nil, nil, fmt.Errorf("controller of identity key must be the DID itself, instead it is: %s", controller)
Expand Down Expand Up @@ -656,9 +675,9 @@ func parseTxtData(data string) map[string]string {
return result
}

// algIsDefaultForKeyType returns true if the given JWK ALG is the default for the given key type
// algIsDefaultForJWK returns true if the given JWK ALG is the default for the given key type
// according to the key type index https://did-dht.com/registry/#key-type-index
func algIsDefaultForKeyType(jwk jwx.PublicKeyJWK) bool {
func algIsDefaultForJWK(jwk jwx.PublicKeyJWK) bool {
// Ed25519 : Ed25519
if jwk.CRV == crypto.Ed25519.String() && jwk.KTY == jwa.OKP.String() {
return jwk.ALG == string(crypto.Ed25519DSA)
Expand All @@ -678,6 +697,28 @@ func algIsDefaultForKeyType(jwk jwx.PublicKeyJWK) bool {
return false
}

// defaultAlgForJWK returns the default signature algorithm for the given JWK based on the key type index
// https://did-dht.com/registry/#key-type-index
func defaultAlgForJWK(jwk jwx.PublicKeyJWK) string {
// Ed25519 : Ed25519
if jwk.CRV == crypto.Ed25519.String() && jwk.KTY == jwa.OKP.String() {
return string(crypto.Ed25519DSA)
}
// secp256k1 : ES256K
if jwk.CRV == crypto.SECP256k1.String() && jwk.KTY == jwa.EC.String() {
return string(crypto.ES256K)
}
// P-256 : ES256
if jwk.CRV == crypto.P256.String() && jwk.KTY == jwa.EC.String() {
return string(crypto.ES256)
}
// X25519 : ECDH-ES+A256KW
if jwk.CRV == crypto.X25519.String() && jwk.KTY == jwa.OKP.String() {
return string(crypto.ECDHESA256KW)
}
return ""
}

// keyTypeLookUp returns the key type for the given key type index
// https://did-dht.com/registry/#key-type-index
func keyTypeLookUp(keyType string) crypto.KeyType {
Expand Down
135 changes: 111 additions & 24 deletions impl/internal/did/did_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package did
import (
"crypto/ed25519"
"fmt"
"strings"
"testing"

"github.com/TBD54566975/ssi-sdk/cryptosuite"
Expand Down Expand Up @@ -222,6 +223,7 @@ func TestToDNSPacket(t *testing.T) {

func TestVectors(t *testing.T) {
type testVectorDNSRecord struct {
Name string `json:"name"`
RecordType string `json:"type"`
TTL string `json:"ttl"`
Record string `json:"rdata"`
Expand All @@ -240,25 +242,57 @@ func TestVectors(t *testing.T) {

var expectedDIDDocument did.Document
retrieveTestVectorAs(t, vector1DIDDocument, &expectedDIDDocument)
assert.EqualValues(t, expectedDIDDocument, *doc)

docJSON, err := json.Marshal(doc)
require.NoError(t, err)

expectedDIDDocJSON, err := json.Marshal(expectedDIDDocument)
require.NoError(t, err)

assert.JSONEq(t, string(expectedDIDDocJSON), string(docJSON))

didID := DHT(doc.ID)
packet, err := didID.ToDNSPacket(*doc, nil)
require.NoError(t, err)
require.NotEmpty(t, packet)

var expectedDNSRecords map[string]testVectorDNSRecord
var expectedDNSRecords []testVectorDNSRecord
retrieveTestVectorAs(t, vector1DNSRecords, &expectedDNSRecords)

// Initialize a map to track matched records
matchedRecords := make(map[int]bool)
for i := range expectedDNSRecords {
matchedRecords[i] = false // Initialize all expected records as unmatched
}

for _, record := range packet.Answer {
expectedRecord, ok := expectedDNSRecords[record.Header().Name]
require.True(t, ok)
for i, expectedRecord := range expectedDNSRecords {
if record.Header().Name == expectedRecord.Name {
s := record.String()
if strings.Contains(s, expectedRecord.RecordType) &&
strings.Contains(s, expectedRecord.TTL) &&
strings.Contains(s, expectedRecord.Record) {
matchedRecords[i] = true // Mark as matched
break
}
}
}
}

s := record.String()
assert.Contains(t, s, expectedRecord.RecordType)
assert.Contains(t, s, expectedRecord.TTL)
assert.Contains(t, s, expectedRecord.Record)
// Check if all expected records have been matched
for i, matched := range matchedRecords {
require.True(t, matched, fmt.Sprintf("Expected DNS record %d: %+v not matched", i, expectedDNSRecords[i]))
}

// Make sure going back to DID Document is consistent
decodedDoc, types, err := didID.FromDNSPacket(packet)
require.NoError(t, err)
require.NotEmpty(t, decodedDoc)
require.Empty(t, types)

decodedDocJSON, err := json.Marshal(decodedDoc)
require.NoError(t, err)
assert.JSONEq(t, string(expectedDIDDocJSON), string(decodedDocJSON))
})

t.Run("test vector 2", func(t *testing.T) {
Expand All @@ -272,6 +306,10 @@ func TestVectors(t *testing.T) {
retrieveTestVectorAs(t, vector2PublicKeyJWK2, &secpJWK)

doc, err := CreateDIDDHTDID(pubKey.(ed25519.PublicKey), CreateDIDDHTOpts{
AuthoritativeGateways: []string{
"gateway1.example-did-dht-gateway.com.",
"gateway2.example-did-dht-gateway.com.",
},
Controller: []string{"did:example:abcd"},
AlsoKnownAs: []string{"did:example:efgh", "did:example:ijkl"},
VerificationMethods: []VerificationMethod{
Expand Down Expand Up @@ -311,18 +349,44 @@ func TestVectors(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, packet)

var expectedDNSRecords map[string]testVectorDNSRecord
var expectedDNSRecords []testVectorDNSRecord
retrieveTestVectorAs(t, vector2DNSRecords, &expectedDNSRecords)

// Initialize a map to track matched records
matchedRecords := make(map[int]bool)
for i := range expectedDNSRecords {
matchedRecords[i] = false // Initialize all expected records as unmatched
}

for _, record := range packet.Answer {
expectedRecord, ok := expectedDNSRecords[record.Header().Name]
require.True(t, ok, "record not found: %s", record.Header().Name)
for i, expectedRecord := range expectedDNSRecords {
if record.Header().Name == expectedRecord.Name {
s := record.String()
if strings.Contains(s, expectedRecord.RecordType) &&
strings.Contains(s, expectedRecord.TTL) &&
strings.Contains(s, expectedRecord.Record) {
matchedRecords[i] = true // Mark as matched
break
}
}
}
}

s := record.String()
assert.Contains(t, s, expectedRecord.RecordType)
assert.Contains(t, s, expectedRecord.TTL)
assert.Contains(t, s, expectedRecord.Record)
// Check if all expected records have been matched
for i, matched := range matchedRecords {
require.True(t, matched, fmt.Sprintf("Expected DNS record %d: %+v not matched", i, expectedDNSRecords[i]))
}

// Make sure going back to DID Document is consistent
decodedDoc, types, err := didID.FromDNSPacket(packet)
require.NoError(t, err)
require.NotEmpty(t, decodedDoc)
require.NotEmpty(t, types)
require.Equal(t, types, []TypeIndex{1, 2, 3})

decodedDocJSON, err := json.Marshal(decodedDoc)
require.NoError(t, err)
assert.JSONEq(t, string(expectedDIDDocJSON), string(decodedDocJSON))
})

t.Run("test vector 3", func(t *testing.T) {
Expand Down Expand Up @@ -366,20 +430,43 @@ func TestVectors(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, packet)

println(packet.String())

var expectedDNSRecords map[string]testVectorDNSRecord
var expectedDNSRecords []testVectorDNSRecord
retrieveTestVectorAs(t, vector3DNSRecords, &expectedDNSRecords)

// Initialize a map to track matched records
matchedRecords := make(map[int]bool)
for i := range expectedDNSRecords {
matchedRecords[i] = false // Initialize all expected records as unmatched
}

for _, record := range packet.Answer {
expectedRecord, ok := expectedDNSRecords[record.Header().Name]
require.True(t, ok, "record not found: %s", record.Header().Name)
for i, expectedRecord := range expectedDNSRecords {
if record.Header().Name == expectedRecord.Name {
s := record.String()
if strings.Contains(s, expectedRecord.RecordType) &&
strings.Contains(s, expectedRecord.TTL) &&
strings.Contains(s, expectedRecord.Record) {
matchedRecords[i] = true // Mark as matched
break
}
}
}
}

s := record.String()
assert.Contains(t, s, expectedRecord.RecordType)
assert.Contains(t, s, expectedRecord.TTL)
assert.Contains(t, s, expectedRecord.Record)
// Check if all expected records have been matched
for i, matched := range matchedRecords {
require.True(t, matched, fmt.Sprintf("Expected DNS record %d: %+v not matched", i, expectedDNSRecords[i]))
}

// Make sure going back to DID Document is consistent
decodedDoc, types, err := didID.FromDNSPacket(packet)
require.NoError(t, err)
require.NotEmpty(t, decodedDoc)
require.Empty(t, types)

decodedDocJSON, err := json.Marshal(decodedDoc)
require.NoError(t, err)
assert.JSONEq(t, string(expectedDIDDocJSON), string(decodedDocJSON))
})
}

Expand Down
12 changes: 7 additions & 5 deletions impl/internal/did/testdata/vector-1-dns-records.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
{
"_did.": {
[
{
"name": "_did.",
"type": "TXT",
"ttl": "7200",
"rdata": "vm=k0;auth=k0;asm=k0;inv=k0;del=k0"
"rdata": "v=0;vm=k0;auth=k0;asm=k0;inv=k0;del=k0"
},
"_k0._did.": {
{
"name": "_k0._did.",
"type": "TXT",
"ttl": "7200",
"rdata": "id=0;t=0;k=YCcHYL2sYNPDlKaALcEmll2HHyT968M4UWbr-9CFGWE"
}
}
]
27 changes: 17 additions & 10 deletions impl/internal/did/testdata/vector-2-dns-records.json
Original file line number Diff line number Diff line change
@@ -1,37 +1,44 @@
{
"_did.": {
[
{
"name": "_did.",
"type": "TXT",
"ttl": "7200",
"rdata": "vm=k0,k1;svc=s0;auth=k0;asm=k0,k1;inv=k0,k1;del=k0"
"rdata": "v=0;vm=k0,k1;svc=s0;auth=k0;asm=k0,k1;inv=k0,k1;del=k0"
},
"_cnt._did.": {
{
"name": "_cnt._did.",
"type": "TXT",
"ttl": "7200",
"rdata": "did:example:abcd"
},
"_aka._did.": {
{
"name": "_aka._did.",
"type": "TXT",
"ttl": "7200",
"rdata": "did:example:efgh,did:example:ijkl"
},
"_k0._did.": {
{
"name": "_k0._did.",
"type": "TXT",
"ttl": "7200",
"rdata": "id=0;t=0;k=YCcHYL2sYNPDlKaALcEmll2HHyT968M4UWbr-9CFGWE"
},
"_k1._did.": {
{
"name": "_k1._did.",
"type": "TXT",
"ttl": "7200",
"rdata": "t=1;k=Atf6NCChxjWpnrfPt1WDVE4ipYVSvi4pXCq4SUjx0jT9"
},
"_s0._did.": {
{
"name": "_s0._did.",
"type": "TXT",
"ttl": "7200",
"rdata": "id=service-1;t=TestService;se=https://test-service.com/1,https://test-service.com/2"
},
"_typ._did.": {
{
"name": "_typ._did.",
"type": "TXT",
"ttl": "7200",
"rdata": "id=1,2,3"
}
}
]
Loading

0 comments on commit 937a919

Please sign in to comment.