Skip to content

Commit

Permalink
more concurrency fixes (#177)
Browse files Browse the repository at this point in the history
* more

* better

* move

* done

* update err

* fix test

* comments
  • Loading branch information
decentralgabe authored Apr 11, 2024
1 parent 9eb197a commit bb6a10f
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 45 deletions.
4 changes: 2 additions & 2 deletions impl/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ func configureLogger(level string) {
if level != "" {
logLevel, err := logrus.ParseLevel(level)
if err != nil {
logrus.WithError(err).WithField("level", level).Error("could not parse log level, setting to info")
logrus.SetLevel(logrus.InfoLevel)
logrus.WithError(err).WithField("level", level).Error("could not parse log level, setting to debug")
logrus.SetLevel(logrus.DebugLevel)
} else {
logrus.SetLevel(logLevel)
}
Expand Down
2 changes: 1 addition & 1 deletion impl/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func GetDefaultConfig() Config {
CacheSizeLimitMB: 500,
},
Log: LogConfig{
Level: logrus.InfoLevel.String(),
Level: logrus.DebugLevel.String(),
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion impl/integrationtest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ var (
)

func main() {
logrus.SetLevel(logrus.InfoLevel)
logrus.SetLevel(logrus.DebugLevel)
if len(os.Args) < 2 {
logrus.Fatal("must specify 1 argument (server URL)")
}
Expand Down
14 changes: 8 additions & 6 deletions impl/internal/dht/getput.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"math"
"sync"
"time"

k_nearest_nodes "github.com/anacrolix/dht/v2/k-nearest-nodes"
"github.com/anacrolix/torrent/bencode"
Expand Down Expand Up @@ -37,7 +38,10 @@ func startGetTraversal(
Alpha: 15,
Target: target,
DoQuery: func(ctx context.Context, addr krpc.NodeAddr) traversal.QueryResult {
res := s.Get(ctx, dht.NewAddr(addr.UDP()), target, seq, dht.QueryRateLimiting{})
queryCtx, cancel := context.WithTimeout(ctx, 8*time.Second)
defer cancel()

res := s.Get(queryCtx, dht.NewAddr(addr.UDP()), target, seq, dht.QueryRateLimiting{})
err := res.ToError()
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, dht.TransactionTimeout) {
logrus.WithContext(ctx).WithError(err).Debugf("error querying %v", addr)
Expand All @@ -52,7 +56,7 @@ func startGetTraversal(
Sig: r.Sig,
Mutable: false,
}:
case <-ctx.Done():
case <-queryCtx.Done():
}
} else if sha1.Sum(append(r.K[:], salt...)) == target && bep44.Verify(r.K[:], salt, *r.Seq, bv, r.Sig[:]) {
select {
Expand All @@ -62,15 +66,13 @@ func startGetTraversal(
Sig: r.Sig,
Mutable: true,
}:
case <-ctx.Done():
case <-queryCtx.Done():
}
} else if rv != nil {
logrus.WithContext(ctx).Debugf("get response item hash didn't match target: %q", rv)
}
}
tqr := res.TraversalQueryResult(addr)
// Filter replies from nodes that don't have a string token. This doesn't look prettier
// with generics. "The token value should be a short binary string." ¯\_(ツ)_/¯ (BEP 5).
tqr.ClosestData, _ = tqr.ClosestData.(string)
if tqr.ClosestData == nil {
tqr.ResponseFrom = nil
Expand All @@ -80,7 +82,7 @@ func startGetTraversal(
NodeFilter: s.TraversalNodeFilter,
})

// list for context cancellation or stalled traversal
// Listen for context cancellation or stalled traversal
go func() {
select {
case <-ctx.Done():
Expand Down
5 changes: 3 additions & 2 deletions impl/pkg/dht/logging.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package dht

import (
"strings"

"github.com/anacrolix/log"
"github.com/sirupsen/logrus"
)

func init() {
logrus.SetFormatter(&logrus.JSONFormatter{})
log.Default.Handlers = []log.Handler{logrusHandler{}}
}

Expand All @@ -16,7 +17,7 @@ type logrusHandler struct{}
// It intentionally downgrades the log level to reduce verbosity.
func (logrusHandler) Handle(record log.Record) {
entry := logrus.WithFields(logrus.Fields{"names": record.Names})
msg := record.Msg.String()
msg := strings.Replace(record.Msg.String(), "\n", "", -1)

switch record.Level {
case log.Debug:
Expand Down
17 changes: 14 additions & 3 deletions impl/pkg/server/pkarr.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,18 @@ func (r *PkarrRouter) GetRecord(c *gin.Context) {
return
}

resp, err := r.service.GetPkarr(c.Request.Context(), *id)
// make sure the key is valid
key, err := util.Z32Decode(*id)
if err != nil {
LoggingRespondErrWithMsg(c, err, "invalid record id", http.StatusInternalServerError)
return
}
if len(key) != ed25519.PublicKeySize {
LoggingRespondErrMsg(c, "invalid z32 encoded ed25519 public key", http.StatusBadRequest)
return
}

resp, err := r.service.GetPkarr(c, *id)
if err != nil {
LoggingRespondErrWithMsg(c, err, "failed to get pkarr record", http.StatusInternalServerError)
return
Expand Down Expand Up @@ -82,7 +93,7 @@ func (r *PkarrRouter) PutRecord(c *gin.Context) {
}
key, err := util.Z32Decode(*id)
if err != nil {
LoggingRespondErrWithMsg(c, err, "failed to read id", http.StatusInternalServerError)
LoggingRespondErrWithMsg(c, err, "invalid record id", http.StatusInternalServerError)
return
}
if len(key) != ed25519.PublicKeySize {
Expand Down Expand Up @@ -114,7 +125,7 @@ func (r *PkarrRouter) PutRecord(c *gin.Context) {
return
}

if err = r.service.PublishPkarr(c.Request.Context(), *id, *request); err != nil {
if err = r.service.PublishPkarr(c, *id, *request); err != nil {
LoggingRespondErrWithMsg(c, err, "failed to publish pkarr record", http.StatusInternalServerError)
return
}
Expand Down
2 changes: 1 addition & 1 deletion impl/pkg/server/pkarr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func TestPkarrRouter(t *testing.T) {

t.Run("test get not found", func(t *testing.T) {
w := httptest.NewRecorder()
suffix := "aaa"
suffix := "uqaj3fcr9db6jg6o9pjs53iuftyj45r46aubogfaceqjbo6pp9sy"
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/%s", testServerURL, suffix), nil)
c := newRequestContextWithParams(w, req, map[string]string{IDParam: suffix})
pkarrRouter.GetRecord(c)
Expand Down
3 changes: 2 additions & 1 deletion impl/pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ func NewServer(cfg *config.Config, shutdown chan os.Signal, d *dht.DHT) (*Server
Server: &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.ServerConfig.APIHost, cfg.ServerConfig.APIPort),
Handler: handler,
ReadTimeout: time.Second * 10,
ReadTimeout: time.Second * 15,
ReadHeaderTimeout: time.Second * 10,
WriteTimeout: time.Second * 10,
MaxHeaderBytes: 1 << 20,
},
cfg: cfg,
svc: pkarrService,
Expand Down
118 changes: 91 additions & 27 deletions impl/pkg/service/pkarr.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ const recordSizeLimit = 1000

// PkarrService is the Pkarr service responsible for managing the Pkarr DHT and reading/writing records
type PkarrService struct {
cfg *config.Config
db storage.Storage
dht *dht.DHT
cache *bigcache.BigCache
scheduler *dhtint.Scheduler
cfg *config.Config
db storage.Storage
dht *dht.DHT
cache *bigcache.BigCache
badGetCache *bigcache.BigCache
scheduler *dhtint.Scheduler
}

// NewPkarrService returns a new instance of the Pkarr service
Expand All @@ -41,7 +42,7 @@ func NewPkarrService(cfg *config.Config, db storage.Storage, d *dht.DHT) (*Pkarr
return nil, ssiutil.LoggingNewError("config is required")
}

// create and start cache and scheduler
// create and start get cache
cacheTTL := time.Duration(cfg.PkarrConfig.CacheTTLSeconds) * time.Second
cacheConfig := bigcache.DefaultConfig(cacheTTL)
cacheConfig.MaxEntrySize = recordSizeLimit
Expand All @@ -51,13 +52,24 @@ func NewPkarrService(cfg *config.Config, db storage.Storage, d *dht.DHT) (*Pkarr
if err != nil {
return nil, ssiutil.LoggingErrorMsg(err, "failed to instantiate cache")
}

// create a new cache for bad gets to prevent spamming the DHT
cacheConfig.LifeWindow = 120 * time.Second
cacheConfig.CleanWindow = 60 * time.Second
badGetCache, err := bigcache.New(context.Background(), cacheConfig)
if err != nil {
return nil, ssiutil.LoggingErrorMsg(err, "failed to instantiate badGetCache")
}

// start scheduler for republishing
scheduler := dhtint.NewScheduler()
svc := PkarrService{
cfg: cfg,
db: db,
dht: d,
cache: cache,
scheduler: &scheduler,
cfg: cfg,
db: db,
dht: d,
cache: cache,
badGetCache: badGetCache,
scheduler: &scheduler,
}
if err = scheduler.Schedule(cfg.PkarrConfig.RepublishCRON, svc.republish); err != nil {
return nil, ssiutil.LoggingErrorMsg(err, "failed to start republisher")
Expand All @@ -70,6 +82,11 @@ func (s *PkarrService) PublishPkarr(ctx context.Context, id string, record pkarr
ctx, span := telemetry.GetTracer().Start(ctx, "PkarrService.PublishPkarr")
defer span.End()

// make sure the key is valid
if _, err := util.Z32Decode(id); err != nil {
return ssiutil.LoggingCtxErrorMsgf(ctx, err, "failed to decode z-base-32 encoded ID: %s", id)
}

if err := record.IsValid(); err != nil {
return err
}
Expand Down Expand Up @@ -115,6 +132,16 @@ func (s *PkarrService) GetPkarr(ctx context.Context, id string) (*pkarr.Response
ctx, span := telemetry.GetTracer().Start(ctx, "PkarrService.GetPkarr")
defer span.End()

// make sure the key is valid
if _, err := util.Z32Decode(id); err != nil {
return nil, ssiutil.LoggingCtxErrorMsgf(ctx, err, "failed to decode z-base-32 encoded ID: %s", id)
}

// if the key is in the badGetCache, return an error
if _, err := s.badGetCache.Get(id); err == nil {
return nil, ssiutil.LoggingCtxErrorMsgf(ctx, err, "key [%s] looked up too frequently, please wait a bit before trying again", id)
}

// first do a cache lookup
if got, err := s.cache.Get(id); err == nil {
var resp pkarr.Response
Expand All @@ -138,7 +165,13 @@ func (s *PkarrService) GetPkarr(ctx context.Context, id string) (*pkarr.Response

record, err := s.db.ReadRecord(ctx, rawID)
if err != nil || record == nil {
logrus.WithContext(ctx).WithError(err).WithField("record", id).Error("failed to resolve pkarr record from storage")
logrus.WithContext(ctx).WithError(err).WithField("record", id).Error("failed to resolve pkarr record from storage; adding to badGetCache")

// add the key to the badGetCache to prevent spamming the DHT
if err = s.badGetCache.Set(id, []byte{0}); err != nil {
logrus.WithContext(ctx).WithError(err).WithField("record", id).Error("failed to set key in badGetCache")
}

return nil, err
}

Expand Down Expand Up @@ -193,67 +226,93 @@ func (s *PkarrService) republish() {
recordCnt, err := s.db.RecordCount(ctx)
if err != nil {
logrus.WithContext(ctx).WithError(err).Error("failed to get record count before republishing")
return
} else {
logrus.WithContext(ctx).WithField("record_count", recordCnt).Info("republishing records")
}

var nextPageToken []byte
var recordsBatch []pkarr.Record
var seenRecords, batchCnt, successCnt, errCnt int32 = 0, 0, 0, 0
var seenRecords, batchCnt, successCnt, errCnt int32 = 0, 1, 0, 0

for {
recordsBatch, nextPageToken, err = s.db.ListRecords(ctx, nextPageToken, 1000)
if err != nil {
logrus.WithContext(ctx).WithError(err).Error("failed to list record(s) for republishing")
return
}
seenRecords += int32(len(recordsBatch))
if len(recordsBatch) == 0 {
batchSize := len(recordsBatch)
seenRecords += int32(batchSize)
if batchSize == 0 {
logrus.WithContext(ctx).Info("no records to republish")
return
}

logrus.WithContext(ctx).WithFields(logrus.Fields{
"record_count": len(recordsBatch),
"record_count": batchSize,
"batch_number": batchCnt,
"total_seen": seenRecords,
}).Infof("republishing next batch of records")
}).Infof("republishing batch [%d] of [%d] records", batchCnt, batchSize)
batchCnt++

var wg sync.WaitGroup
wg.Add(len(recordsBatch))
wg.Add(batchSize)

var batchErrCnt, batchSuccessCnt int32 = 0, 0
for _, record := range recordsBatch {
go func(record pkarr.Record) {
go func(ctx context.Context, record pkarr.Record) {
defer wg.Done()

recordID := zbase32.EncodeToString(record.Key[:])
logrus.WithContext(ctx).Debugf("republishing record: %s", recordID)

// Create a new context with a timeout of 10 seconds
putCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
putCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

if _, err = s.dht.Put(putCtx, record.BEP44()); err != nil {
logrus.WithContext(ctx).WithError(err).Errorf("failed to republish record: %s", recordID)
atomic.AddInt32(&errCnt, 1)
if _, putErr := s.dht.Put(putCtx, record.BEP44()); putErr != nil {
logrus.WithContext(putCtx).WithError(putErr).Errorf("failed to republish record: %s", recordID)
atomic.AddInt32(&batchErrCnt, 1)
} else {
atomic.AddInt32(&successCnt, 1)
atomic.AddInt32(&batchSuccessCnt, 1)
}
}(record)
}(ctx, record)
}

// Wait for all goroutines in this batch to finish before moving on to the next batch
wg.Wait()

// Update the success and error counts
atomic.AddInt32(&successCnt, batchSuccessCnt)
atomic.AddInt32(&errCnt, batchErrCnt)

successRate := float64(batchSuccessCnt) / float64(batchSize)

logrus.WithContext(ctx).WithFields(logrus.Fields{
"batch_number": batchCnt,
"success": successCnt,
"errors": errCnt,
}).Infof("batch [%d] completed with a [%02f] percent success rate", batchCnt, successRate)

if successRate < 0.8 {
logrus.WithContext(ctx).WithFields(logrus.Fields{
"batch_number": batchCnt,
"success": successCnt,
"errors": errCnt,
}).Errorf("batch [%d] failed to meet success rate threshold; exiting republishing early", batchCnt)
break
}

if nextPageToken == nil {
break
}
}

successRate := float64(successCnt) / float64(seenRecords)
logrus.WithContext(ctx).WithFields(logrus.Fields{
"success": seenRecords - errCnt,
"errors": errCnt,
"total": seenRecords,
}).Infof("republishing complete with [%d] batches", batchCnt)
}).Infof("republishing complete with [%d] batches of [%d] total records with an [%02f] percent success rate", batchCnt, seenRecords, successRate*100)
}

// Close closes the Pkarr service gracefully
Expand All @@ -269,6 +328,11 @@ func (s *PkarrService) Close() {
logrus.WithError(err).Error("failed to close cache")
}
}
if s.badGetCache != nil {
if err := s.badGetCache.Close(); err != nil {
logrus.WithError(err).Error("failed to close badGetCache")
}
}
if err := s.db.Close(); err != nil {
logrus.WithError(err).Error("failed to close db")
}
Expand Down
Loading

0 comments on commit bb6a10f

Please sign in to comment.