From 7cd8a5cc24fdf214c9fe57081c55afe231ee4f05 Mon Sep 17 00:00:00 2001 From: gabe Date: Thu, 11 Apr 2024 14:52:12 -0400 Subject: [PATCH] cleanup resources properly --- impl/cmd/cli/identity.go | 2 +- impl/go.mod | 1 + impl/go.sum | 2 + impl/internal/dht/{get.go => getput.go} | 82 ++++++++++++++++++++++--- impl/internal/did/client_test.go | 20 +++--- impl/pkg/dht/dht.go | 20 +----- impl/pkg/dht/dht_test.go | 15 ++--- impl/pkg/dht/pkarr.go | 5 +- impl/pkg/dht/pkarr_test.go | 12 ++-- impl/pkg/server/pkarr_test.go | 13 ++-- impl/pkg/server/server_test.go | 4 ++ impl/pkg/service/pkarr.go | 28 ++++++++- impl/pkg/service/pkarr_test.go | 15 +++-- 13 files changed, 152 insertions(+), 67 deletions(-) rename impl/internal/dht/{get.go => getput.go} (53%) diff --git a/impl/cmd/cli/identity.go b/impl/cmd/cli/identity.go index bfb36d9a..e1516645 100644 --- a/impl/cmd/cli/identity.go +++ b/impl/cmd/cli/identity.go @@ -164,7 +164,7 @@ var identityGetCmd = &cobra.Command{ } // get the identity from the dht - gotResp, err := d.Get(context.Background(), id) + gotResp, err := d.GetFull(context.Background(), id) if err != nil { logrus.WithError(err).Error("failed to get identity from dht") return err diff --git a/impl/go.mod b/impl/go.mod index 949d2383..025a9f8d 100644 --- a/impl/go.mod +++ b/impl/go.mod @@ -135,6 +135,7 @@ require ( go.opentelemetry.io/otel/metric v1.25.0 // indirect go.opentelemetry.io/proto/otlp v1.1.0 // indirect go.uber.org/atomic v1.9.0 // indirect + go.uber.org/goleak v1.3.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/arch v0.7.0 // indirect golang.org/x/crypto v0.21.0 // indirect diff --git a/impl/go.sum b/impl/go.sum index 3a7a10b3..304c106a 100644 --- a/impl/go.sum +++ b/impl/go.sum @@ -570,6 +570,8 @@ go.opentelemetry.io/proto/otlp v1.1.0 h1:2Di21piLrCqJ3U3eXGCTPHE9R8Nh+0uglSnOyxi go.opentelemetry.io/proto/otlp v1.1.0/go.mod h1:GpBHCBWiqvVLDqmHZsoMM3C5ySeKTC7ej/RNTae6MdY= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= diff --git a/impl/internal/dht/get.go b/impl/internal/dht/getput.go similarity index 53% rename from impl/internal/dht/get.go rename to impl/internal/dht/getput.go index 10ac68b1..71d9a906 100644 --- a/impl/internal/dht/get.go +++ b/impl/internal/dht/getput.go @@ -5,9 +5,11 @@ import ( "crypto/sha1" "errors" "math" + "sync" - "github.com/anacrolix/log" + k_nearest_nodes "github.com/anacrolix/dht/v2/k-nearest-nodes" "github.com/anacrolix/torrent/bencode" + "github.com/sirupsen/logrus" "github.com/anacrolix/dht/v2" "github.com/anacrolix/dht/v2/bep44" @@ -16,7 +18,7 @@ import ( ) // Copied from https://github.com/anacrolix/dht/blob/master/exts/getput/getput.go and modified -// to return signature data +// to return signature data and allow for context cancellations type FullGetResult struct { Seq int64 @@ -26,7 +28,7 @@ type FullGetResult struct { } func startGetTraversal( - target bep44.Target, s *dht.Server, seq *int64, salt []byte, + ctx context.Context, target bep44.Target, s *dht.Server, seq *int64, salt []byte, ) ( vChan chan FullGetResult, op *traversal.Operation, err error, ) { @@ -35,11 +37,10 @@ func startGetTraversal( Alpha: 15, Target: target, DoQuery: func(ctx context.Context, addr krpc.NodeAddr) traversal.QueryResult { - logger := log.ContextLogger(ctx) res := s.Get(ctx, dht.NewAddr(addr.UDP()), target, seq, dht.QueryRateLimiting{}) - err := res.ToError() + err = res.ToError() if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, dht.TransactionTimeout) { - logger.Levelf(log.Debug, "error querying %v: %v", addr, err) + logrus.WithContext(ctx).WithError(err).Debugf("error querying %v", addr) } if r := res.Reply.R; r != nil { rv := r.V @@ -64,7 +65,7 @@ func startGetTraversal( case <-ctx.Done(): } } else if rv != nil { - logger.Levelf(log.Debug, "get response item hash didn't match target: %q", rv) + logrus.WithContext(ctx).Debugf("get response item hash didn't match target: %q", rv) } } tqr := res.TraversalQueryResult(addr) @@ -78,6 +79,16 @@ func startGetTraversal( }, NodeFilter: s.TraversalNodeFilter, }) + + // list for context cancellation or stalled traversal + go func() { + select { + case <-ctx.Done(): + op.Stop() + case <-op.Stalled(): + } + }() + nodes, err := s.TraversalStartingNodes() op.AddNodes(nodes) return @@ -88,7 +99,7 @@ func Get( ) ( ret FullGetResult, stats *traversal.Stats, err error, ) { - vChan, op, err := startGetTraversal(target, s, seq, salt) + vChan, op, err := startGetTraversal(ctx, target, s, seq, salt) if err != nil { return } @@ -101,7 +112,7 @@ receiveResults: err = errors.New("value not found") } case v := <-vChan: - log.ContextLogger(ctx).Levelf(log.Debug, "received %#v", v) + logrus.WithContext(ctx).Debugf("received %#v", v) gotValue = true if !v.Mutable { ret = v @@ -118,3 +129,56 @@ receiveResults: stats = op.Stats() return } + +type SeqToPut func(seq int64) bep44.Put + +func Put( + ctx context.Context, target krpc.ID, s *dht.Server, salt []byte, seqToPut SeqToPut, +) ( + stats *traversal.Stats, err error, +) { + vChan, op, err := startGetTraversal(ctx, target, s, + // When we do a get traversal for a put, we don't care what seq the peers have? + nil, + // This is duplicated with the put, but we need it to filter responses for autoSeq. + salt) + if err != nil { + return + } + var autoSeq int64 +notDone: + select { + case v := <-vChan: + if v.Mutable && v.Seq > autoSeq { + autoSeq = v.Seq + } + // There are more optimizations that can be done here. We can set CAS automatically, and we + // can skip updating the sequence number if the existing content already matches (and + // presumably republish the existing seq). + goto notDone + case <-op.Stalled(): + case <-ctx.Done(): + err = ctx.Err() + } + op.Stop() + var wg sync.WaitGroup + put := seqToPut(autoSeq) + op.Closest().Range(func(elem k_nearest_nodes.Elem) { + wg.Add(1) + go func() { + defer wg.Done() + // This is enforced by startGetTraversal. + token := elem.Data.(string) + res := s.Put(ctx, dht.NewAddr(elem.Addr.UDP()), put, token, dht.QueryRateLimiting{}) + err = res.ToError() + if err != nil { + logrus.WithContext(ctx).WithError(err).Warnf("error putting to %v [token=%q]", elem.Addr, token) + } else { + logrus.WithContext(ctx).WithError(err).Debugf("put to %v [token=%q]", elem.Addr, token) + } + }() + }) + wg.Wait() + stats = op.Stats() + return +} diff --git a/impl/internal/did/client_test.go b/impl/internal/did/client_test.go index 957b7fa1..5051b453 100644 --- a/impl/internal/did/client_test.go +++ b/impl/internal/did/client_test.go @@ -42,16 +42,16 @@ func TestClient(t *testing.T) { t.Logf("time to put and get: %s", since) } -func TestGet(t *testing.T) { - client, err := NewGatewayClient("http://localhost:8305") - - require.NoError(t, err) - require.NotNil(t, client) - - doc, _, _, err := client.GetDIDDocument("did:dht:uqaj3fcr9db6jg6o9pjs53iuftyj45r46aubogfaceqjbo6pp9sy") - require.NoError(t, err) - require.NotNil(t, doc) -} +// func TestGet(t *testing.T) { +// client, err := NewGatewayClient("http://localhost:8305") +// +// require.NoError(t, err) +// require.NotNil(t, client) +// +// doc, _, _, err := client.GetDIDDocument("did:dht:uqaj3fcr9db6jg6o9pjs53iuftyj45r46aubogfaceqjbo6pp9sy") +// require.NoError(t, err) +// require.NotNil(t, doc) +// } func TestClientInvalidGateway(t *testing.T) { g, err := NewGatewayClient("\n") diff --git a/impl/pkg/dht/dht.go b/impl/pkg/dht/dht.go index 0a580606..42404ee3 100644 --- a/impl/pkg/dht/dht.go +++ b/impl/pkg/dht/dht.go @@ -10,7 +10,6 @@ import ( errutil "github.com/TBD54566975/ssi-sdk/util" "github.com/anacrolix/dht/v2" "github.com/anacrolix/dht/v2/bep44" - "github.com/anacrolix/dht/v2/exts/getput" "github.com/anacrolix/log" "github.com/anacrolix/torrent/types/infohash" "github.com/pkg/errors" @@ -93,7 +92,7 @@ func (d *DHT) Put(ctx context.Context, request bep44.Put) (string, error) { } key := util.Z32Encode(request.K[:]) - t, err := getput.Put(ctx, request.Target(), d.Server, nil, func(int64) bep44.Put { + t, err := dhtint.Put(ctx, request.Target(), d.Server, nil, func(int64) bep44.Put { return request }) if err != nil { @@ -107,23 +106,6 @@ func (d *DHT) Put(ctx context.Context, request bep44.Put) (string, error) { return util.Z32Encode(request.K[:]), nil } -// Get returns the BEP-44 result for the given key from the DHT. -// The key is a z32-encoded string, such as "yj47pezutnpw9pyudeeai8cx8z8d6wg35genrkoqf9k3rmfzy58o". -func (d *DHT) Get(ctx context.Context, key string) (*getput.GetResult, error) { - ctx, span := telemetry.GetTracer().Start(ctx, "DHT.Get") - defer span.End() - - z32Decoded, err := util.Z32Decode(key) - if err != nil { - return nil, errors.Wrapf(err, "failed to decode key [%s]", key) - } - res, t, err := getput.Get(ctx, infohash.HashBytes(z32Decoded), d.Server, nil, nil) - if err != nil { - return nil, fmt.Errorf("failed to get key[%s] from dht; tried %d nodes, got %d responses", key, t.NumAddrsTried, t.NumResponses) - } - return &res, nil -} - // GetFull returns the full BEP-44 result for the given key from the DHT, using our modified // implementation of getput.Get. It should ONLY be used when it's needed to get the signature // data for a record. diff --git a/impl/pkg/dht/dht_test.go b/impl/pkg/dht/dht_test.go index bff387b9..bc7cae2a 100644 --- a/impl/pkg/dht/dht_test.go +++ b/impl/pkg/dht/dht_test.go @@ -10,15 +10,18 @@ import ( "github.com/anacrolix/torrent/bencode" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "github.com/TBD54566975/did-dht-method/internal/util" dhtclient "github.com/TBD54566975/did-dht-method/pkg/dht" ) func TestGetPutDHT(t *testing.T) { - ctx := context.Background() + defer goleak.VerifyNone(t) + ctx := context.Background() d := dhtclient.NewTestDHT(t) + defer d.Close() pubKey, privKey, err := util.GenerateKeypair() require.NoError(t, err) @@ -34,18 +37,12 @@ func TestGetPutDHT(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, id) - got, err := d.Get(ctx, id) + got, err := d.GetFull(ctx, id) require.NoError(t, err) require.NotEmpty(t, got) require.Equal(t, bencode.Bytes(put.V.([]byte)), got.V[2:]) require.Equal(t, put.Seq, got.Seq) - - full, err := d.GetFull(ctx, id) - require.NoError(t, err) - require.NotEmpty(t, full) - require.Equal(t, bencode.Bytes(put.V.([]byte)), full.V[2:]) - require.Equal(t, put.Seq, full.Seq) - require.False(t, full.Mutable) + require.True(t, got.Mutable) var payload string err = bencode.Unmarshal(got.V, &payload) diff --git a/impl/pkg/dht/pkarr.go b/impl/pkg/dht/pkarr.go index aaba9d8a..2265c5c3 100644 --- a/impl/pkg/dht/pkarr.go +++ b/impl/pkg/dht/pkarr.go @@ -6,9 +6,10 @@ import ( "github.com/TBD54566975/ssi-sdk/util" "github.com/anacrolix/dht/v2/bep44" - "github.com/anacrolix/dht/v2/exts/getput" "github.com/anacrolix/torrent/bencode" "github.com/miekg/dns" + + "github.com/TBD54566975/did-dht-method/internal/dht" ) // CreatePkarrPublishRequest creates a put request for the given records. Requires a public/private keypair and the records to put. @@ -50,7 +51,7 @@ func CreatePkarrPublishRequest(privateKey ed25519.PrivateKey, msg dns.Msg) (*bep // ParsePkarrGetResponse parses the response from a get request. // The response is expected to be a slice of DNS resource records. -func ParsePkarrGetResponse(response getput.GetResult) (*dns.Msg, error) { +func ParsePkarrGetResponse(response dht.FullGetResult) (*dns.Msg, error) { var payload string if err := bencode.Unmarshal(response.V, &payload); err != nil { return nil, util.LoggingErrorMsg(err, "failed to unmarshal payload value") diff --git a/impl/pkg/dht/pkarr_test.go b/impl/pkg/dht/pkarr_test.go index 6b5a1643..f83992b1 100644 --- a/impl/pkg/dht/pkarr_test.go +++ b/impl/pkg/dht/pkarr_test.go @@ -16,8 +16,9 @@ import ( "github.com/TBD54566975/did-dht-method/internal/util" ) -func TestGetPutPKARRDHT(t *testing.T) { - d := NewTestDHT(t) +func TestGetPutPkarrDHT(t *testing.T) { + dht := NewTestDHT(t) + defer dht.Close() _, privKey, err := util.GenerateKeypair() require.NoError(t, err) @@ -44,11 +45,11 @@ func TestGetPutPKARRDHT(t *testing.T) { put, err := CreatePkarrPublishRequest(privKey, msg) require.NoError(t, err) - id, err := d.Put(context.Background(), *put) + id, err := dht.Put(context.Background(), *put) require.NoError(t, err) require.NotEmpty(t, id) - got, err := d.Get(context.Background(), id) + got, err := dht.GetFull(context.Background(), id) require.NoError(t, err) require.NotEmpty(t, got) @@ -61,6 +62,7 @@ func TestGetPutPKARRDHT(t *testing.T) { func TestGetPutDIDDHT(t *testing.T) { dht := NewTestDHT(t) + defer dht.Close() pubKey, _, err := crypto.GenerateSECP256k1Key() require.NoError(t, err) @@ -108,7 +110,7 @@ func TestGetPutDIDDHT(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, gotID) - got, err := dht.Get(context.Background(), gotID) + got, err := dht.GetFull(context.Background(), gotID) require.NoError(t, err) require.NotEmpty(t, got) diff --git a/impl/pkg/server/pkarr_test.go b/impl/pkg/server/pkarr_test.go index c1d1ed16..f635600b 100644 --- a/impl/pkg/server/pkarr_test.go +++ b/impl/pkg/server/pkarr_test.go @@ -19,12 +19,14 @@ import ( "github.com/TBD54566975/did-dht-method/pkg/storage" ) -func TestPKARRRouter(t *testing.T) { - pkarrSvc := testPKARRService(t) +func TestPkarrRouter(t *testing.T) { + pkarrSvc := testPkarrService(t) pkarrRouter, err := NewPkarrRouter(&pkarrSvc) require.NoError(t, err) require.NotEmpty(t, pkarrRouter) + defer pkarrSvc.Close() + t.Run("test put record", func(t *testing.T) { didID, reqData := generateDIDPutRequest(t) @@ -146,16 +148,15 @@ func TestPKARRRouter(t *testing.T) { }) } -func testPKARRService(t *testing.T) service.PkarrService { +func testPkarrService(t *testing.T) service.PkarrService { defaultConfig := config.GetDefaultConfig() db, err := storage.NewStorage(defaultConfig.ServerConfig.StorageURI) require.NoError(t, err) require.NotEmpty(t, db) - d := dht.NewTestDHT(t) - - pkarrService, err := service.NewPkarrService(&defaultConfig, db, d) + dht := dht.NewTestDHT(t) + pkarrService, err := service.NewPkarrService(&defaultConfig, db, dht) require.NoError(t, err) require.NotEmpty(t, pkarrService) diff --git a/impl/pkg/server/server_test.go b/impl/pkg/server/server_test.go index 4b59373e..60be8599 100644 --- a/impl/pkg/server/server_test.go +++ b/impl/pkg/server/server_test.go @@ -30,6 +30,8 @@ func TestHealthCheckAPI(t *testing.T) { assert.NoError(t, err) assert.NotEmpty(t, server) + defer server.Close() + req := httptest.NewRequest(http.MethodGet, testServerURL+"/health", nil) w := httptest.NewRecorder() @@ -41,6 +43,8 @@ func TestHealthCheckAPI(t *testing.T) { err = json.NewDecoder(w.Body).Decode(&resp) assert.NoError(t, err) assert.Equal(t, HealthOK, resp.Status) + + shutdown <- os.Interrupt } // Is2xxResponse returns true if the given status code is a 2xx response diff --git a/impl/pkg/service/pkarr.go b/impl/pkg/service/pkarr.go index 08172bfc..5be841f8 100644 --- a/impl/pkg/service/pkarr.go +++ b/impl/pkg/service/pkarr.go @@ -207,13 +207,16 @@ func (s *PkarrService) republish() { return } seenRecords += int32(len(recordsBatch)) - if len(recordsBatch) == 0 { logrus.WithContext(ctx).Info("no records to republish") return } - logrus.WithContext(ctx).WithField("record_count", len(recordsBatch)).Infof("republishing records in batch: %d", batchCnt) + logrus.WithContext(ctx).WithFields(logrus.Fields{ + "record_count": len(recordsBatch), + "batch_number": batchCnt, + "total_seen": seenRecords, + }).Infof("republishing next batch of records") batchCnt++ var wg sync.WaitGroup @@ -247,3 +250,24 @@ func (s *PkarrService) republish() { "total": seenRecords, }).Infof("republishing complete with [%d] batches", batchCnt) } + +// Close closes the Pkarr service gracefully +func (s *PkarrService) Close() { + if s == nil { + return + } + if s.scheduler != nil { + s.scheduler.Stop() + } + if s.cache != nil { + if err := s.cache.Close(); err != nil { + logrus.WithError(err).Error("failed to close cache") + } + } + if err := s.db.Close(); err != nil { + logrus.WithError(err).Error("failed to close db") + } + if s.dht != nil { + s.dht.Close() + } +} diff --git a/impl/pkg/service/pkarr_test.go b/impl/pkg/service/pkarr_test.go index fd59ef6a..3a08e1d8 100644 --- a/impl/pkg/service/pkarr_test.go +++ b/impl/pkg/service/pkarr_test.go @@ -6,11 +6,10 @@ import ( "os" "testing" + anacrolixdht "github.com/anacrolix/dht/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - anacrolixdht "github.com/anacrolix/dht/v2" - "github.com/TBD54566975/did-dht-method/config" "github.com/TBD54566975/did-dht-method/internal/did" "github.com/TBD54566975/did-dht-method/pkg/dht" @@ -125,6 +124,8 @@ func TestPkarrService(t *testing.T) { assert.Equal(t, putMsg.Sig, got.Sig) assert.Equal(t, putMsg.Seq, got.Seq) }) + + t.Cleanup(func() { svc.Close() }) } func TestDHT(t *testing.T) { @@ -164,12 +165,17 @@ func TestDHT(t *testing.T) { assert.Equal(t, putMsg.V, gotFrom2.V) assert.Equal(t, putMsg.Sig, gotFrom2.Sig) assert.Equal(t, putMsg.Seq, gotFrom2.Seq) + + t.Cleanup(func() { + svc1.Close() + svc2.Close() + }) } func TestNoConfig(t *testing.T) { svc, err := NewPkarrService(nil, nil, nil) assert.EqualError(t, err, "config is required") - assert.Nil(t, svc) + assert.Empty(t, svc) svc, err = NewPkarrService(&config.Config{ PkarrConfig: config.PkarrServiceConfig{ @@ -186,6 +192,8 @@ func TestNoConfig(t *testing.T) { }, nil, nil) assert.EqualError(t, err, "failed to start republisher: gocron: cron expression failed to be parsed: failed to parse int from not: strconv.Atoi: parsing \"not\": invalid syntax") assert.Nil(t, svc) + + t.Cleanup(func() { svc.Close() }) } func newPkarrService(t *testing.T, id string, bootstrapPeers ...anacrolixdht.Addr) PkarrService { @@ -198,7 +206,6 @@ func newPkarrService(t *testing.T, id string, bootstrapPeers ...anacrolixdht.Add t.Cleanup(func() { os.Remove(fmt.Sprintf("diddht-test-%s.db", id)) }) d := dht.NewTestDHT(t, bootstrapPeers...) - pkarrService, err := NewPkarrService(&defaultConfig, db, d) require.NoError(t, err) require.NotEmpty(t, pkarrService)