Skip to content

Commit

Permalink
cleanup resources properly
Browse files Browse the repository at this point in the history
  • Loading branch information
gabe committed Apr 11, 2024
1 parent ca33f80 commit 7cd8a5c
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 67 deletions.
2 changes: 1 addition & 1 deletion impl/cmd/cli/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ var identityGetCmd = &cobra.Command{
}

// get the identity from the dht
gotResp, err := d.Get(context.Background(), id)
gotResp, err := d.GetFull(context.Background(), id)
if err != nil {
logrus.WithError(err).Error("failed to get identity from dht")
return err
Expand Down
1 change: 1 addition & 0 deletions impl/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ require (
go.opentelemetry.io/otel/metric v1.25.0 // indirect
go.opentelemetry.io/proto/otlp v1.1.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/goleak v1.3.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/arch v0.7.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions impl/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ go.opentelemetry.io/proto/otlp v1.1.0 h1:2Di21piLrCqJ3U3eXGCTPHE9R8Nh+0uglSnOyxi
go.opentelemetry.io/proto/otlp v1.1.0/go.mod h1:GpBHCBWiqvVLDqmHZsoMM3C5ySeKTC7ej/RNTae6MdY=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
Expand Down
82 changes: 73 additions & 9 deletions impl/internal/dht/get.go → impl/internal/dht/getput.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"crypto/sha1"
"errors"
"math"
"sync"

"github.com/anacrolix/log"
k_nearest_nodes "github.com/anacrolix/dht/v2/k-nearest-nodes"
"github.com/anacrolix/torrent/bencode"
"github.com/sirupsen/logrus"

"github.com/anacrolix/dht/v2"
"github.com/anacrolix/dht/v2/bep44"
Expand All @@ -16,7 +18,7 @@ import (
)

// Copied from https://github.com/anacrolix/dht/blob/master/exts/getput/getput.go and modified
// to return signature data
// to return signature data and allow for context cancellations

type FullGetResult struct {
Seq int64
Expand All @@ -26,7 +28,7 @@ type FullGetResult struct {
}

func startGetTraversal(
target bep44.Target, s *dht.Server, seq *int64, salt []byte,
ctx context.Context, target bep44.Target, s *dht.Server, seq *int64, salt []byte,
) (
vChan chan FullGetResult, op *traversal.Operation, err error,
) {
Expand All @@ -35,11 +37,10 @@ func startGetTraversal(
Alpha: 15,
Target: target,
DoQuery: func(ctx context.Context, addr krpc.NodeAddr) traversal.QueryResult {
logger := log.ContextLogger(ctx)
res := s.Get(ctx, dht.NewAddr(addr.UDP()), target, seq, dht.QueryRateLimiting{})
err := res.ToError()
err = res.ToError()
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, dht.TransactionTimeout) {
logger.Levelf(log.Debug, "error querying %v: %v", addr, err)
logrus.WithContext(ctx).WithError(err).Debugf("error querying %v", addr)
}
if r := res.Reply.R; r != nil {
rv := r.V
Expand All @@ -64,7 +65,7 @@ func startGetTraversal(
case <-ctx.Done():
}
} else if rv != nil {
logger.Levelf(log.Debug, "get response item hash didn't match target: %q", rv)
logrus.WithContext(ctx).Debugf("get response item hash didn't match target: %q", rv)
}
}
tqr := res.TraversalQueryResult(addr)
Expand All @@ -78,6 +79,16 @@ func startGetTraversal(
},
NodeFilter: s.TraversalNodeFilter,
})

// list for context cancellation or stalled traversal
go func() {
select {
case <-ctx.Done():
op.Stop()
case <-op.Stalled():
}
}()

nodes, err := s.TraversalStartingNodes()
op.AddNodes(nodes)
return
Expand All @@ -88,7 +99,7 @@ func Get(
) (
ret FullGetResult, stats *traversal.Stats, err error,
) {
vChan, op, err := startGetTraversal(target, s, seq, salt)
vChan, op, err := startGetTraversal(ctx, target, s, seq, salt)
if err != nil {
return
}
Expand All @@ -101,7 +112,7 @@ receiveResults:
err = errors.New("value not found")
}
case v := <-vChan:
log.ContextLogger(ctx).Levelf(log.Debug, "received %#v", v)
logrus.WithContext(ctx).Debugf("received %#v", v)
gotValue = true
if !v.Mutable {
ret = v
Expand All @@ -118,3 +129,56 @@ receiveResults:
stats = op.Stats()
return
}

type SeqToPut func(seq int64) bep44.Put

func Put(
ctx context.Context, target krpc.ID, s *dht.Server, salt []byte, seqToPut SeqToPut,
) (
stats *traversal.Stats, err error,
) {
vChan, op, err := startGetTraversal(ctx, target, s,
// When we do a get traversal for a put, we don't care what seq the peers have?
nil,
// This is duplicated with the put, but we need it to filter responses for autoSeq.
salt)
if err != nil {
return
}
var autoSeq int64
notDone:
select {
case v := <-vChan:
if v.Mutable && v.Seq > autoSeq {
autoSeq = v.Seq
}
// There are more optimizations that can be done here. We can set CAS automatically, and we
// can skip updating the sequence number if the existing content already matches (and
// presumably republish the existing seq).
goto notDone
case <-op.Stalled():
case <-ctx.Done():
err = ctx.Err()
}
op.Stop()
var wg sync.WaitGroup
put := seqToPut(autoSeq)
op.Closest().Range(func(elem k_nearest_nodes.Elem) {
wg.Add(1)
go func() {
defer wg.Done()
// This is enforced by startGetTraversal.
token := elem.Data.(string)
res := s.Put(ctx, dht.NewAddr(elem.Addr.UDP()), put, token, dht.QueryRateLimiting{})
err = res.ToError()
if err != nil {
logrus.WithContext(ctx).WithError(err).Warnf("error putting to %v [token=%q]", elem.Addr, token)
} else {
logrus.WithContext(ctx).WithError(err).Debugf("put to %v [token=%q]", elem.Addr, token)
}
}()
})
wg.Wait()
stats = op.Stats()
return
}
20 changes: 10 additions & 10 deletions impl/internal/did/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ func TestClient(t *testing.T) {
t.Logf("time to put and get: %s", since)
}

func TestGet(t *testing.T) {
client, err := NewGatewayClient("http://localhost:8305")

require.NoError(t, err)
require.NotNil(t, client)

doc, _, _, err := client.GetDIDDocument("did:dht:uqaj3fcr9db6jg6o9pjs53iuftyj45r46aubogfaceqjbo6pp9sy")
require.NoError(t, err)
require.NotNil(t, doc)
}
// func TestGet(t *testing.T) {
// client, err := NewGatewayClient("http://localhost:8305")
//
// require.NoError(t, err)
// require.NotNil(t, client)
//
// doc, _, _, err := client.GetDIDDocument("did:dht:uqaj3fcr9db6jg6o9pjs53iuftyj45r46aubogfaceqjbo6pp9sy")
// require.NoError(t, err)
// require.NotNil(t, doc)
// }

func TestClientInvalidGateway(t *testing.T) {
g, err := NewGatewayClient("\n")
Expand Down
20 changes: 1 addition & 19 deletions impl/pkg/dht/dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
errutil "github.com/TBD54566975/ssi-sdk/util"
"github.com/anacrolix/dht/v2"
"github.com/anacrolix/dht/v2/bep44"
"github.com/anacrolix/dht/v2/exts/getput"
"github.com/anacrolix/log"
"github.com/anacrolix/torrent/types/infohash"
"github.com/pkg/errors"
Expand Down Expand Up @@ -93,7 +92,7 @@ func (d *DHT) Put(ctx context.Context, request bep44.Put) (string, error) {
}

key := util.Z32Encode(request.K[:])
t, err := getput.Put(ctx, request.Target(), d.Server, nil, func(int64) bep44.Put {
t, err := dhtint.Put(ctx, request.Target(), d.Server, nil, func(int64) bep44.Put {
return request
})
if err != nil {
Expand All @@ -107,23 +106,6 @@ func (d *DHT) Put(ctx context.Context, request bep44.Put) (string, error) {
return util.Z32Encode(request.K[:]), nil
}

// Get returns the BEP-44 result for the given key from the DHT.
// The key is a z32-encoded string, such as "yj47pezutnpw9pyudeeai8cx8z8d6wg35genrkoqf9k3rmfzy58o".
func (d *DHT) Get(ctx context.Context, key string) (*getput.GetResult, error) {
ctx, span := telemetry.GetTracer().Start(ctx, "DHT.Get")
defer span.End()

z32Decoded, err := util.Z32Decode(key)
if err != nil {
return nil, errors.Wrapf(err, "failed to decode key [%s]", key)
}
res, t, err := getput.Get(ctx, infohash.HashBytes(z32Decoded), d.Server, nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to get key[%s] from dht; tried %d nodes, got %d responses", key, t.NumAddrsTried, t.NumResponses)
}
return &res, nil
}

// GetFull returns the full BEP-44 result for the given key from the DHT, using our modified
// implementation of getput.Get. It should ONLY be used when it's needed to get the signature
// data for a record.
Expand Down
15 changes: 6 additions & 9 deletions impl/pkg/dht/dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@ import (
"github.com/anacrolix/torrent/bencode"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"github.com/TBD54566975/did-dht-method/internal/util"
dhtclient "github.com/TBD54566975/did-dht-method/pkg/dht"
)

func TestGetPutDHT(t *testing.T) {
ctx := context.Background()
defer goleak.VerifyNone(t)

ctx := context.Background()
d := dhtclient.NewTestDHT(t)
defer d.Close()

pubKey, privKey, err := util.GenerateKeypair()
require.NoError(t, err)
Expand All @@ -34,18 +37,12 @@ func TestGetPutDHT(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, id)

got, err := d.Get(ctx, id)
got, err := d.GetFull(ctx, id)
require.NoError(t, err)
require.NotEmpty(t, got)
require.Equal(t, bencode.Bytes(put.V.([]byte)), got.V[2:])
require.Equal(t, put.Seq, got.Seq)

full, err := d.GetFull(ctx, id)
require.NoError(t, err)
require.NotEmpty(t, full)
require.Equal(t, bencode.Bytes(put.V.([]byte)), full.V[2:])
require.Equal(t, put.Seq, full.Seq)
require.False(t, full.Mutable)
require.True(t, got.Mutable)

var payload string
err = bencode.Unmarshal(got.V, &payload)
Expand Down
5 changes: 3 additions & 2 deletions impl/pkg/dht/pkarr.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import (

"github.com/TBD54566975/ssi-sdk/util"
"github.com/anacrolix/dht/v2/bep44"
"github.com/anacrolix/dht/v2/exts/getput"
"github.com/anacrolix/torrent/bencode"
"github.com/miekg/dns"

"github.com/TBD54566975/did-dht-method/internal/dht"
)

// CreatePkarrPublishRequest creates a put request for the given records. Requires a public/private keypair and the records to put.
Expand Down Expand Up @@ -50,7 +51,7 @@ func CreatePkarrPublishRequest(privateKey ed25519.PrivateKey, msg dns.Msg) (*bep

// ParsePkarrGetResponse parses the response from a get request.
// The response is expected to be a slice of DNS resource records.
func ParsePkarrGetResponse(response getput.GetResult) (*dns.Msg, error) {
func ParsePkarrGetResponse(response dht.FullGetResult) (*dns.Msg, error) {
var payload string
if err := bencode.Unmarshal(response.V, &payload); err != nil {
return nil, util.LoggingErrorMsg(err, "failed to unmarshal payload value")
Expand Down
12 changes: 7 additions & 5 deletions impl/pkg/dht/pkarr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ import (
"github.com/TBD54566975/did-dht-method/internal/util"
)

func TestGetPutPKARRDHT(t *testing.T) {
d := NewTestDHT(t)
func TestGetPutPkarrDHT(t *testing.T) {
dht := NewTestDHT(t)
defer dht.Close()

_, privKey, err := util.GenerateKeypair()
require.NoError(t, err)
Expand All @@ -44,11 +45,11 @@ func TestGetPutPKARRDHT(t *testing.T) {
put, err := CreatePkarrPublishRequest(privKey, msg)
require.NoError(t, err)

id, err := d.Put(context.Background(), *put)
id, err := dht.Put(context.Background(), *put)
require.NoError(t, err)
require.NotEmpty(t, id)

got, err := d.Get(context.Background(), id)
got, err := dht.GetFull(context.Background(), id)
require.NoError(t, err)
require.NotEmpty(t, got)

Expand All @@ -61,6 +62,7 @@ func TestGetPutPKARRDHT(t *testing.T) {

func TestGetPutDIDDHT(t *testing.T) {
dht := NewTestDHT(t)
defer dht.Close()

pubKey, _, err := crypto.GenerateSECP256k1Key()
require.NoError(t, err)
Expand Down Expand Up @@ -108,7 +110,7 @@ func TestGetPutDIDDHT(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, gotID)

got, err := dht.Get(context.Background(), gotID)
got, err := dht.GetFull(context.Background(), gotID)
require.NoError(t, err)
require.NotEmpty(t, got)

Expand Down
13 changes: 7 additions & 6 deletions impl/pkg/server/pkarr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ import (
"github.com/TBD54566975/did-dht-method/pkg/storage"
)

func TestPKARRRouter(t *testing.T) {
pkarrSvc := testPKARRService(t)
func TestPkarrRouter(t *testing.T) {
pkarrSvc := testPkarrService(t)
pkarrRouter, err := NewPkarrRouter(&pkarrSvc)
require.NoError(t, err)
require.NotEmpty(t, pkarrRouter)

defer pkarrSvc.Close()

t.Run("test put record", func(t *testing.T) {
didID, reqData := generateDIDPutRequest(t)

Expand Down Expand Up @@ -146,16 +148,15 @@ func TestPKARRRouter(t *testing.T) {
})
}

func testPKARRService(t *testing.T) service.PkarrService {
func testPkarrService(t *testing.T) service.PkarrService {
defaultConfig := config.GetDefaultConfig()

db, err := storage.NewStorage(defaultConfig.ServerConfig.StorageURI)
require.NoError(t, err)
require.NotEmpty(t, db)

d := dht.NewTestDHT(t)

pkarrService, err := service.NewPkarrService(&defaultConfig, db, d)
dht := dht.NewTestDHT(t)
pkarrService, err := service.NewPkarrService(&defaultConfig, db, dht)
require.NoError(t, err)
require.NotEmpty(t, pkarrService)

Expand Down
4 changes: 4 additions & 0 deletions impl/pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ func TestHealthCheckAPI(t *testing.T) {
assert.NoError(t, err)
assert.NotEmpty(t, server)

defer server.Close()

req := httptest.NewRequest(http.MethodGet, testServerURL+"/health", nil)
w := httptest.NewRecorder()

Expand All @@ -41,6 +43,8 @@ func TestHealthCheckAPI(t *testing.T) {
err = json.NewDecoder(w.Body).Decode(&resp)
assert.NoError(t, err)
assert.Equal(t, HealthOK, resp.Status)

shutdown <- os.Interrupt
}

// Is2xxResponse returns true if the given status code is a 2xx response
Expand Down
Loading

0 comments on commit 7cd8a5c

Please sign in to comment.