Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: type parametrized header #22

Merged
merged 3 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fraudserv/requester.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const (
readDeadline = time.Minute
)

func (f *ProofService) requestProofs(
func (f *ProofService[H]) requestProofs(
ctx context.Context,
id protocol.ID,
pid peer.ID,
Expand Down
53 changes: 29 additions & 24 deletions fraudserv/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"

"github.com/celestiaorg/go-header"

"github.com/celestiaorg/go-fraud"
)

Expand All @@ -32,7 +34,7 @@ const fraudRequests = 5

// ProofService is responsible for validating and propagating Fraud Proofs.
// It implements the Service interface.
type ProofService struct {
type ProofService[H header.Header[H]] struct {
networkID string

ctx context.Context
Expand All @@ -45,28 +47,31 @@ type ProofService struct {
stores map[fraud.ProofType]datastore.Datastore

verifiersLk sync.RWMutex
verifiers map[fraud.ProofType]fraud.Verifier
verifiers map[fraud.ProofType]fraud.Verifier[H]

pubsub *pubsub.PubSub
host host.Host
getter fraud.HeaderFetcher
getter fraud.HeaderFetcher[H]
unmarshal fraud.ProofUnmarshaler[H]
ds datastore.Datastore
syncerEnabled bool
}

func NewProofService(
func NewProofService[H header.Header[H]](
p *pubsub.PubSub,
host host.Host,
getter fraud.HeaderFetcher,
getter fraud.HeaderFetcher[H],
unmarshal fraud.ProofUnmarshaler[H],
ds datastore.Datastore,
syncerEnabled bool,
networkID string,
) *ProofService {
return &ProofService{
) *ProofService[H] {
return &ProofService[H]{
pubsub: p,
host: host,
getter: getter,
verifiers: make(map[fraud.ProofType]fraud.Verifier),
unmarshal: unmarshal,
verifiers: make(map[fraud.ProofType]fraud.Verifier[H]),
topics: make(map[fraud.ProofType]*pubsub.Topic),
stores: make(map[fraud.ProofType]datastore.Datastore),
ds: ds,
Expand All @@ -75,9 +80,9 @@ func NewProofService(
}
}

// registerProofTopics registers proofTypes as pubsub topics to be joined.
func (f *ProofService) registerProofTopics(proofTypes ...fraud.ProofType) error {
for _, proofType := range proofTypes {
// registerProofTopics registers as pubsub topics to be joined.
func (f *ProofService[H]) registerProofTopics() error {
for _, proofType := range f.unmarshal.List() {
t, err := join(f.pubsub, proofType, f.networkID, f.processIncoming)
if err != nil {
return err
Expand All @@ -91,9 +96,9 @@ func (f *ProofService) registerProofTopics(proofTypes ...fraud.ProofType) error

// Start joins fraud proofs topics, sets the stream handler for fraudProtocolID and starts syncing
// if syncer is enabled.
func (f *ProofService) Start(context.Context) error {
func (f *ProofService[H]) Start(context.Context) error {
f.ctx, f.cancel = context.WithCancel(context.Background())
if err := f.registerProofTopics(fraud.Registered()...); err != nil {
if err := f.registerProofTopics(); err != nil {
return err
}
id := protocolID(f.networkID)
Expand All @@ -107,7 +112,7 @@ func (f *ProofService) Start(context.Context) error {
}

// Stop removes the stream handler and cancels the underlying ProofService
func (f *ProofService) Stop(context.Context) (err error) {
func (f *ProofService[H]) Stop(context.Context) (err error) {
f.host.RemoveStreamHandler(protocolID(f.networkID))
f.topicsLk.Lock()
for tp, topic := range f.topics {
Expand All @@ -119,7 +124,7 @@ func (f *ProofService) Stop(context.Context) (err error) {
return
}

func (f *ProofService) Subscribe(proofType fraud.ProofType) (_ fraud.Subscription, err error) {
func (f *ProofService[H]) Subscribe(proofType fraud.ProofType) (_ fraud.Subscription[H], err error) {
f.topicsLk.Lock()
defer f.topicsLk.Unlock()
t, ok := f.topics[proofType]
Expand All @@ -130,10 +135,10 @@ func (f *ProofService) Subscribe(proofType fraud.ProofType) (_ fraud.Subscriptio
if err != nil {
return nil, err
}
return &subscription{subs}, nil
return &subscription[H]{subs}, nil
}

func (f *ProofService) Broadcast(ctx context.Context, p fraud.Proof) error {
func (f *ProofService[H]) Broadcast(ctx context.Context, p fraud.Proof[H]) error {
bin, err := p.MarshalBinary()
if err != nil {
return err
Expand All @@ -147,7 +152,7 @@ func (f *ProofService) Broadcast(ctx context.Context, p fraud.Proof) error {
return t.Publish(ctx, bin)
}

func (f *ProofService) AddVerifier(proofType fraud.ProofType, verifier fraud.Verifier) error {
func (f *ProofService[H]) AddVerifier(proofType fraud.ProofType, verifier fraud.Verifier[H]) error {
f.verifiersLk.Lock()
defer f.verifiersLk.Unlock()
if _, ok := f.verifiers[proofType]; ok {
Expand All @@ -158,7 +163,7 @@ func (f *ProofService) AddVerifier(proofType fraud.ProofType, verifier fraud.Ver
}

// processIncoming encompasses the logic for validating fraud proofs.
func (f *ProofService) processIncoming(
func (f *ProofService[H]) processIncoming(
ctx context.Context,
proofType fraud.ProofType,
from peer.ID,
Expand All @@ -171,7 +176,7 @@ func (f *ProofService) processIncoming(

// unmarshal message to the Proof.
// Peer will be added to black list if unmarshalling fails.
proof, err := fraud.Unmarshal(proofType, msg.Data)
proof, err := f.unmarshal.Unmarshal(proofType, msg.Data)
if err != nil {
log.Errorw("unmarshalling failed", "err", err)
if !errors.Is(err, &fraud.ErrNoUnmarshaler{}) {
Expand Down Expand Up @@ -246,7 +251,7 @@ func (f *ProofService) processIncoming(
return pubsub.ValidationAccept
}

func (f *ProofService) Get(ctx context.Context, proofType fraud.ProofType) ([]fraud.Proof, error) {
func (f *ProofService[H]) Get(ctx context.Context, proofType fraud.ProofType) ([]fraud.Proof[H], error) {
f.storesLk.Lock()
store, ok := f.stores[proofType]
if !ok {
Expand All @@ -255,11 +260,11 @@ func (f *ProofService) Get(ctx context.Context, proofType fraud.ProofType) ([]fr
}
f.storesLk.Unlock()

return getAll(ctx, store, proofType)
return getAll(ctx, store, proofType, f.unmarshal)
}

// put adds a fraud proof to the local storage.
func (f *ProofService) put(ctx context.Context, proofType fraud.ProofType, hash string, data []byte) error {
func (f *ProofService[H]) put(ctx context.Context, proofType fraud.ProofType, hash string, data []byte) error {
f.storesLk.Lock()
store, ok := f.stores[proofType]
if !ok {
Expand All @@ -271,7 +276,7 @@ func (f *ProofService) put(ctx context.Context, proofType fraud.ProofType, hash
}

// verifyLocal checks if a fraud proof has been stored locally.
func (f *ProofService) verifyLocal(ctx context.Context, proofType fraud.ProofType, hash string, data []byte) bool {
func (f *ProofService[H]) verifyLocal(ctx context.Context, proofType fraud.ProofType, hash string, data []byte) bool {
f.storesLk.RLock()
storage, ok := f.stores[proofType]
f.storesLk.RUnlock()
Expand Down
54 changes: 34 additions & 20 deletions fraudserv/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ import (
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
"github.com/stretchr/testify/require"

"github.com/celestiaorg/go-header"
"github.com/celestiaorg/go-header/headertest"

gofraud "github.com/celestiaorg/go-fraud"
"github.com/celestiaorg/go-fraud"
"github.com/celestiaorg/go-fraud/fraudtest"
)

Expand All @@ -27,7 +26,7 @@ func TestService_SubscribeBroadcastValid(t *testing.T) {
serv := newTestService(ctx, t, false)
require.NoError(t, serv.Start(ctx))

fraud := fraudtest.NewValidProof()
fraud := fraudtest.NewValidProof[*headertest.DummyHeader]()
sub, err := serv.Subscribe(fraud.Type())
require.NoError(t, err)
defer sub.Cancel()
Expand All @@ -44,38 +43,38 @@ func TestService_SubscribeBroadcastWithVerifiers(t *testing.T) {
serv := newTestService(ctx, t, false)
require.NoError(t, serv.Start(ctx))

fraud := fraudtest.NewValidProof()
require.NoError(t, serv.AddVerifier(fraud.Type(), func(fraudProof gofraud.Proof) (bool, error) {
frd := fraudtest.NewValidProof[*headertest.DummyHeader]()
require.NoError(t, serv.AddVerifier(frd.Type(), func(fraudProof fraud.Proof[*headertest.DummyHeader]) (bool, error) {
return true, nil
}))

// test for error while adding the verifier for the second time
require.Error(t, serv.AddVerifier(fraud.Type(), func(fraudProof gofraud.Proof) (bool, error) {
require.Error(t, serv.AddVerifier(frd.Type(), func(fraudProof fraud.Proof[*headertest.DummyHeader]) (bool, error) {
return true, nil
}))
sub, err := serv.Subscribe(fraud.Type())
sub, err := serv.Subscribe(frd.Type())
require.NoError(t, err)
defer sub.Cancel()

require.NoError(t, serv.Broadcast(ctx, fraud))
require.NoError(t, serv.Broadcast(ctx, frd))
_, err = sub.Proof(ctx)
require.NoError(t, err)

// test for invalid fraud proof verifier
serv = newTestService(ctx, t, false)
require.NoError(t, serv.Start(ctx))
require.NoError(t, serv.AddVerifier(fraud.Type(), func(fraudProof gofraud.Proof) (bool, error) {
require.NoError(t, serv.AddVerifier(frd.Type(), func(fraudProof fraud.Proof[*headertest.DummyHeader]) (bool, error) {
return false, nil
}))
require.Error(t, serv.Broadcast(ctx, fraud))
require.Error(t, serv.Broadcast(ctx, frd))

// test for error case of fraud proof verifier
serv = newTestService(ctx, t, false)
require.NoError(t, serv.Start(ctx))
require.NoError(t, serv.AddVerifier(fraud.Type(), func(fraudProof gofraud.Proof) (bool, error) {
require.NoError(t, serv.AddVerifier(frd.Type(), func(fraudProof fraud.Proof[*headertest.DummyHeader]) (bool, error) {
return true, errors.New("throws error")
}))
require.Error(t, serv.Broadcast(ctx, fraud))
require.Error(t, serv.Broadcast(ctx, frd))
}

func TestService_SubscribeBroadcastInvalid(t *testing.T) {
Expand All @@ -85,7 +84,7 @@ func TestService_SubscribeBroadcastInvalid(t *testing.T) {
serv := newTestService(ctx, t, false)
require.NoError(t, serv.Start(ctx))

fraud := fraudtest.NewInvalidProof()
fraud := fraudtest.NewInvalidProof[*headertest.DummyHeader]()
sub, err := serv.Subscribe(fraud.Type())
require.NoError(t, err)
defer sub.Cancel()
Expand Down Expand Up @@ -123,7 +122,7 @@ func TestService_ReGossiping(t *testing.T) {
require.NoError(t, servB.Start(ctx))
require.NoError(t, servC.Start(ctx))

fraud := fraudtest.NewValidProof()
fraud := fraudtest.NewValidProof[*headertest.DummyHeader]()
subsA, err := servA.Subscribe(fraud.Type())
require.NoError(t, err)
defer subsA.Cancel()
Expand Down Expand Up @@ -161,7 +160,7 @@ func TestService_Get(t *testing.T) {
serv := newTestService(ctx, t, false)
require.NoError(t, serv.Start(ctx))

fraud := fraudtest.NewValidProof()
fraud := fraudtest.NewValidProof[*headertest.DummyHeader]()
_, err := serv.Get(ctx, fraud.Type()) // try to fetch proof
require.Error(t, err) // storage is empty so should error

Expand Down Expand Up @@ -189,7 +188,7 @@ func TestService_Sync(t *testing.T) {
servA := newTestServiceWithHost(ctx, t, net.Hosts()[0], false)
require.NoError(t, servA.Start(ctx))

fraud := fraudtest.NewValidProof()
fraud := fraudtest.NewValidProof[*headertest.DummyHeader]()
err = servA.Broadcast(ctx, fraud) // broadcasting ensures the fraud gets stored on servA
require.NoError(t, err)

Expand All @@ -207,23 +206,29 @@ func TestService_Sync(t *testing.T) {
require.NoError(t, err)
}

func newTestService(ctx context.Context, t *testing.T, enabledSyncer bool) *ProofService {
func newTestService(ctx context.Context, t *testing.T, enabledSyncer bool) *ProofService[*headertest.DummyHeader] {
net, err := mocknet.FullMeshLinked(1)
require.NoError(t, err)
return newTestServiceWithHost(ctx, t, net.Hosts()[0], enabledSyncer)
}

func newTestServiceWithHost(ctx context.Context, t *testing.T, host host.Host, enabledSyncer bool) *ProofService {
func newTestServiceWithHost(
ctx context.Context,
t *testing.T,
host host.Host,
enabledSyncer bool,
) *ProofService[*headertest.DummyHeader] {
ps, err := pubsub.NewFloodSub(ctx, host, pubsub.WithMessageSignaturePolicy(pubsub.StrictNoSign))
require.NoError(t, err)

store := headertest.NewDummyStore(t)
serv := NewProofService(
serv := NewProofService[*headertest.DummyHeader](
ps,
host,
func(ctx context.Context, u uint64) (header.Header, error) {
func(ctx context.Context, u uint64) (*headertest.DummyHeader, error) {
return store.GetByHeight(ctx, u)
},
unmarshaler,
sync.MutexWrap(datastore.NewMapDatastore()),
enabledSyncer,
"private",
Expand All @@ -237,3 +242,12 @@ func newTestServiceWithHost(ctx context.Context, t *testing.T, host host.Host, e
})
return serv
}

var unmarshaler = &fraud.MultiUnmarshaler[*headertest.DummyHeader]{
Unmarshalers: map[fraud.ProofType]func([]byte) (fraud.Proof[*headertest.DummyHeader], error){
fraudtest.DummyProofType: func(data []byte) (fraud.Proof[*headertest.DummyHeader], error) {
proof := &fraudtest.DummyProof[*headertest.DummyHeader]{}
return proof, proof.UnmarshalBinary(data)
},
},
}
13 changes: 10 additions & 3 deletions fraudserv/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"github.com/ipfs/go-datastore/namespace"
q "github.com/ipfs/go-datastore/query"

"github.com/celestiaorg/go-header"

"github.com/celestiaorg/go-fraud"
)

Expand Down Expand Up @@ -38,17 +40,22 @@ func getByHash(ctx context.Context, ds datastore.Datastore, hash string) ([]byte
}

// getAll queries all Fraud Proofs by their type.
func getAll(ctx context.Context, ds datastore.Datastore, proofType fraud.ProofType) ([]fraud.Proof, error) {
func getAll[H header.Header[H]](
ctx context.Context,
ds datastore.Datastore,
proofType fraud.ProofType,
registry fraud.ProofUnmarshaler[H],
) ([]fraud.Proof[H], error) {
entries, err := query(ctx, ds, q.Query{})
if err != nil {
return nil, err
}
if len(entries) == 0 {
return nil, datastore.ErrNotFound
}
proofs := make([]fraud.Proof, 0)
proofs := make([]fraud.Proof[H], 0)
for _, data := range entries {
proof, err := fraud.Unmarshal(proofType, data.Value)
proof, err := registry.Unmarshal(proofType, data.Value)
if err != nil {
if errors.Is(err, &fraud.ErrNoUnmarshaler{}) {
return nil, err
Expand Down
Loading