diff --git a/go.mod b/go.mod index 52a4d0369f..001f7ccc03 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22.4 require ( cloud.google.com/go/storage v1.43.0 github.com/ALTree/bigfloat v0.2.0 + github.com/bits-and-blooms/bloom/v3 v3.7.0 github.com/chaos-mesh/chaos-mesh/api v0.0.0-20240913055630-bfe8736306b4 github.com/cosmos/btcutil v1.0.5 github.com/go-llsqlite/crawshaw v0.5.5 @@ -35,6 +36,7 @@ require ( github.com/prometheus/client_model v0.6.1 github.com/prometheus/common v0.59.1 github.com/quic-go/quic-go v0.46.0 + github.com/rqlite/sql v0.0.0-20240312185922-ffac88a740bd github.com/rs/cors v1.11.1 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/seehuhn/mt19937 v1.0.0 @@ -80,6 +82,7 @@ require ( github.com/anacrolix/sync v0.3.0 // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/bits-and-blooms/bitset v1.10.0 // indirect github.com/c0mm4nd/go-ripemd v0.0.0-20200326052756-bd1759ad7d10 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/cgroups v1.1.0 // indirect diff --git a/go.sum b/go.sum index d2dfc5edef..d0a137beed 100644 --- a/go.sum +++ b/go.sum @@ -48,6 +48,10 @@ github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZx github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.10.0 h1:ePXTeiPEazB5+opbv5fr8umg2R/1NlzgDsyepwsSr88= +github.com/bits-and-blooms/bitset v1.10.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/bits-and-blooms/bloom/v3 v3.7.0 h1:VfknkqV4xI+PsaDIsoHueyxVDZrfvMn56jeWUzvzdls= +github.com/bits-and-blooms/bloom/v3 v3.7.0/go.mod h1:VKlUSvp0lFIYqxJjzdnSsZEw4iHb1kOL2tfHTgyJBHg= github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= github.com/bradfitz/iter v0.0.0-20140124041915-454541ec3da2/go.mod h1:PyRFw1Lt2wKX4ZVSQ2mk+PeDa1rxyObEDlApuIsUKuo= github.com/bradfitz/iter v0.0.0-20190303215204-33e6a9893b0c h1:FUUopH4brHNO2kJoNN3pV+OBEYmgraLT/KHZrMM69r0= @@ -152,6 +156,8 @@ github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -551,6 +557,8 @@ github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzG github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rqlite/sql v0.0.0-20240312185922-ffac88a740bd h1:wW6BtayFoKaaDeIvXRE3SZVPOscSKlYD+X3bB749+zk= +github.com/rqlite/sql v0.0.0-20240312185922-ffac88a740bd/go.mod h1:ib9zVtNgRKiGuoMyUqqL5aNpk+r+++YlyiVIkclVqPg= github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= @@ -656,6 +664,8 @@ github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70 github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= +github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg= +github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.10/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= diff --git a/node/node.go b/node/node.go index 70ea73a48f..50b1c8d98b 100644 --- a/node/node.go +++ b/node/node.go @@ -74,6 +74,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/localsql" localmigrations "github.com/spacemeshos/go-spacemesh/sql/localsql/migrations" @@ -1979,6 +1980,15 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { } { warmupLog := app.log.Zap().Named("warmup") + app.log.Info("loading Bloom filters") + _, err := atxs.LoadBloomFilter(app.db, warmupLog) + if err != nil { + return fmt.Errorf("loading ATX Bloom filter: %w", err) + } + _, err = identities.LoadBloomFilter(app.db, warmupLog) + if err != nil { + return fmt.Errorf("loading malicious identity Bloom filter: %w", err) + } app.log.Info("starting cache warmup") applied, err := layers.GetLastApplied(app.db) if err != nil { diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index a94f3afa73..777f9c800e 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -2,10 +2,12 @@ package atxs import ( "context" + "errors" "fmt" "time" sqlite "github.com/go-llsqlite/crawshaw" + "go.uber.org/zap" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" @@ -15,6 +17,11 @@ import ( const ( CacheKindEpochATXs sql.QueryCacheKind = "epoch-atxs" CacheKindATXBlob sql.QueryCacheKind = "atx-blob" + // Bloom filter size is < 115 MiB while below 100M ATXs. + // TODO: adjust Bloom filter settings after ATX merge & checkpointing. + BloomFilterFalsePositiveRate = 0.01 + BloomFilterMinSize = 100_000_000 + BloomFilterExtraCoef = 1.2 ) // Query to retrieve ATXs. @@ -120,8 +127,28 @@ func GetByEpochAndNodeID( return id, nil } +// LoadBloomFilter intializes and loads the bloom filter for ATXs. +func LoadBloomFilter(db sql.StateDatabase, logger *zap.Logger) (*sql.DBBloomFilter, error) { + bf := sql.NewDBBloomFilter( + "atxs", "select id from atxs", "id", + BloomFilterMinSize, BloomFilterExtraCoef, BloomFilterFalsePositiveRate) + if err := bf.Load(db, logger); err != nil { + return nil, fmt.Errorf("load bloom filter: %w", err) + } + db.AddSet(bf) + return bf, nil +} + // Has checks if an ATX exists by a given ATX ID. +// It tries to do so using Bloom filter first, and falls back to a direct query if the filter is not available. func Has(db sql.Executor, id types.ATXID) (bool, error) { + has, err := sql.Contains(db, "atxs", id[:]) + if err == nil { + return has, nil + } else if !errors.Is(err, sql.ErrNoSet) { + return false, fmt.Errorf("check if have id %s: %w", id, err) + } + rows, err := db.Exec("select 1 from atxs where id = ?1;", func(stmt *sql.Statement) { stmt.BindBytes(1, id.Bytes()) @@ -486,7 +513,12 @@ func Add(db sql.Executor, atx *types.ActivationTx, blob types.AtxBlob) error { return fmt.Errorf("insert ATX ID %v: %w", atx.ID(), err) } - return AddBlob(db, atx.ID(), blob.Blob, blob.Version) + if err := AddBlob(db, atx.ID(), blob.Blob, blob.Version); err != nil { + return err + } + + sql.AddToSet(db, "atxs", atx.ID().Bytes()) + return err } func AddBlob(db sql.Executor, id types.ATXID, blob []byte, version types.AtxVersion) error { diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 91707aeb52..fa85af87f2 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" "golang.org/x/exp/rand" "github.com/spacemeshos/go-spacemesh/common/types" @@ -30,7 +31,7 @@ func TestMain(m *testing.M) { } func TestGet(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) atxList := make([]*types.ActivationTx, 0) for i := 0; i < 3; i++ { @@ -52,7 +53,7 @@ func TestGet(t *testing.T) { } func TestAll(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) var expected []types.ATXID for i := 0; i < 3; i++ { @@ -69,7 +70,7 @@ func TestAll(t *testing.T) { } func TestHasID(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) atxList := make([]*types.ActivationTx, 0) for i := 0; i < 3; i++ { @@ -92,7 +93,7 @@ func TestHasID(t *testing.T) { } func Test_IdentityExists(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -110,7 +111,7 @@ func Test_IdentityExists(t *testing.T) { } func TestGetFirstIDByNodeID(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -143,7 +144,7 @@ func TestGetFirstIDByNodeID(t *testing.T) { } func TestLatestN(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig1, err := signing.NewEdSigner() require.NoError(t, err) @@ -238,7 +239,7 @@ func TestLatestN(t *testing.T) { } func TestGetByEpochAndNodeID(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig1, err := signing.NewEdSigner() require.NoError(t, err) @@ -270,7 +271,7 @@ func TestGetByEpochAndNodeID(t *testing.T) { } func TestGetLastIDByNodeID(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -304,7 +305,7 @@ func TestGetLastIDByNodeID(t *testing.T) { } func TestGetIDByEpochAndNodeID(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig1, err := signing.NewEdSigner() require.NoError(t, err) @@ -348,7 +349,7 @@ func TestGetIDByEpochAndNodeID(t *testing.T) { } func TestGetIDsByEpoch(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) ctx := context.Background() sig1, err := signing.NewEdSigner() @@ -458,7 +459,7 @@ func TestGetIDsByEpochCached(t *testing.T) { } func Test_IterateAtxsWithMalfeasance(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) e1 := types.EpochID(1) m := make(map[types.ATXID]bool) @@ -488,7 +489,7 @@ func Test_IterateAtxsWithMalfeasance(t *testing.T) { } func Test_IterateAtxIdsWithMalfeasance(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) e1 := types.EpochID(1) m := make(map[types.ATXID]bool) @@ -519,7 +520,7 @@ func Test_IterateAtxIdsWithMalfeasance(t *testing.T) { func TestVRFNonce(t *testing.T) { // Arrange - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -552,7 +553,7 @@ func TestVRFNonce(t *testing.T) { } func TestLoadBlob(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) ctx := context.Background() sig, err := signing.NewEdSigner() @@ -610,7 +611,7 @@ func TestLoadBlob(t *testing.T) { } func TestLoadBlob_DefaultsToV1(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -745,7 +746,7 @@ func TestCachedBlobEviction(t *testing.T) { } func TestCheckpointATX(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) ctx := context.Background() sig, err := signing.NewEdSigner() @@ -792,7 +793,7 @@ func TestCheckpointATX(t *testing.T) { } func TestAdd(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) nonExistingATXID := types.ATXID(types.CalcHash32([]byte("0"))) _, err := atxs.Get(db, nonExistingATXID) @@ -973,7 +974,7 @@ func TestGetIDWithMaxHeight(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) var sigs []*signing.EdSigner var ids []types.ATXID filtered := make(map[types.ATXID]struct{}) @@ -1014,7 +1015,7 @@ func TestLatest(t *testing.T) { {"out of order", []uint32{3, 4, 1, 2}, 4}, } { t.Run(tc.desc, func(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) for i, epoch := range tc.epochs { full := &types.ActivationTx{ PublishEpoch: types.EpochID(epoch), @@ -1033,7 +1034,7 @@ func TestLatest(t *testing.T) { } func Test_PrevATXCollision(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -1086,13 +1087,13 @@ func TestCoinbase(t *testing.T) { t.Parallel() t.Run("not found", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) _, err := atxs.Coinbase(db, types.NodeID{}) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("found", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig, err := signing.NewEdSigner() require.NoError(t, err) atx := newAtx(t, sig, withCoinbase(types.Address{1, 2, 3})) @@ -1103,7 +1104,7 @@ func TestCoinbase(t *testing.T) { }) t.Run("picks last", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig, err := signing.NewEdSigner() require.NoError(t, err) atx1 := newAtx(t, sig, withPublishEpoch(1), withCoinbase(types.Address{1, 2, 3})) @@ -1120,13 +1121,13 @@ func TestUnits(t *testing.T) { t.Parallel() t.Run("ATX not found", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) _, err := atxs.Units(db, types.RandomATXID(), types.RandomNodeID()) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("smesher has no units in ATX", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) atxID := types.RandomATXID() require.NoError(t, atxs.SetPost(db, atxID, types.EmptyATXID, 0, types.RandomNodeID(), 10, 0)) _, err := atxs.Units(db, atxID, types.RandomNodeID()) @@ -1134,7 +1135,7 @@ func TestUnits(t *testing.T) { }) t.Run("returns units for given smesher in given ATX", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) atxID := types.RandomATXID() units := map[types.NodeID]uint32{ {1, 2, 3}: 10, @@ -1163,12 +1164,12 @@ func Test_AtxWithPrevious(t *testing.T) { prev := types.RandomATXID() t.Run("no atxs", func(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) _, err := atxs.AtxWithPrevious(db, prev, sig.NodeID()) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("finds other ATX with same previous", func(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) prev := types.RandomATXID() atx := newAtx(t, sig) @@ -1180,7 +1181,7 @@ func Test_AtxWithPrevious(t *testing.T) { require.Equal(t, atx.ID(), id) }) t.Run("finds other ATX with same previous (empty)", func(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) atx := newAtx(t, sig) require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) @@ -1191,7 +1192,7 @@ func Test_AtxWithPrevious(t *testing.T) { require.Equal(t, atx.ID(), id) }) t.Run("same previous used by 2 IDs in two ATXs", func(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig2, err := signing.NewEdSigner() require.NoError(t, err) @@ -1221,14 +1222,14 @@ func Test_FindDoublePublish(t *testing.T) { require.NoError(t, err) t.Run("no atxs", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) _, err := atxs.FindDoublePublish(db, types.RandomNodeID(), 0) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("no double publish", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) // one atx atx0 := newAtx(t, sig, withPublishEpoch(1)) @@ -1248,7 +1249,7 @@ func Test_FindDoublePublish(t *testing.T) { }) t.Run("double publish", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) atx0 := newAtx(t, sig) require.NoError(t, atxs.Add(db, atx0, types.AtxBlob{})) @@ -1268,7 +1269,7 @@ func Test_FindDoublePublish(t *testing.T) { }) t.Run("double publish different smesher", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) atx0Signer, err := signing.NewEdSigner() require.NoError(t, err) @@ -1296,13 +1297,13 @@ func Test_MergeConflict(t *testing.T) { t.Parallel() t.Run("no atxs", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) _, err := atxs.MergeConflict(db, types.RandomATXID(), 0) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("no conflict", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) marriage := types.RandomATXID() atx := types.ActivationTx{MarriageATX: &marriage} @@ -1314,7 +1315,7 @@ func Test_MergeConflict(t *testing.T) { }) t.Run("finds conflict", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) marriage := types.RandomATXID() atx0 := types.ActivationTx{MarriageATX: &marriage} @@ -1374,12 +1375,12 @@ func Test_Previous(t *testing.T) { func TestPrevIDByNodeID(t *testing.T) { t.Run("no previous ATXs", func(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) _, err := atxs.PrevIDByNodeID(db, types.RandomNodeID(), 0) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("filters by epoch", func(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -1403,7 +1404,7 @@ func TestPrevIDByNodeID(t *testing.T) { require.Equal(t, atx2.ID(), prevID) }) t.Run("the previous is merged and ID is not the signer", func(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) sig, err := signing.NewEdSigner() require.NoError(t, err) id := types.RandomNodeID() @@ -1427,3 +1428,52 @@ func TestPrevIDByNodeID(t *testing.T) { require.Equal(t, atx1.ID(), prevID) }) } + +func Test_BloomFilter(t *testing.T) { + db := statesql.InMemoryTest(t) + + atxList := make([]*types.ActivationTx, 0) + addSome := func() { + for i := 0; i < 3; i++ { + sig, err := signing.NewEdSigner() + require.NoError(t, err) + atx := newAtx(t, sig, withPublishEpoch(types.EpochID(i))) + atxList = append(atxList, atx) + require.NoError(t, atxs.Add(db, atx, types.AtxBlob{})) + } + } + + check := func() { + for _, want := range atxList { + has, err := atxs.Has(db, want.ID()) + require.NoError(t, err) + require.True(t, has) + } + n := db.QueryCount() + for range 5 { + has, err := atxs.Has(db, types.RandomATXID()) + require.NoError(t, err) + require.False(t, has) + } + require.Equal(t, n, db.QueryCount()) + } + + addSome() + bf, err := atxs.LoadBloomFilter(db, zaptest.NewLogger(t)) + require.NoError(t, err) + check() + require.Equal(t, sql.BloomStats{ + Loaded: 3, + NumPositive: 3, + NumNegative: 5, + }, bf.Stats()) + + addSome() + check() + require.Equal(t, sql.BloomStats{ + Loaded: 3, + Added: 3, + NumPositive: 9, // 2nd pass rechecks the initial set of IDs + NumNegative: 10, + }, bf.Stats()) +} diff --git a/sql/bloom.go b/sql/bloom.go new file mode 100644 index 0000000000..0923302c70 --- /dev/null +++ b/sql/bloom.go @@ -0,0 +1,179 @@ +package sql + +import ( + "fmt" + "math" + "sync" + "sync/atomic" + + "github.com/bits-and-blooms/bloom/v3" + "go.uber.org/zap" + + "github.com/spacemeshos/go-spacemesh/sql/expr" +) + +// BloomStats represents bloom filter statistics. +type BloomStats struct { + Loaded int + Added int + NumPositive int + NumNegative int +} + +// DBBloomFilter reduces the number of database lookups for the keys that are not in the +// database. +type DBBloomFilter struct { + mtx sync.Mutex + f *bloom.BloomFilter + name string + sel expr.Statement + idCol expr.Expr + minSize int + fp float64 + extraCoef float64 + added atomic.Int64 + loaded atomic.Int64 + positive atomic.Int64 + negative atomic.Int64 +} + +var _ IDSet = &DBBloomFilter{} + +// NewDBBloomFilter creates a new Bloom filter that for a database table. +// tableName is the name of the table, idColumn is the name of the column that contains +// the IDs, filter is an optional SQL expression that selects the rows to include in the +// filter, and falsePositiveRate is the desired false positive rate. +func NewDBBloomFilter( + name, selectExpr, idCol string, + minSize int, + extraCoef, falsePositiveRate float64, +) *DBBloomFilter { + return &DBBloomFilter{ + name: name, + sel: expr.MustParseStatement(selectExpr), + idCol: expr.MustParse(idCol), + minSize: minSize, + extraCoef: extraCoef, + fp: falsePositiveRate, + } +} + +func (bf *DBBloomFilter) Name() string { + return bf.name +} + +func (bf *DBBloomFilter) countSQL() string { + return expr.SelectBasedOn(bf.sel).Columns(expr.CountStar()).String() +} + +func (bf *DBBloomFilter) loadSQL() string { + return bf.sel.String() +} + +func (bf *DBBloomFilter) hasSQL() string { + return expr.SelectBasedOn(bf.sel). + Where(expr.MaybeAnd( + expr.WhereExpr(bf.sel), + expr.Op(bf.idCol, expr.EQ, expr.Bind()))). + String() +} + +// Add adds the specified key to the Bloom filter. +func (bf *DBBloomFilter) Add(id []byte) { + bf.mtx.Lock() + defer bf.mtx.Unlock() + if bf.f != nil { + bf.f.Add(id) + bf.added.Add(1) + } +} + +// Load populates the Bloom filter from the database. +func (bf *DBBloomFilter) Load(db Executor, logger *zap.Logger) error { + bf.mtx.Lock() + defer bf.mtx.Unlock() + if bf.f != nil { + return nil + } + logger.Info("estimating Bloom filter size", zap.String("table", bf.name)) + count := 0 + _, err := db.Exec(bf.countSQL(), nil, func(stmt *Statement) bool { + count = stmt.ColumnInt(0) + return true + }) + if err != nil { + return fmt.Errorf("get count of table %s: %w", bf.name, err) + } + size := int(math.Ceil(float64(count) * bf.extraCoef)) + if bf.minSize > 0 && size < bf.minSize { + size = bf.minSize + } + bf.f = bloom.NewWithEstimates(uint(size), bf.fp) + logger.Info("loading Bloom filter", + zap.String("table", bf.name), + zap.Int("count", count), + zap.Int("actualSize", size), + zap.Int("bytes", bf.f.BitSet().BinaryStorageSize()), + zap.Float64("falsePositiveRate", bf.fp)) + var bs []byte + nRows, err := db.Exec(bf.loadSQL(), nil, func(stmt *Statement) bool { + l := stmt.ColumnLen(0) + if cap(bs) < l { + bs = make([]byte, l) + } else { + bs = bs[:l] + } + stmt.ColumnBytes(0, bs) + bf.f.Add(bs) + bf.loaded.Add(1) + return true + }) + if err != nil { + return fmt.Errorf("populate Bloom filter for the table %s: %w", bf.name, err) + } + logger.Info("done loading Bloom filter", zap.String("table", bf.name), zap.Int("rows", nRows)) + return nil +} + +func (bf *DBBloomFilter) mayHave(id []byte) bool { + bf.mtx.Lock() + defer bf.mtx.Unlock() + if bf.f == nil { + return false + } + if bf.f.Test(id) { + bf.positive.Add(1) + return true + } + bf.negative.Add(1) + return false +} + +func (bf *DBBloomFilter) inDB(db Executor, id []byte) (bool, error) { + nRows, err := db.Exec(bf.hasSQL(), func(stmt *Statement) { + stmt.BindBytes(1, id) + }, nil) + if err != nil { + return false, fmt.Errorf("check if ID exists in table %s: %w", bf.name, err) + } + return nRows != 0, nil +} + +// Contains returns true if the ID is in the table. +func (bf *DBBloomFilter) Contains(db Executor, id []byte) (bool, error) { + if !bf.mayHave(id) { + // no false negatives in the Bloom filter + return false, nil + } + return bf.inDB(db, id) +} + +// Stats returns Bloom filter statistics. +func (bf *DBBloomFilter) Stats() BloomStats { + return BloomStats{ + Loaded: int(bf.loaded.Load()), + Added: int(bf.added.Load()), + NumPositive: int(bf.positive.Load()), + NumNegative: int(bf.negative.Load()), + } +} diff --git a/sql/bloom_test.go b/sql/bloom_test.go new file mode 100644 index 0000000000..dd603213fe --- /dev/null +++ b/sql/bloom_test.go @@ -0,0 +1,168 @@ +package sql + +import ( + "context" + "crypto/rand" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func randomID() []byte { + b := make([]byte, 32) + _, err := rand.Read(b) + if err != nil { + panic(err) + } + return b +} + +func TestDBBloomFilter(t *testing.T) { + const ( + numInsert = 1000 + falsePositiveRate = 0.01 + numChecks = 10000 + maxFalsePositiveCount = int(numChecks * falsePositiveRate * 2) + ) + + db := InMemoryTest(t) + _, err := db.Exec("CREATE TABLE test (id CHAR(32))", nil, nil) + require.NoError(t, err) + ids := make([][]byte, numInsert) + for i := range ids { + ids[i] = randomID() + _, err := db.Exec("INSERT INTO test (id) VALUES (?)", func(st *Statement) { + st.BindBytes(1, ids[i]) + }, nil) + require.NoError(t, err) + } + + c := db.QueryCount() + bf := NewDBBloomFilter("bloomTest", "select id from test", "id", numInsert, 1.5, falsePositiveRate) + require.NoError(t, bf.Load(db, zaptest.NewLogger(t))) + db.AddSet(bf) + require.Equal(t, BloomStats{ + Loaded: numInsert, + }, bf.Stats()) + + c += 2 + require.Equal(t, c, db.QueryCount()) + for _, id := range ids { + has, err := Contains(db, "bloomTest", id) + require.NoError(t, err) + require.True(t, has) + } + c += numInsert + require.Equal(t, c, db.QueryCount()) + require.Equal(t, BloomStats{ + Loaded: numInsert, + NumPositive: numInsert, + }, bf.Stats()) + + check := func(ex Executor) { + for range 10 { + oldStats := bf.Stats() + for range numChecks { + id := randomID() + has, err := Contains(ex, "bloomTest", id) + require.NoError(t, err) + require.False(t, has) + } + count := db.QueryCount() - c + newStats := bf.Stats() + require.Equal(t, count, newStats.NumPositive-oldStats.NumPositive) + require.Equal(t, numChecks-count, newStats.NumNegative-oldStats.NumNegative) + t.Logf("query count: %d, maxFalsePositiveCount: %d", count, maxFalsePositiveCount) + require.GreaterOrEqual(t, maxFalsePositiveCount, count) + c = db.QueryCount() + } + } + check(db) + require.NoError(t, db.WithTx(context.Background(), func(tx Transaction) error { + check(tx) + return nil + })) + + for range 100 { + newID := randomID() + _, err = db.Exec("INSERT INTO test (id) VALUES (?)", func(st *Statement) { + st.BindBytes(1, newID) + }, nil) + require.NoError(t, err) + + c = db.QueryCount() + db.AddToSet("bloomTest", newID) + require.True(t, bf.mayHave(newID)) + has, err := Contains(db, "bloomTest", newID) + require.NoError(t, err) + require.True(t, has) + require.Equal(t, c+1, db.QueryCount()) + } + require.Equal(t, 100, bf.Stats().Added) +} + +func TestDBBloomFilterWhere(t *testing.T) { + const ( + numInsert = 1000 + numSkip = 100 + falsePositiveRate = 0.01 + numChecks = 10000 + maxFalsePositiveCount = int(numChecks * falsePositiveRate * 2) + ) + + db := InMemoryTest(t, WithConnections(10)) + _, err := db.Exec("CREATE TABLE test (id CHAR(32), include int)", nil, nil) + require.NoError(t, err) + ids := make([][]byte, numInsert) + for i := range ids { + ids[i] = randomID() + _, err := db.Exec("INSERT INTO test (id, include) VALUES (?, 1)", func(st *Statement) { + st.BindBytes(1, ids[i]) + }, nil) + require.NoError(t, err) + } + skip := make([][]byte, numSkip) + for i := range skip { + skip[i] = randomID() + _, err := db.Exec("INSERT INTO test (id, include) VALUES (?, 0)", func(st *Statement) { + st.BindBytes(1, skip[i]) + }, nil) + require.NoError(t, err) + } + + c := db.QueryCount() + bf := NewDBBloomFilter("bloomTest", "select id from test where include = 1", "id", + numInsert, 1.5, falsePositiveRate) + require.NoError(t, bf.Load(db, zaptest.NewLogger(t))) + db.AddSet(bf) + + c += 2 + require.Equal(t, c, db.QueryCount()) + for _, id := range ids { + has, err := Contains(db, "bloomTest", id) + require.NoError(t, err) + require.True(t, has) + } + c += numInsert + require.Equal(t, c, db.QueryCount()) + + for range 5 { + for range numChecks { + id := randomID() + has, err := Contains(db, "bloomTest", id) + require.NoError(t, err) + require.False(t, has) + } + count := db.QueryCount() - c + t.Logf("query count: %d, maxFalsePositiveCount: %d", count, maxFalsePositiveCount) + require.GreaterOrEqual(t, maxFalsePositiveCount, count) + c = db.QueryCount() + } + + for _, id := range skip { + has, err := Contains(db, "bloomTest", id) + require.NoError(t, err) + require.False(t, has, "skipped key included in filter") + } +} diff --git a/sql/database.go b/sql/database.go index a647086f4a..01b5ca716c 100644 --- a/sql/database.go +++ b/sql/database.go @@ -11,6 +11,7 @@ import ( "strings" "sync" "sync/atomic" + "testing" "time" sqlite "github.com/go-llsqlite/crawshaw" @@ -223,6 +224,14 @@ func OpenInMemory(opts ...Opt) (*sqliteDatabase, error) { return Open("file::memory:?mode=memory", opts...) } +// InMemoryTest returns an in-mem database for testing and ensures database is closed +// during `tb.Cleanup`. +func InMemoryTest(tb testing.TB, opts ...Opt) Database { + db := InMemory(append(opts, WithNoCheckSchemaDrift())...) + tb.Cleanup(func() { db.Close() }) + return db +} + // InMemory creates an in-memory database for testing and panics if // there's an error. func InMemory(opts ...Opt) *sqliteDatabase { @@ -550,10 +559,23 @@ func handleIncompleteCopyMigration(config *conf) error { // PushIntercept. The query will fail if Interceptor returns an error. type Interceptor func(query string) error +// IDSet verifies if the particular ID exists in the database. +type IDSet interface { + Name() string + Add(id []byte) + Contains(db Executor, id []byte) (bool, error) +} + +type IDSetCollection interface { + Contains(name string, id []byte) (bool, error) + AddToSet(name string, id []byte) +} + // Database represents a database. type Database interface { Executor QueryCache + IDSetCollection Close() error QueryCount() int QueryCache() QueryCache @@ -563,6 +585,7 @@ type Database interface { WithTxImmediate(ctx context.Context, exec func(Transaction) error) error Intercept(key string, fn Interceptor) RemoveInterceptor(key string) + AddSet(c IDSet) } // Transaction represents a transaction. @@ -584,6 +607,9 @@ type sqliteDatabase struct { interceptMtx sync.Mutex interceptors map[string]Interceptor + + setMtx sync.Mutex + sets map[string]IDSet } var _ Database = &sqliteDatabase{} @@ -981,6 +1007,43 @@ func (db *sqliteDatabase) QueryCache() QueryCache { return db.queryCache } +func (db *sqliteDatabase) AddSet(c IDSet) { + name := c.Name() + db.setMtx.Lock() + defer db.setMtx.Unlock() + _, found := db.sets[name] + if found { + panic("set already exists: " + name) + } + if db.sets == nil { + db.sets = make(map[string]IDSet) + } + db.sets[name] = c +} + +func (db *sqliteDatabase) getSet(name string) IDSet { + db.setMtx.Lock() + defer db.setMtx.Unlock() + return db.sets[name] +} + +func (db *sqliteDatabase) AddToSet(name string, id []byte) { + set := db.getSet(name) + if set != nil { + set.Add(id) + } +} + +var ErrNoSet = errors.New("no such set") + +func (db *sqliteDatabase) Contains(name string, id []byte) (bool, error) { + set := db.getSet(name) + if set == nil { + return false, ErrNoSet + } + return set.Contains(db, id) +} + func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (int, error) { stmt, err := conn.Prepare(query) if err != nil { @@ -1023,6 +1086,11 @@ type sqliteTx struct { err error } +var ( + _ Transaction = &sqliteTx{} + _ IDSetCollection = &sqliteTx{} +) + func (tx *sqliteTx) begin(initstmt string) error { stmt := tx.conn.Prep(initstmt) _, err := stmt.Step() @@ -1070,6 +1138,36 @@ func (tx *sqliteTx) Exec(query string, encoder Encoder, decoder Decoder) (int, e return exec(tx.conn, query, encoder, decoder) } +// AddToSet implements IDSetCollection. +func (tx *sqliteTx) AddToSet(name string, id []byte) { + tx.db.AddToSet(name, id) +} + +// Contains implements IDSetCollection. +func (tx *sqliteTx) Contains(name string, id []byte) (bool, error) { + set := tx.db.getSet(name) + if set == nil { + return false, ErrNoSet + } + return set.Contains(tx, id) +} + +// AddToSet registers the ID with the specified set to ensure Contains returns true for +// it. +func AddToSet(db Executor, name string, id []byte) { + if set, ok := db.(IDSetCollection); ok { + set.AddToSet(name, id) + } +} + +// Contains verifies that the ID exists within the specified set. +func Contains(db Executor, name string, id []byte) (bool, error) { + if set, ok := db.(IDSetCollection); ok { + return set.Contains(name, id) + } + return false, ErrNoSet +} + func mapSqliteError(err error) error { switch sqlite.ErrCode(err) { case sqlite.SQLITE_CONSTRAINT_PRIMARYKEY, sqlite.SQLITE_CONSTRAINT_UNIQUE: diff --git a/sql/database_test.go b/sql/database_test.go index 0104c08813..75687c6b95 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -612,3 +612,14 @@ func TestExclusive(t *testing.T) { }) } } + +func TestNoSet(t *testing.T) { + db := InMemoryTest(t) + AddToSet(db, "noSuchSet", []byte{42}) + _, err := Contains(db, "noSuchSet", []byte{42}) + require.ErrorIs(t, err, ErrNoSet) + wrapped := struct{ Executor }{db} + AddToSet(wrapped, "noSuchSet", []byte{42}) + _, err = Contains(wrapped, "noSuchSet", []byte{42}) + require.ErrorIs(t, err, ErrNoSet) +} diff --git a/sql/expr/expr.go b/sql/expr/expr.go new file mode 100644 index 0000000000..862891a5db --- /dev/null +++ b/sql/expr/expr.go @@ -0,0 +1,212 @@ +// Package expr proviedes a simple SQL expression parser and builder. +// It wraps the rqlite/sql package and provides a more convenient API that contains only +// what's needed for the go-spacemesh codebase. +package expr + +import ( + "strings" + + rsql "github.com/rqlite/sql" +) + +// SQL operations +const ( + NE = rsql.NE // != + EQ = rsql.EQ // = + LE = rsql.LE // <= + LT = rsql.LT // < + GT = rsql.GT // > + GE = rsql.GE // >= + BITAND = rsql.BITAND // & + BITOR = rsql.BITOR // | + BITNOT = rsql.BITNOT // ! + LSHIFT = rsql.LSHIFT // << + RSHIFT = rsql.RSHIFT // >> + PLUS = rsql.PLUS // + + MINUS = rsql.MINUS // - + STAR = rsql.STAR // * + SLASH = rsql.SLASH // / + REM = rsql.REM // % + CONCAT = rsql.CONCAT // || + DOT = rsql.DOT // . + AND = rsql.AND + OR = rsql.OR + NOT = rsql.NOT +) + +// Expr represents a parsed SQL expression. +type Expr = rsql.Expr + +// Statement represents a parsed SQL statement. +type Statement = rsql.Statement + +// MustParse parses an SQL expression and panics if there's an error. +func MustParse(s string) rsql.Expr { + expr, err := rsql.ParseExprString(s) + if err != nil { + panic("error parsing SQL expression: " + err.Error()) + } + return expr +} + +// MustParseStatement parses an SQL statement and panics if there's an error. +func MustParseStatement(s string) rsql.Statement { + st, err := rsql.NewParser(strings.NewReader(s)).ParseStatement() + if err != nil { + panic("error parsing SQL statement: " + err.Error()) + } + return st +} + +// MaybeAnd joins together several SQL expressions with AND, ignoring any nil exprs. +// If no non-nil expressions are passed, nil is returned. +// If a single non-nil expression is passed, that single expression is returned. +// Otherwise, the expressions are joined together with ANDs: +// a AND b AND c AND d +func MaybeAnd(exprs ...Expr) Expr { + var r Expr + for _, expr := range exprs { + switch { + case expr == nil: + case r == nil: + r = expr + default: + r = Op(r, AND, expr) + } + } + return r +} + +// Ident constructs SQL identifier expression for the identifier with the specified name. +func Ident(name string) *rsql.Ident { + return &rsql.Ident{Name: name} +} + +// Number constructs a number literal. +func Number(value string) *rsql.NumberLit { + return &rsql.NumberLit{Value: value} +} + +// TableSource constructs a Source clause for SELECT statement that corresponds to +// selecting from a single table with the specified name. +func TableSource(name string) rsql.Source { + return &rsql.QualifiedTableName{Name: Ident(name)} +} + +// Op constructs a binary expression such as x + y or x < y. +func Op(x Expr, op rsql.Token, y Expr) Expr { + return &rsql.BinaryExpr{ + X: x, + Op: op, + Y: y, + } +} + +// Bind constructs the unnamed bind expression (?). +func Bind() Expr { + return &rsql.BindExpr{Name: "?"} +} + +// Between constructs BETWEEN expression: x BETWEEN a AND b. +func Between(x, a, b Expr) Expr { + return Op(x, rsql.BETWEEN, &rsql.Range{X: a, Y: b}) +} + +// Call constructs a call expression with specified arguments such as max(x). +func Call(name string, args ...Expr) Expr { + return &rsql.Call{Name: Ident(name), Args: args} +} + +// CountStar returns a COUNT(*) expression. +func CountStar() Expr { + return &rsql.Call{Name: Ident("count"), Star: rsql.Pos{Offset: 1}} +} + +// Asc constructs an ascending ORDER BY term. +func Asc(expr Expr) *rsql.OrderingTerm { + return &rsql.OrderingTerm{X: expr} +} + +// Desc constructs a descedning ORDER BY term. +func Desc(expr Expr) *rsql.OrderingTerm { + return &rsql.OrderingTerm{X: expr, Desc: rsql.Pos{Offset: 1}} +} + +// SelectBuilder is used to construct a SELECT statement. +type SelectBuilder struct { + st *rsql.SelectStatement +} + +// Select returns a SELECT statement builder. +func Select(columns ...any) SelectBuilder { + sb := SelectBuilder{st: &rsql.SelectStatement{}} + return sb.Columns(columns...) +} + +// SelectBasedOn returns a SELECT statement builder based on the specified SELECT statement. +// The statement must be parseable, otherwise SelectBasedOn panics. +// The builder methods can be used to alter the statement. +func SelectBasedOn(st Statement) SelectBuilder { + st = rsql.CloneStatement(st) + return SelectBuilder{st: st.(*rsql.SelectStatement)} +} + +// Get returns the underlying SELECT statement. +func (sb SelectBuilder) Get() *rsql.SelectStatement { + return sb.st +} + +// String returns the underlying SELECT statement as a string. +func (sb SelectBuilder) String() string { + return sb.st.String() +} + +// Columns sets columns in the SELECT statement. +func (sb SelectBuilder) Columns(columns ...any) SelectBuilder { + sb.st.Columns = make([]*rsql.ResultColumn, len(columns)) + for n, column := range columns { + switch c := column.(type) { + case *rsql.ResultColumn: + sb.st.Columns[n] = c + case Expr: + sb.st.Columns[n] = &rsql.ResultColumn{Expr: c} + default: + panic("unexpected column type") + } + } + return sb +} + +// From adds FROM clause to the SELECT statement. +func (sb SelectBuilder) From(s rsql.Source) SelectBuilder { + sb.st.Source = s + return sb +} + +// From adds WHERE clause to the SELECT statement. +func (sb SelectBuilder) Where(s Expr) SelectBuilder { + sb.st.WhereExpr = s + return sb +} + +// From adds ORDER BY clause to the SELECT statement. +func (sb SelectBuilder) OrderBy(terms ...*rsql.OrderingTerm) SelectBuilder { + sb.st.OrderingTerms = terms + return sb +} + +// From adds LIMIT clause to the SELECT statement. +func (sb SelectBuilder) Limit(limit Expr) SelectBuilder { + sb.st.LimitExpr = limit + return sb +} + +// ColumnExpr returns nth column expression from the SELECT statement. +func ColumnExpr(st Statement, n int) Expr { + return st.(*rsql.SelectStatement).Columns[n].Expr +} + +// WhereExpr returns WHERE expression from the SELECT statement. +func WhereExpr(st Statement) Expr { + return st.(*rsql.SelectStatement).WhereExpr +} diff --git a/sql/expr/expr_test.go b/sql/expr/expr_test.go new file mode 100644 index 0000000000..8a87bc8395 --- /dev/null +++ b/sql/expr/expr_test.go @@ -0,0 +1,137 @@ +package expr + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExpr(t *testing.T) { + for _, tc := range []struct { + Expr Expr + Expected string + }{ + { + Expr: MustParse("a = ? OR x < 10"), + Expected: `"a" = ? OR "x" < 10`, + }, + { + Expr: Number("1"), + Expected: `1`, + }, + { + Expr: CountStar(), + Expected: `count(*)`, + }, + { + Expr: Op(Ident("x"), EQ, Ident("y")), + Expected: `"x" = "y"`, + }, + { + Expr: MaybeAnd(Op(Ident("x"), EQ, Ident("y"))), + Expected: `"x" = "y"`, + }, + { + Expr: MaybeAnd(Op(Ident("x"), EQ, Ident("y")), nil, nil), + Expected: `"x" = "y"`, + }, + { + Expr: MaybeAnd(Op(Ident("x"), EQ, Ident("y")), + Op(Ident("a"), EQ, Bind())), + Expected: `"x" = "y" AND "a" = ?`, + }, + { + Expr: MaybeAnd(Op(Ident("x"), EQ, Ident("y")), + nil, + Op(Ident("a"), EQ, Bind())), + Expected: `"x" = "y" AND "a" = ?`, + }, + { + Expr: MaybeAnd(), + Expected: "", + }, + { + Expr: Between(Ident("x"), Ident("y"), Bind()), + Expected: `"x" BETWEEN "y" AND ?`, + }, + { + Expr: Call("max", Ident("x")), + Expected: `max("x")`, + }, + { + Expr: MustParse("a.id"), + Expected: `"a"."id"`, + }, + } { + if tc.Expected == "" { + require.Nil(t, tc.Expr) + } else { + require.Equal(t, tc.Expected, tc.Expr.String()) + require.Equal(t, tc.Expected, MustParse(tc.Expected).String()) + } + } +} + +func TestStatement(t *testing.T) { + for _, tc := range []struct { + Statement SelectBuilder + Expected string + Columns []string + }{ + { + Statement: Select(Number("1")), + Expected: `SELECT 1`, + Columns: []string{"1"}, + }, + { + Statement: Select(Call("max", Ident("n"))).From(TableSource("mytable")), + Expected: `SELECT max("n") FROM "mytable"`, + Columns: []string{`max("n")`}, + }, + { + Statement: Select(Ident("id"), Ident("n")). + From(TableSource("mytable")). + Where(Op(Ident("n"), GE, Bind())). + OrderBy(Asc(Ident("n"))). + Limit(Bind()), + Expected: `SELECT "id", "n" FROM "mytable" WHERE "n" >= ? ORDER BY "n" LIMIT ?`, + Columns: []string{`"id"`, `"n"`}, + }, + { + Statement: Select(Ident("id")). + From(TableSource("mytable")). + OrderBy(Desc(Ident("id"))). + Limit(Number("10")), + Expected: `SELECT "id" FROM "mytable" ORDER BY "id" DESC LIMIT 10`, + Columns: []string{`"id"`}, + }, + { + Statement: Select(CountStar()).From(TableSource("mytable")), + Expected: `SELECT count(*) FROM "mytable"`, + Columns: []string{`count(*)`}, + }, + { + Statement: SelectBasedOn( + MustParseStatement("select a.id from a left join b on a.x = b.x")). + Where(Op(Ident("id"), EQ, Bind())), + Expected: `SELECT "a"."id" FROM "a" LEFT JOIN "b" ON "a"."x" = "b"."x" WHERE "id" = ?`, + Columns: []string{`"a"."id"`}, + }, + { + Statement: SelectBasedOn( + MustParseStatement("select a.id from a inner join b on a.x = b.x")). + Columns(CountStar()). + Where(Op(Ident("id"), EQ, Bind())), + Expected: `SELECT count(*) FROM "a" INNER JOIN "b" ON "a"."x" = "b"."x" WHERE "id" = ?`, + Columns: []string{`count(*)`}, + }, + } { + require.Equal(t, tc.Expected, tc.Statement.String()) + st := tc.Statement.Get() + require.Equal(t, tc.Expected, st.String()) + for n, col := range tc.Columns { + require.Equal(t, col, ColumnExpr(st, n).String()) + } + require.Equal(t, tc.Expected, MustParseStatement(tc.Expected).String()) + } +} diff --git a/sql/identities/identities.go b/sql/identities/identities.go index 257bec03ee..941bccd81c 100644 --- a/sql/identities/identities.go +++ b/sql/identities/identities.go @@ -2,15 +2,40 @@ package identities import ( "context" + "errors" "fmt" "time" sqlite "github.com/go-llsqlite/crawshaw" + "go.uber.org/zap" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" ) +const ( + // Bloom filter size is < 234 KiB while below 100k identities + BloomFilterFalsePositiveRate = 0.0001 + BloomFilterMinSize = 100000 + BloomFilterExtraCoef = 1.2 +) + +// LoadBloomFilter intializes and loads the bloom filter for malicious identities. +func LoadBloomFilter(db sql.StateDatabase, logger *zap.Logger) (*sql.DBBloomFilter, error) { + bf := sql.NewDBBloomFilter( + "malicious", + `SELECT i1.pubkey FROM identities i1 + LEFT JOIN identities i2 ON i1.marriage_atx = i2.marriage_atx + WHERE i1.proof IS NOT NULL OR i2.proof IS NOT NULL`, + "i1.pubkey", + BloomFilterMinSize, BloomFilterExtraCoef, BloomFilterFalsePositiveRate) + if err := bf.Load(db, logger); err != nil { + return nil, fmt.Errorf("load bloom filter: %w", err) + } + db.AddSet(bf) + return bf, nil +} + // SetMalicious records identity as malicious. func SetMalicious(db sql.Executor, nodeID types.NodeID, proof []byte, received time.Time) error { _, err := db.Exec(`insert into identities (pubkey, proof, received) @@ -26,11 +51,25 @@ func SetMalicious(db sql.Executor, nodeID types.NodeID, proof []byte, received t if err != nil { return fmt.Errorf("set malicious %v: %w", nodeID, err) } + ids, err := EquivocationSet(db, nodeID) + if err != nil { + return fmt.Errorf("get equivocation set for %v: %w", nodeID, err) + } + for _, id := range ids { + sql.AddToSet(db, "malicious", id[:]) + } return nil } // IsMalicious returns true if identity is known to be malicious. func IsMalicious(db sql.Executor, nodeID types.NodeID) (bool, error) { + has, err := sql.Contains(db, "malicious", nodeID[:]) + if err == nil { + return has, nil + } else if !errors.Is(err, sql.ErrNoSet) { + return false, fmt.Errorf("check if node %s is malicious: %w", nodeID, err) + } + rows, err := db.Exec(` SELECT 1 FROM identities WHERE ( @@ -161,7 +200,15 @@ func Marriage(db sql.Executor, id types.NodeID) (*MarriageData, error) { // Set marriage inserts marriage data for given identity. // If identity doesn't exist - create it. func SetMarriage(db sql.Executor, id types.NodeID, m *MarriageData) error { - _, err := db.Exec(` + isMalicious1, err := IsMalicious(db, id) + if err != nil { + return fmt.Errorf("checking if the node is malicious: %w", err) + } + isMalicious2, err := IsMalicious(db, m.Target) + if err != nil { + return fmt.Errorf("checking if the target is malicious: %w", err) + } + _, err = db.Exec(` INSERT INTO identities (pubkey, marriage_atx, marriage_idx, marriage_target, marriage_signature) values (?1, ?2, ?3, ?4, ?5) ON CONFLICT(pubkey) DO UPDATE SET @@ -181,6 +228,10 @@ func SetMarriage(db sql.Executor, id types.NodeID, m *MarriageData) error { if err != nil { return fmt.Errorf("setting marriage %v: %w", id, err) } + if isMalicious1 || isMalicious2 { + sql.AddToSet(db, "malicious", id[:]) + sql.AddToSet(db, "malicious", m.Target[:]) + } return nil } diff --git a/sql/identities/identities_test.go b/sql/identities/identities_test.go index 555a9be270..8149cdce7b 100644 --- a/sql/identities/identities_test.go +++ b/sql/identities/identities_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" @@ -16,7 +17,7 @@ import ( ) func TestMalicious(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) nodeID := types.NodeID{1, 1, 1, 1} mal, err := identities.IsMalicious(db, nodeID) @@ -59,7 +60,7 @@ func TestMalicious(t *testing.T) { } func Test_GetMalicious(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) got, err := identities.GetMalicious(db) require.NoError(t, err) require.Nil(t, got) @@ -77,7 +78,7 @@ func Test_GetMalicious(t *testing.T) { } func TestLoadMalfeasanceBlob(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) ctx := context.Background() nid1 := types.RandomNodeID() @@ -122,7 +123,7 @@ func TestMarriageATX(t *testing.T) { t.Parallel() t.Run("not married", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) id := types.RandomNodeID() _, err := identities.MarriageATX(db, id) @@ -130,7 +131,7 @@ func TestMarriageATX(t *testing.T) { }) t.Run("married", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) id := types.RandomNodeID() marriage := identities.MarriageData{ @@ -149,7 +150,7 @@ func TestMarriageATX(t *testing.T) { func TestMarriage(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) id := types.RandomNodeID() marriage := identities.MarriageData{ @@ -168,7 +169,7 @@ func TestEquivocationSet(t *testing.T) { t.Parallel() t.Run("equivocation set of married IDs", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) atx := types.RandomATXID() ids := []types.NodeID{ @@ -195,7 +196,7 @@ func TestEquivocationSet(t *testing.T) { }) t.Run("equivocation set for unmarried ID contains itself only", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) id := types.RandomNodeID() set, err := identities.EquivocationSet(db, id) require.NoError(t, err) @@ -203,7 +204,7 @@ func TestEquivocationSet(t *testing.T) { }) t.Run("can't escape the marriage", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) atx := types.RandomATXID() ids := []types.NodeID{ types.RandomNodeID(), @@ -236,7 +237,7 @@ func TestEquivocationSet(t *testing.T) { } }) t.Run("married doesn't become malicious immediately", func(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) atx := types.RandomATXID() id := types.RandomNodeID() require.NoError(t, identities.SetMarriage(db, id, &identities.MarriageData{ATX: atx})) @@ -256,7 +257,7 @@ func TestEquivocationSet(t *testing.T) { }) t.Run("all IDs in equivocation set are malicious if one is", func(t *testing.T) { t.Parallel() - db := statesql.InMemory() + db := statesql.InMemoryTest(t) atx := types.RandomATXID() ids := []types.NodeID{ types.RandomNodeID(), @@ -280,7 +281,7 @@ func TestEquivocationSetByMarriageATX(t *testing.T) { t.Parallel() t.Run("married IDs", func(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) ids := []types.NodeID{ types.RandomNodeID(), types.RandomNodeID(), @@ -296,9 +297,92 @@ func TestEquivocationSetByMarriageATX(t *testing.T) { require.Equal(t, ids, set) }) t.Run("empty set", func(t *testing.T) { - db := statesql.InMemory() + db := statesql.InMemoryTest(t) set, err := identities.EquivocationSetByMarriageATX(db, types.RandomATXID()) require.NoError(t, err) require.Empty(t, set) }) } + +func Test_BloomFilter(t *testing.T) { + t.Run("not married", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + got, err := identities.GetMalicious(db) + require.NoError(t, err) + require.Nil(t, got) + + const numBad = 11 + bad := make([]types.NodeID, 0, numBad*2) + addSome := func() { + for i := 0; i < numBad; i++ { + nid := types.NodeID{byte(i + 1)} + bad = append(bad, nid) + require.NoError(t, identities.SetMalicious( + db, nid, types.RandomBytes(11), time.Now().Local())) + } + } + + check := func() { + for _, nodeID := range bad { + has, err := identities.IsMalicious(db, nodeID) + require.NoError(t, err) + require.True(t, has) + } + n := db.QueryCount() + for range 5 { + has, err := identities.IsMalicious(db, types.RandomNodeID()) + require.NoError(t, err) + require.False(t, has) + } + require.Equal(t, n, db.QueryCount()) + } + + addSome() + + bf, err := identities.LoadBloomFilter(db, zaptest.NewLogger(t)) + require.NoError(t, err) + check() + require.Equal(t, sql.BloomStats{ + Loaded: numBad, + NumPositive: numBad, + NumNegative: 5, + }, bf.Stats()) + + addSome() + check() + require.Equal(t, sql.BloomStats{ + Loaded: numBad, + Added: numBad, + NumPositive: numBad * 3, // 2nd pass rechecks the initial set of IDs + NumNegative: 10, + }, bf.Stats()) + }) + + t.Run("married", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + _, err := identities.LoadBloomFilter(db, zaptest.NewLogger(t)) + require.NoError(t, err) + atx := types.RandomATXID() + ids := []types.NodeID{ + types.RandomNodeID(), + types.RandomNodeID(), + } + for i, id := range ids { + require.NoError(t, identities.SetMarriage( + db, id, &identities.MarriageData{ATX: atx, Index: i})) + } + + // Each member of the equivocation set needs to be added to the Bloom + // filter, as it has no false negatives and if an ID is absent from the + // filter, it's considered not to be malicious. + require.NoError(t, identities.SetMalicious(db, ids[0], []byte("proof"), time.Now())) + + for _, id := range ids { + malicious, err := identities.IsMalicious(db, id) + require.NoError(t, err) + require.True(t, malicious) + } + }) +}