From 75f57332c7834bb2630339e9e877d6e9eb08cf0c Mon Sep 17 00:00:00 2001 From: gabe Date: Fri, 12 Apr 2024 12:52:58 -0400 Subject: [PATCH] add test for spam filter --- impl/internal/did/client_test.go | 21 +++++++++++++++++++++ impl/pkg/server/pkarr.go | 2 +- impl/pkg/server/pkarr_test.go | 15 +++++++++++++++ impl/pkg/service/pkarr.go | 4 ++-- 4 files changed, 39 insertions(+), 3 deletions(-) diff --git a/impl/internal/did/client_test.go b/impl/internal/did/client_test.go index 42ceecd3..3ad74552 100644 --- a/impl/internal/did/client_test.go +++ b/impl/internal/did/client_test.go @@ -48,6 +48,27 @@ func TestClientInvalidGateway(t *testing.T) { assert.Nil(t, g) } +func TestClientGet(t *testing.T) { + client, err := NewGatewayClient("https://diddht.tbddev.org") + require.NoError(t, err) + require.NotNil(t, client) + + // get the same DID 20 different times and log how long it takes each time + // aggregate the average time to get the DID after the loop + var total time.Duration + for i := 0; i < 20; i++ { + start := time.Now() + _, _, _, err := client.GetDIDDocument("did:dht:i9xkp8ddcbcg8jwq54ox699wuzxyifsqx4jru45zodqu453ksz6y") + require.NoError(t, err) + since := time.Since(start) + t.Logf("time to get DID: %s in round %d", since, i) + total += since + + } + average := total / 20 + t.Logf("average time to get DID: %s", average) +} + func TestInvalidDIDDocument(t *testing.T) { client, err := NewGatewayClient("https://diddht.tbddev.test") require.NoError(t, err) diff --git a/impl/pkg/server/pkarr.go b/impl/pkg/server/pkarr.go index a3ea8323..0b932985 100644 --- a/impl/pkg/server/pkarr.go +++ b/impl/pkg/server/pkarr.go @@ -59,7 +59,7 @@ func (r *PkarrRouter) GetRecord(c *gin.Context) { resp, err := r.service.GetPkarr(c, *id) if err != nil { // TODO(gabe): provide a more maintainable way to handle custom errors - if strings.Contains("spam", err.Error()) { + if strings.Contains(err.Error(), "spam") { LoggingRespondErrMsg(c, fmt.Sprintf("too many requests for bad key %s", *id), http.StatusTooManyRequests) return } diff --git a/impl/pkg/server/pkarr_test.go b/impl/pkg/server/pkarr_test.go index f87bd5c6..fe43655c 100644 --- a/impl/pkg/server/pkarr_test.go +++ b/impl/pkg/server/pkarr_test.go @@ -146,6 +146,21 @@ func TestPkarrRouter(t *testing.T) { pkarrRouter.GetRecord(c) assert.Equal(t, http.StatusNotFound, w.Result().StatusCode, "unexpected %s", w.Result().Status) }) + + t.Run("test get not found spam", func(t *testing.T) { + w := httptest.NewRecorder() + suffix := "cz13drbfxy3ih6xun4mw3cyiexrtfcs9gyp46o4469e93y36zhsy" + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/%s", testServerURL, suffix), nil) + c := newRequestContextWithParams(w, req, map[string]string{IDParam: suffix}) + pkarrRouter.GetRecord(c) + assert.Equal(t, http.StatusNotFound, w.Result().StatusCode, "unexpected %s", w.Result().Status) + + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/%s", testServerURL, suffix), nil) + c = newRequestContextWithParams(w, req, map[string]string{IDParam: suffix}) + pkarrRouter.GetRecord(c) + assert.Equal(t, http.StatusTooManyRequests, w.Result().StatusCode, "unexpected %s", w.Result().Status) + }) } func testPkarrService(t *testing.T) service.PkarrService { diff --git a/impl/pkg/service/pkarr.go b/impl/pkg/service/pkarr.go index 30c9368a..1d677d21 100644 --- a/impl/pkg/service/pkarr.go +++ b/impl/pkg/service/pkarr.go @@ -53,8 +53,8 @@ func NewPkarrService(cfg *config.Config, db storage.Storage, d *dht.DHT) (*Pkarr } // create a new cache for bad gets to prevent spamming the DHT - cacheConfig.LifeWindow = 120 * time.Second - cacheConfig.CleanWindow = 60 * time.Second + cacheConfig.LifeWindow = 60 * time.Second + cacheConfig.CleanWindow = 30 * time.Second badGetCache, err := bigcache.New(context.Background(), cacheConfig) if err != nil { return nil, ssiutil.LoggingErrorMsg(err, "failed to instantiate badGetCache")