From aebab472c1d5b4e972993d239976448d423f7c2f Mon Sep 17 00:00:00 2001 From: afk <84330705+afkbyte@users.noreply.github.com> Date: Wed, 29 May 2024 00:58:19 -0400 Subject: [PATCH] minimal diff to be able to return the task response bytes instead of just the task response digest (#252) * minimal diff to be able to return the task response bytes * add taskReponse type * only add to taskResponseMap if doesn't already exist * add interface for hash function * fix lint * make local * add types * remove mapping --- services/bls_aggregation/blsagg.go | 36 +++-- services/bls_aggregation/blsagg_test.go | 198 +++++++++++++++++------- types/avs.go | 7 +- 3 files changed, 170 insertions(+), 71 deletions(-) diff --git a/services/bls_aggregation/blsagg.go b/services/bls_aggregation/blsagg.go index d470da18..ddd75cc5 100644 --- a/services/bls_aggregation/blsagg.go +++ b/services/bls_aggregation/blsagg.go @@ -45,6 +45,7 @@ var ( type BlsAggregationServiceResponse struct { Err error // if Err is not nil, the other fields are not valid TaskIndex types.TaskIndex // unique identifier of the task + TaskResponse types.TaskResponse // the task response that was signed TaskResponseDigest types.TaskResponseDigest // digest of the task response that was signed // The below 8 fields are the data needed to build the IBLSSignatureChecker.NonSignerStakesAndSignature struct // users of this service will need to build the struct themselves by converting the bls points @@ -97,7 +98,7 @@ type BlsAggregationService interface { ProcessNewSignature( ctx context.Context, taskIndex types.TaskIndex, - taskResponseDigest types.TaskResponseDigest, + taskResponse types.TaskResponse, blsSignature *bls.Signature, operatorId types.OperatorId, ) error @@ -134,17 +135,20 @@ type BlsAggregatorService struct { taskChansMutex sync.RWMutex avsRegistryService avsregistry.AvsRegistryService logger logging.Logger + + hashFunction types.TaskResponseHashFunction } var _ BlsAggregationService = (*BlsAggregatorService)(nil) -func NewBlsAggregatorService(avsRegistryService avsregistry.AvsRegistryService, logger logging.Logger) *BlsAggregatorService { +func NewBlsAggregatorService(avsRegistryService avsregistry.AvsRegistryService, hashFunction types.TaskResponseHashFunction, logger logging.Logger) *BlsAggregatorService { return &BlsAggregatorService{ aggregatedResponsesC: make(chan BlsAggregationServiceResponse), signedTaskRespsCs: make(map[types.TaskIndex]chan types.SignedTaskResponseDigest), taskChansMutex: sync.RWMutex{}, avsRegistryService: avsRegistryService, logger: logger, + hashFunction: hashFunction, } } @@ -179,7 +183,7 @@ func (a *BlsAggregatorService) InitializeNewTask( func (a *BlsAggregatorService) ProcessNewSignature( ctx context.Context, taskIndex types.TaskIndex, - taskResponseDigest types.TaskResponseDigest, + taskResponse types.TaskResponse, blsSignature *bls.Signature, operatorId types.OperatorId, ) error { @@ -189,14 +193,16 @@ func (a *BlsAggregatorService) ProcessNewSignature( if !taskInitialized { return TaskNotFoundErrorFn(taskIndex) } + signatureVerificationErrorC := make(chan error) // send the task to the goroutine processing this task // and return the error (if any) returned by the signature verification routine + select { // we need to send this as part of select because if the goroutine is processing another SignedTaskResponseDigest // and cannot receive this one, we want the context to be able to cancel the request case taskC <- types.SignedTaskResponseDigest{ - TaskResponseDigest: taskResponseDigest, + TaskResponse: taskResponse, BlsSignature: blsSignature, OperatorId: operatorId, SignatureVerificationErrorC: signatureVerificationErrorC, @@ -255,13 +261,16 @@ func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc( select { case signedTaskResponseDigest := <-signedTaskRespsC: a.logger.Debug("Task goroutine received new signed task response digest", "taskIndex", taskIndex, "signedTaskResponseDigest", signedTaskResponseDigest) + // compute the taskResponseDigest using the hash function + taskResponseDigest := a.hashFunction(signedTaskResponseDigest.TaskResponse) + err := a.verifySignature(taskIndex, signedTaskResponseDigest, operatorsAvsStateDict) signedTaskResponseDigest.SignatureVerificationErrorC <- err if err != nil { continue } // after verifying signature we aggregate its sig and pubkey, and update the signed stake amount - digestAggregatedOperators, ok := aggregatedOperatorsDict[signedTaskResponseDigest.TaskResponseDigest] + digestAggregatedOperators, ok := aggregatedOperatorsDict[taskResponseDigest] if !ok { // first operator to sign on this digest digestAggregatedOperators = aggregatedOperators{ @@ -286,7 +295,7 @@ func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc( } // update the aggregatedOperatorsDict. Note that we need to assign the whole struct value at once, // because of https://github.com/golang/go/issues/3117 - aggregatedOperatorsDict[signedTaskResponseDigest.TaskResponseDigest] = digestAggregatedOperators + aggregatedOperatorsDict[taskResponseDigest] = digestAggregatedOperators if checkIfStakeThresholdsMet(a.logger, digestAggregatedOperators.signersTotalStakePerQuorum, totalStakePerQuorum, quorumThresholdPercentagesMap) { nonSignersOperatorIds := []types.OperatorId{} @@ -316,10 +325,12 @@ func (a *BlsAggregatorService) singleTaskAggregatorGoroutineFunc( } return } + blsAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, - TaskResponseDigest: signedTaskResponseDigest.TaskResponseDigest, + TaskResponse: signedTaskResponseDigest.TaskResponse, + TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: nonSignersG1Pubkeys, QuorumApksG1: quorumApksG1, SignersApkG2: digestAggregatedOperators.signersApkG2, @@ -371,7 +382,9 @@ func (a *BlsAggregatorService) verifySignature( return OperatorNotPartOfTaskQuorumErrorFn(signedTaskResponseDigest.OperatorId, taskIndex) } - // 0. verify that the msg actually came from the correct operator + taskResponseDigest := a.hashFunction(signedTaskResponseDigest.TaskResponse) + + // verify that the msg actually came from the correct operator operatorG2Pubkey := operatorsAvsStateDict[signedTaskResponseDigest.OperatorId].OperatorInfo.Pubkeys.G2Pubkey if operatorG2Pubkey == nil { a.logger.Error("Operator G2 pubkey not found", "operatorId", signedTaskResponseDigest.OperatorId, "taskId", taskIndex) @@ -379,10 +392,13 @@ func (a *BlsAggregatorService) verifySignature( } a.logger.Debug("Verifying signed task response digest signature", "operatorG2Pubkey", operatorG2Pubkey, - "taskResponseDigest", signedTaskResponseDigest.TaskResponseDigest, + "taskResponseDigest", taskResponseDigest, "blsSignature", signedTaskResponseDigest.BlsSignature, ) - signatureVerified, err := signedTaskResponseDigest.BlsSignature.Verify(operatorG2Pubkey, signedTaskResponseDigest.TaskResponseDigest) + + // if the operator signs a digest that is not the digest of the TaskResponse submitted in ProcessNewTask + // then the signature will not be verified + signatureVerified, err := signedTaskResponseDigest.BlsSignature.Verify(operatorG2Pubkey, taskResponseDigest) if err != nil { return SignatureVerificationError(err) } diff --git a/services/bls_aggregation/blsagg_test.go b/services/bls_aggregation/blsagg_test.go index 27ad9704..616c6754 100644 --- a/services/bls_aggregation/blsagg_test.go +++ b/services/bls_aggregation/blsagg_test.go @@ -2,6 +2,8 @@ package blsagg import ( "context" + "crypto/sha256" + "encoding/json" "math/big" "testing" "time" @@ -23,6 +25,28 @@ func TestBlsAgg(t *testing.T) { // 1 second seems to be enough for tests to pass. Currently takes 5s to run all tests tasksTimeToExpiry := 1 * time.Second + hashFunction := func(taskResponse types.TaskResponse) types.TaskResponseDigest { + taskResponseBytes, err := json.Marshal(taskResponse) + if err != nil { + panic(err) + } + return types.TaskResponseDigest(sha256.Sum256(taskResponseBytes)) + } + + wrongHashFunction := func(taskResponse types.TaskResponse) types.TaskResponseDigest { + taskResponseBytes, err := json.Marshal(taskResponse) + if err != nil { + panic(err) + } + // append something to the taskResponseBytes to make it different + taskResponseBytes = append(taskResponseBytes, []byte("something")...) + return types.TaskResponseDigest(sha256.Sum256(taskResponseBytes)) + } + + type mockTaskResponse struct { + Value int + } + t.Run("1 quorum 1 operator 1 correct signature", func(t *testing.T) { testOperator1 := types.TestOperator{ OperatorId: types.OperatorId{1}, @@ -33,20 +57,25 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := mockTaskResponse{123} // Initialize with appropriate data + + // Compute the TaskResponseDigest as the SHA-256 sum of the TaskResponse + taskResponseDigest := hashFunction(taskResponse) + blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSig, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse, TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{testOperator1.BlsKeypair.GetPubKeyG1()}, @@ -73,31 +102,34 @@ func TestBlsAgg(t *testing.T) { StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(300), 1: big.NewInt(100)}, BlsKeypair: newBlsKeyPairPanics("0x3"), } + blockNum := uint32(1) taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} - taskResponseDigest := types.TaskResponseDigest{123} - blockNum := uint32(1) + taskResponse := mockTaskResponse{123} + + taskResponseDigest := hashFunction(taskResponse) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) blsSigOp3 := testOperator3.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp3, testOperator3.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp3, testOperator3.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse, TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{testOperator1.BlsKeypair.GetPubKeyG1(). @@ -126,28 +158,30 @@ func TestBlsAgg(t *testing.T) { StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, BlsKeypair: newBlsKeyPairPanics("0x2"), } + blockNum := uint32(1) taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} - taskResponseDigest := types.TaskResponseDigest{123} - blockNum := uint32(1) + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse, TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{ @@ -178,34 +212,37 @@ func TestBlsAgg(t *testing.T) { fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) // initialize 2 concurrent tasks task1Index := types.TaskIndex(1) - task1ResponseDigest := types.TaskResponseDigest{123} + task1Response := mockTaskResponse{123} + task1ResponseDigest := hashFunction(task1Response) err := blsAggServ.InitializeNewTask(task1Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) task2Index := types.TaskIndex(2) - task2ResponseDigest := types.TaskResponseDigest{230} + task2Response := mockTaskResponse{234} + task2ResponseDigest := hashFunction(task2Response) err = blsAggServ.InitializeNewTask(task2Index, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigTask1Op1 := testOperator1.BlsKeypair.SignMessage(task1ResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), task1Index, task1ResponseDigest, blsSigTask1Op1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), task1Index, task1Response, blsSigTask1Op1, testOperator1.OperatorId) require.Nil(t, err) blsSigTask2Op1 := testOperator1.BlsKeypair.SignMessage(task2ResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), task2Index, task2ResponseDigest, blsSigTask2Op1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), task2Index, task2Response, blsSigTask2Op1, testOperator1.OperatorId) require.Nil(t, err) blsSigTask1Op2 := testOperator2.BlsKeypair.SignMessage(task1ResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), task1Index, task1ResponseDigest, blsSigTask1Op2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), task1Index, task1Response, blsSigTask1Op2, testOperator2.OperatorId) require.Nil(t, err) blsSigTask2Op2 := testOperator2.BlsKeypair.SignMessage(task2ResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), task2Index, task2ResponseDigest, blsSigTask2Op2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), task2Index, task2Response, blsSigTask2Op2, testOperator2.OperatorId) require.Nil(t, err) wantAggregationServiceResponseTask1 := BlsAggregationServiceResponse{ Err: nil, TaskIndex: task1Index, + TaskResponse: task1Response, TaskResponseDigest: task1ResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{ @@ -218,6 +255,7 @@ func TestBlsAgg(t *testing.T) { wantAggregationServiceResponseTask2 := BlsAggregationServiceResponse{ Err: nil, TaskIndex: task2Index, + TaskResponse: task2Response, TaskResponseDigest: task2ResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{ @@ -254,7 +292,7 @@ func TestBlsAgg(t *testing.T) { fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) @@ -279,21 +317,23 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{50} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSig, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse, TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{testOperator2.BlsKeypair.GetPubKeyG1()}, QuorumApksG1: []*bls.G1Point{testOperator1.BlsKeypair.GetPubKeyG1().Add(testOperator2.BlsKeypair.GetPubKeyG1())}, @@ -319,16 +359,17 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0} quorumThresholdPercentages := []types.QuorumThresholdPercentage{60} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSig, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: TaskExpiredErrorFn(taskIndex), @@ -353,25 +394,27 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse, TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{ @@ -406,25 +449,27 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{50, 50} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse, TaskResponseDigest: taskResponseDigest, NonSignersPubkeysG1: []*bls.G1Point{ testOperator3.BlsKeypair.GetPubKeyG1(), @@ -461,20 +506,21 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{60, 60} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2, testOperator3}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ @@ -494,17 +540,18 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ @@ -529,17 +576,18 @@ func TestBlsAgg(t *testing.T) { taskIndex := types.TaskIndex(0) quorumNumbers := types.QuorumNums{0, 1} quorumThresholdPercentages := []types.QuorumThresholdPercentage{100, 100} - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blockNum := uint32(1) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ @@ -557,14 +605,15 @@ func TestBlsAgg(t *testing.T) { } blockNum := uint32(1) taskIndex := types.TaskIndex(0) - taskResponseDigest := types.TaskResponseDigest{123} + taskResponse := mockTaskResponse{123} + taskResponseDigest := hashFunction(taskResponse) blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) - err := blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest, blsSig, testOperator1.OperatorId) + err := blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) require.Equal(t, TaskNotFoundErrorFn(taskIndex), err) }) @@ -588,26 +637,29 @@ func TestBlsAgg(t *testing.T) { fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) - taskResponseDigest1 := types.TaskResponseDigest{1} + taskResponse1 := mockTaskResponse{1} + taskResponseDigest1 := hashFunction(taskResponse1) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest1) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest1, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse1, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) - taskResponseDigest2 := types.TaskResponseDigest{2} + taskResponse2 := mockTaskResponse{2} + taskResponseDigest2 := hashFunction(taskResponse2) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest2) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - err = blsAggServ.ProcessNewSignature(ctx, taskIndex, taskResponseDigest2, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(ctx, taskIndex, taskResponse2, blsSigOp2, testOperator2.OperatorId) // this should timeout because the task goroutine is blocked on the response channel (since we only listen for it below) require.Equal(t, context.DeadlineExceeded, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: nil, TaskIndex: taskIndex, + TaskResponse: taskResponse1, TaskResponseDigest: taskResponseDigest1, NonSignersPubkeysG1: []*bls.G1Point{}, QuorumApksG1: []*bls.G1Point{testOperator1.BlsKeypair.GetPubKeyG1()}, @@ -636,17 +688,19 @@ func TestBlsAgg(t *testing.T) { fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1, testOperator2}) noopLogger := logging.NewNoopLogger() - blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, noopLogger) + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) require.Nil(t, err) - taskResponseDigest1 := types.TaskResponseDigest{1} + taskResponse1 := mockTaskResponse{1} + taskResponseDigest1 := hashFunction(taskResponse1) blsSigOp1 := testOperator1.BlsKeypair.SignMessage(taskResponseDigest1) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest1, blsSigOp1, testOperator1.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse1, blsSigOp1, testOperator1.OperatorId) require.Nil(t, err) - taskResponseDigest2 := types.TaskResponseDigest{2} + taskResponse2 := mockTaskResponse{2} + taskResponseDigest2 := hashFunction(taskResponse2) blsSigOp2 := testOperator2.BlsKeypair.SignMessage(taskResponseDigest2) - err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponseDigest2, blsSigOp2, testOperator2.OperatorId) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse2, blsSigOp2, testOperator2.OperatorId) require.Nil(t, err) wantAggregationServiceResponse := BlsAggregationServiceResponse{ Err: TaskExpiredErrorFn(taskIndex), @@ -654,6 +708,32 @@ func TestBlsAgg(t *testing.T) { gotAggregationServiceResponse := <-blsAggServ.aggregatedResponsesC require.Equal(t, wantAggregationServiceResponse, gotAggregationServiceResponse) }) + + t.Run("1 quorum 1 operator 1 invalid signature (TaskResponseDigest does not match TaskResponse)", func(t *testing.T) { + testOperator1 := types.TestOperator{ + OperatorId: types.OperatorId{1}, + StakePerQuorum: map[types.QuorumNum]types.StakeAmount{0: big.NewInt(100), 1: big.NewInt(200)}, + BlsKeypair: newBlsKeyPairPanics("0x1"), + } + blockNum := uint32(1) + taskIndex := types.TaskIndex(0) + quorumNumbers := types.QuorumNums{0} + quorumThresholdPercentages := []types.QuorumThresholdPercentage{100} + taskResponse := mockTaskResponse{123} // Initialize with appropriate data + + taskResponseDigest := wrongHashFunction(taskResponse) + + blsSig := testOperator1.BlsKeypair.SignMessage(taskResponseDigest) + + fakeAvsRegistryService := avsregistry.NewFakeAvsRegistryService(blockNum, []types.TestOperator{testOperator1}) + noopLogger := logging.NewNoopLogger() + blsAggServ := NewBlsAggregatorService(fakeAvsRegistryService, hashFunction, noopLogger) + + err := blsAggServ.InitializeNewTask(taskIndex, blockNum, quorumNumbers, quorumThresholdPercentages, tasksTimeToExpiry) + require.Nil(t, err) + err = blsAggServ.ProcessNewSignature(context.Background(), taskIndex, taskResponse, blsSig, testOperator1.OperatorId) + require.EqualError(t, err, "Signature verification failed. Incorrect Signature.") + }) } func newBlsKeyPairPanics(hexKey string) *bls.KeyPair { diff --git a/types/avs.go b/types/avs.go index 9a10c679..22787013 100644 --- a/types/avs.go +++ b/types/avs.go @@ -8,9 +8,12 @@ import ( type TaskIndex = uint32 type TaskResponseDigest = Bytes32 +type TaskResponse = interface{} + +type TaskResponseHashFunction func(taskResponse TaskResponse) TaskResponseDigest type SignedTaskResponseDigest struct { - TaskResponseDigest TaskResponseDigest + TaskResponse TaskResponse BlsSignature *bls.Signature OperatorId OperatorId SignatureVerificationErrorC chan error `json:"-"` // removed from json because channels are not marshallable @@ -18,7 +21,7 @@ type SignedTaskResponseDigest struct { func (strd SignedTaskResponseDigest) LogValue() slog.Value { return slog.GroupValue( - slog.Any("taskResponseDigest", strd.TaskResponseDigest), + slog.Any("taskResponse", strd.TaskResponse), slog.Any("blsSignature", strd.BlsSignature), slog.Any("operatorId", strd.OperatorId), )