diff --git a/.gitignore b/.gitignore index de840d3..1f4126d 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ aggregator.log /performance-test bft-config /data/ +.cache/ # Environment files with secrets .env diff --git a/README.md b/README.md index 72f72cd..dc3138e 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,14 @@ The service will start on `http://localhost:3000` by default. The service is configured via environment variables: +### Chain Configuration + +| Variable | Description | Default | +|-----------------|-----------------|-----------| +| `CHAIN_ID` | Chain ID | `unicity` | +| `CHAIN_VERSION` | Chain version | `1.0` | +| `CHAIN_FORK_ID` | Chain's Fork ID | `mainnet` | + ### Server Configuration | Variable | Description | Default | |----------|-------------|---------| @@ -220,6 +228,7 @@ Submit a state transition request to the aggregation layer with cryptographic va - `REQUEST_ID_MISMATCH` - RequestID doesn't match SHA256(publicKey || stateHash) - `INVALID_STATE_HASH_FORMAT` - StateHash not in proper DataHash imprint format - `INVALID_TRANSACTION_HASH_FORMAT` - TransactionHash not in proper DataHash imprint format +- `INVALID_SHARD` - The commitment was sent to the wrong shard - `UNSUPPORTED_ALGORITHM` - Algorithm other than secp256k1 #### `get_inclusion_proof` @@ -551,6 +560,92 @@ The service implements a MongoDB-based leader election system: - **`follower`** - Processing API requests, monitoring for leadership - **`standalone`** - Single server mode (HA disabled) +## Sharding + +To support horizontal scaling, the aggregators can be run in a sharded configuration consisting of one parent aggregator +and multiple child aggregators. In this mode, the global Sparse Merkle Tree (SMT) is split across the child nodes, and +agents must submit their commitments to the correct child node. + +For a more detailed technical explanation of the sharded SMT structure, please refer to the official specification: +[https://github.com/unicitynetwork/specs/blob/main/smt.md](https://github.com/unicitynetwork/specs/blob/main/smt.md) + +### Commitment Routing + +The commitments are assigned to a shard based on the least significant bits of their commitment identifier. The number of +bits used to determine the shard is defined by the `SHARD_ID_LENGTH` configuration. + +For example `SHARD_ID_LENGTH: 1` means that the rightmost `1` bits of commitment identifier determines the correct +shard. In this case there would be 2 shards e.g. commitments ending with bit `0` would go to `shard-1`, and +commitments ending with bit `1` would go to the `shard-2`. + +In sharded setup only the parent aggregator talks to the BFT node. + +### Shard ID Encoding + +The `shardID` is a unique identifier for each shard that includes a `1` as its most significant bit (MSB). This prefix +bit ensures that the leading zeros are preserved for bit manipulations. + +Examples +- For `SHARD_ID_LENGTH: 1` the valid `shardID`s are `0b10` (2) and `0b11` (3), for a total of two shards. +- For `SHARD_ID_LENGTH: 2` the valid `shardID`s are `0b100` (4), `0b101` (5), `0b110` (6) and `0b111` (7), for a total of four shards. + +A child aggregator validates incoming commitments to ensure they belong to its shard. If a commitment is sent to the +wrong shard, the aggregator will reject it. + +### Example Sharded Setup + +The following diagram illustrates a sharded setup with one parent and two child aggregators for `SHARD_ID_LENGTH: 1`. + +```text + +--------------------+ + | Parent Aggregator | + | (2-leaf SMT) | + +--------------------+ + / \ + / \ ++----------------+ +----------------+ +| Child Agg. #1 | | Child Agg. #2 | +| ShardID = 0b10 | | ShardID = 0b11 | +| (handles *...0)| | (handles *...1)| ++----------------+ +----------------+ + ^ ^ + | | ++----------------+ +----------------+ +| Agent sends | | Agent sends | +| commitment | | commitment | +| ID = ...xxx0 | | ID = ...xxx1 | ++----------------+ +----------------+ +``` + +### Configuration + +The sharded setup is configured via environment variables, as seen in `sharding-compose.yml`. + +A **parent** aggregator is configured with: +```yaml +environment: + SHARDING_MODE: "parent" + SHARD_ID_LENGTH: 1 +``` + +A **child** aggregator is configured with its unique `shardID` and the address of the parent, for example: + +Shard-1: +```yaml +environment: + SHARDING_MODE: "child" + SHARDING_CHILD_SHARD_ID: 2 # (binary 0b10) + SHARDING_CHILD_PARENT_RPC_ADDR: http://aggregator-root:3000 +``` + +Shard-2: +```yaml +environment: + SHARDING_MODE: "child" + SHARDING_CHILD_SHARD_ID: 3 # (binary 0b11) + SHARDING_CHILD_PARENT_RPC_ADDR: http://aggregator-root:3000 +``` + ## Error Handling The service implements comprehensive JSON-RPC 2.0 error codes: diff --git a/cmd/aggregator/main.go b/cmd/aggregator/main.go index fdd00c7..fe0c696 100644 --- a/cmd/aggregator/main.go +++ b/cmd/aggregator/main.go @@ -102,41 +102,50 @@ func main() { // Create the shared state tracker for block sync height stateTracker := state.NewSyncStateTracker() - // Create the Round Manager - roundManager, err := round.NewRoundManager(ctx, cfg, log, commitmentQueue, storageInstance, stateTracker) + // Create round manager based on sharding mode + roundManager, err := round.NewManager(ctx, cfg, log, commitmentQueue, storageInstance, stateTracker) if err != nil { log.WithComponent("main").Error("Failed to create round manager", "error", err.Error()) gracefulExit(asyncLogger, 1) } - // Perform initial SMT restoration. This is required in all modes. + // Initialize round manager (SMT restoration, etc.) if err := roundManager.Start(ctx); err != nil { log.WithComponent("main").Error("Failed to start round manager", "error", err.Error()) gracefulExit(asyncLogger, 1) } - // Initialize HA Manager if enabled + // Initialize leader selector and HA Manager if enabled + var ls leaderSelector var haManager *ha.HAManager - var leaderSelector *ha.LeaderElection if cfg.HA.Enabled { log.WithComponent("main").Info("High availability mode enabled") - leaderSelector = ha.NewLeaderElection(log, cfg.HA, storageInstance.LeadershipStorage()) - leaderSelector.Start(ctx) + ls = ha.NewLeaderElection(log, cfg.HA, storageInstance.LeadershipStorage()) + ls.Start(ctx) + + // Disable block syncing for parent aggregator mode + // Parent mode uses state-based SMT (current shard roots) rather than history-based (commitment leaves) + disableBlockSync := cfg.Sharding.Mode == config.ShardingModeParent + if disableBlockSync { + log.WithComponent("main").Info("Block syncing disabled for parent aggregator mode - SMT will be reconstructed on leadership transition") + } - haManager = ha.NewHAManager(log, roundManager, leaderSelector, storageInstance, roundManager.GetSMT(), stateTracker, cfg.Processing.RoundDuration) + haManager = ha.NewHAManager(log, roundManager, ls, storageInstance, roundManager.GetSMT(), cfg.Sharding.Child.ShardID, stateTracker, cfg.Processing.RoundDuration, disableBlockSync) haManager.Start(ctx) - } else { log.WithComponent("main").Info("High availability mode is disabled, running as standalone leader") - // In non-HA mode, the node is always the leader, so activate it directly. + // In non-HA mode, activate the round manager directly if err := roundManager.Activate(ctx); err != nil { log.WithComponent("main").Error("Failed to activate round manager", "error", err.Error()) gracefulExit(asyncLogger, 1) } } - // Initialize service - aggregatorService := service.NewAggregatorService(cfg, log, roundManager, commitmentQueue, storageInstance, leaderSelector) + aggregatorService, err := service.NewService(ctx, cfg, log, roundManager, commitmentQueue, storageInstance, ls) + if err != nil { + log.WithComponent("main").Error("Failed to create service", "error", err.Error()) + gracefulExit(asyncLogger, 1) + } // Initialize gateway server server := gateway.NewServer(cfg, log, aggregatorService) @@ -166,17 +175,19 @@ func main() { log.WithComponent("main").Error("Failed to stop server gracefully", "error", err.Error()) } - // Stop round manager - roundManager.Stop(shutdownCtx) + // Stop HA Manager if it was started + if haManager != nil { + haManager.Stop() + } // Stop leader selector if it was started - if leaderSelector != nil { - leaderSelector.Stop(shutdownCtx) + if ls != nil { + ls.Stop(shutdownCtx) } - // Stop HA Manager if it was started - if haManager != nil { - haManager.Stop() + // Stop round manager + if err := roundManager.Stop(shutdownCtx); err != nil { + log.WithComponent("main").Error("Failed to stop round manager gracefully", "error", err.Error()) } // Close storage backends @@ -194,3 +205,9 @@ func main() { asyncLogger.Stop() } } + +type leaderSelector interface { + IsLeader(ctx context.Context) (bool, error) + Start(ctx context.Context) + Stop(ctx context.Context) +} diff --git a/cmd/sharding-perf-test/main.go b/cmd/sharding-perf-test/main.go new file mode 100644 index 0000000..2d445c5 --- /dev/null +++ b/cmd/sharding-perf-test/main.go @@ -0,0 +1,754 @@ +package main + +import ( + "context" + "crypto/rand" + "encoding/json" + "fmt" + "log" + "os" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + + "github.com/unicitynetwork/aggregator-go/internal/signing" + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +// Test configuration +const ( + testDuration = 30 * time.Second + workerCount = 30 // Number of concurrent workers + requestsPerSec = 5000 // Target requests per second +) + +// Generate a cryptographically valid commitment request +func generateCommitmentRequest() *api.SubmitCommitmentRequest { + // Generate a real secp256k1 key pair + privateKey, err := btcec.NewPrivateKey() + if err != nil { + panic(fmt.Sprintf("Failed to generate private key: %v", err)) + } + publicKeyBytes := privateKey.PubKey().SerializeCompressed() + + // Generate random state data and create DataHash imprint + stateData := make([]byte, 32) + rand.Read(stateData) + stateHashImprint := signing.CreateDataHashImprint(stateData) + + // Create RequestID deterministically + requestID, err := api.CreateRequestID(publicKeyBytes, stateHashImprint) + if err != nil { + panic(fmt.Sprintf("Failed to create request ID: %v", err)) + } + + // Generate random transaction data and create DataHash imprint + transactionData := make([]byte, 32) + rand.Read(transactionData) + transactionHashImprint := signing.CreateDataHashImprint(transactionData) + + // Extract transaction hash bytes for signing + transactionHashBytes, err := transactionHashImprint.DataBytes() + if err != nil { + panic(fmt.Sprintf("Failed to extract transaction hash: %v", err)) + } + + // Sign the transaction hash bytes + signingService := signing.NewSigningService() + signatureBytes, err := signingService.SignHash(transactionHashBytes, privateKey.Serialize()) + if err != nil { + panic(fmt.Sprintf("Failed to sign transaction: %v", err)) + } + + // Create receipt flag + receipt := false + + return &api.SubmitCommitmentRequest{ + RequestID: requestID, + TransactionHash: transactionHashImprint, + Authenticator: api.Authenticator{ + Algorithm: "secp256k1", + PublicKey: publicKeyBytes, + Signature: signatureBytes, + StateHash: stateHashImprint, + }, + Receipt: &receipt, + } +} + +// Worker function that continuously submits commitments +func commitmentWorker(ctx context.Context, clients []*JSONRPCClient, metrics *Metrics) { + var wg sync.WaitGroup + ticker := time.NewTicker(time.Second / time.Duration(requestsPerSec/workerCount)) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + wg.Wait() // Wait for in-flight requests to complete + return + case <-ticker.C: + wg.Add(1) + // Generate and submit commitment asynchronously + go func() { + defer wg.Done() + req := generateCommitmentRequest() + + // choose correct client based on generated request ID + reqBytes, _ := req.RequestID.Bytes() + lsb := reqBytes[len(reqBytes)-1] + var client *JSONRPCClient + if lsb&1 == 0 { + client = clients[1] + } else { + client = clients[0] + } + + shardM := metrics.shardMetrics[client.url] + atomic.AddInt64(&shardM.totalRequests, 1) + + resp, err := client.call("submit_commitment", req) + if err != nil { + atomic.AddInt64(&shardM.failedRequests, 1) + // Don't print network errors - too noisy + return + } + + requestIDStr := strings.ToLower(req.RequestID.String()) + if resp.Error != nil { + atomic.AddInt64(&shardM.failedRequests, 1) + if resp.Error.Message == "REQUEST_ID_EXISTS" { + atomic.AddInt64(&shardM.requestIdExistsErr, 1) + // Track this ID - it exists so it will be in blocks! + metrics.submittedRequestIDs.Store(requestIDStr, &requestInfo{URL: client.url, Found: 0}) + } + return + } + + // Parse response + var submitResp api.SubmitCommitmentResponse + respBytes, _ := json.Marshal(resp.Result) + if err := json.Unmarshal(respBytes, &submitResp); err != nil { + atomic.AddInt64(&shardM.failedRequests, 1) + return + } + + if submitResp.Status == "SUCCESS" { + atomic.AddInt64(&shardM.successfulRequests, 1) + // Track this request ID as submitted by us (normalized to lowercase) + metrics.submittedRequestIDs.Store(requestIDStr, &requestInfo{URL: client.url, Found: 0}) + } else if submitResp.Status == "REQUEST_ID_EXISTS" { + atomic.AddInt64(&shardM.requestIdExistsErr, 1) + atomic.AddInt64(&shardM.successfulRequests, 1) // Count as successful - it will be in blocks! + // Also track this ID - it exists so it will be in blocks! (normalized to lowercase) + metrics.submittedRequestIDs.Store(requestIDStr, &requestInfo{URL: client.url, Found: 0}) + } else { + atomic.AddInt64(&shardM.failedRequests, 1) + // Log unexpected status + if submitResp.Status != "" { + fmt.Printf("Unexpected status '%s' for request %s\n", submitResp.Status, req.RequestID) + } + } + }() + } + } +} + +func main() { + // Get URL and auth header from environment variables + authHeader := os.Getenv("AUTH_HEADER") + + fmt.Printf("Starting sharded aggregator performance test...\n") + fmt.Printf("Duration: %v\n", testDuration) + fmt.Printf("Workers: %d\n", workerCount) + fmt.Printf("Target RPS: %d\n", requestsPerSec) + fmt.Printf("----------------------------------------\n") + + // Initialize metrics + metrics := &Metrics{ + startTime: time.Now(), + startingBlockNumbers: make(map[string]int64), + shardMetrics: make(map[string]*ShardMetrics), + } + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), testDuration) + defer cancel() + + // Create JSON-RPC clients + s1Client := NewJSONRPCClient("http://localhost:3001", authHeader) + s2Client := NewJSONRPCClient("http://localhost:3002", authHeader) + clients := []*JSONRPCClient{s1Client, s2Client} + + // Test connectivity and get starting block number for both shards + for _, client := range clients { + metrics.shardMetrics[client.url] = &ShardMetrics{} + fmt.Printf("Testing connectivity to %s...\n", client.url) + resp, err := client.call("get_block_height", nil) + if err != nil { + log.Fatalf("Failed to connect to aggregator at %s: %v", client.url, err) + } + + if resp.Error != nil { + log.Fatalf("Error getting block height from %s: %v", client.url, resp.Error.Message) + } + + var heightResp GetBlockHeightResponse + respBytes, _ := json.Marshal(resp.Result) + if err := json.Unmarshal(respBytes, &heightResp); err != nil { + log.Fatalf("Failed to parse block height from %s: %v", client.url, err) + } + + var startingBlockNumber int64 + if _, err := fmt.Sscanf(heightResp.BlockNumber, "%d", &startingBlockNumber); err != nil { + log.Fatalf("Failed to parse starting block number from %s: %v", client.url, err) + } + + fmt.Printf("✓ Connected successfully to %s\n", client.url) + fmt.Printf("✓ Starting block number for %s: %d\n", client.url, startingBlockNumber) + metrics.startingBlockNumbers[client.url] = startingBlockNumber + } + + var wg sync.WaitGroup + + // Start commitment workers + wg.Add(workerCount) + for i := 0; i < workerCount; i++ { + go func() { + defer wg.Done() + commitmentWorker(ctx, clients, metrics) + }() + } + + // Progress reporting + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + elapsed := time.Since(metrics.startTime) + var total, successful, failed, exists int64 + for _, sm := range metrics.shardMetrics { + total += atomic.LoadInt64(&sm.totalRequests) + successful += atomic.LoadInt64(&sm.successfulRequests) + failed += atomic.LoadInt64(&sm.failedRequests) + exists += atomic.LoadInt64(&sm.requestIdExistsErr) + } + + rps := float64(total) / elapsed.Seconds() + fmt.Printf("[%v] Total: %d, Success: %d, Failed: %d, Exists: %d, RPS: %.1f\n", elapsed.Truncate(time.Second), total, successful, failed, exists, rps) + } + } + }() + + // Wait for completion + wg.Wait() + + // Give a moment for any in-flight requests to complete + time.Sleep(1 * time.Second) + + // Stop submission phase and get counts + fmt.Printf("\n----------------------------------------\n") + fmt.Printf("Submission completed. Now checking blocks for all commitments...\n") + + var totalSuccessful int64 + for _, sm := range metrics.shardMetrics { + totalSuccessful += atomic.LoadInt64(&sm.successfulRequests) + } + fmt.Printf("Total successful submissions: %d\n", totalSuccessful) + + for _, waitClient := range clients { + shardM := metrics.shardMetrics[waitClient.url] + shardSuccessful := atomic.LoadInt64(&shardM.successfulRequests) + startingBlockNumber := metrics.startingBlockNumbers[waitClient.url] + fmt.Printf("\n--- Checking shard %s ---\n", waitClient.url) + fmt.Printf("Starting from block %d\n", startingBlockNumber+1) + + var latestBlockNumber int64 + var blockHeightResp *JSONRPCResponse + var blockHeightErr error + + for i := 0; i < 5; i++ { // Retry up to 5 times + blockHeightResp, blockHeightErr = waitClient.call("get_block_height", nil) + if blockHeightErr == nil && blockHeightResp.Error == nil { + break // Success + } + fmt.Printf("Retrying get_block_height for %s... (%d/5)\n", waitClient.url, i+1) + time.Sleep(1 * time.Second) + } + + if blockHeightErr == nil && blockHeightResp.Error == nil { + var heightResult GetBlockHeightResponse + respBytes, _ := json.Marshal(blockHeightResp.Result) + if err := json.Unmarshal(respBytes, &heightResult); err == nil { + fmt.Sscanf(heightResult.BlockNumber, "%d", &latestBlockNumber) + } + } else { + log.Printf("Could not get block height for %s after retries. Proceeding with latestBlock=0.", waitClient.url) + } + + fmt.Printf("Latest block: %d\n", latestBlockNumber) + + currentCheckBlock := startingBlockNumber + 1 + + safeBlockNumber := latestBlockNumber - 1 + if safeBlockNumber < currentCheckBlock { + fmt.Printf("\nWaiting for more blocks to be created...\n") + safeBlockNumber = currentCheckBlock + } + + fmt.Printf("\nChecking blocks %d to %d...\n", currentCheckBlock, safeBlockNumber) + + for currentCheckBlock <= safeBlockNumber && atomic.LoadInt64(&shardM.totalBlockCommitments) < shardSuccessful { + commitReq := GetBlockCommitmentsRequest{ + BlockNumber: fmt.Sprintf("%d", currentCheckBlock), + } + + commitResp, err := waitClient.call("get_block_commitments", commitReq) + if err != nil { + currentCheckBlock++ + continue + } + + if commitResp.Error != nil { + currentCheckBlock++ + continue + } + + var commitsResp GetBlockCommitmentsResponse + commitRespBytes, err := json.Marshal(commitResp.Result) + if err != nil { + currentCheckBlock++ + continue + } + if err := json.Unmarshal(commitRespBytes, &commitsResp); err != nil { + currentCheckBlock++ + continue + } + + ourCommitmentCount := 0 + notOurs := 0 + for _, commitment := range commitsResp.Commitments { + requestIDStr := strings.ToLower(commitment.RequestID) + if val, exists := metrics.submittedRequestIDs.Load(requestIDStr); exists { + info := val.(*requestInfo) + if info.URL == waitClient.url { + if atomic.CompareAndSwapInt32(&info.Found, 0, 1) { + ourCommitmentCount++ + } + } + } else { + notOurs++ + } + } + + if notOurs > 100 { + fmt.Printf(" [DEBUG] Block %d has %d commitments not from our test\n", currentCheckBlock, notOurs) + } + + shardM.addBlockCommitmentCount(ourCommitmentCount) + if ourCommitmentCount > 0 { + atomic.AddInt64(&shardM.totalBlockCommitments, int64(ourCommitmentCount)) + fmt.Printf("Block %d: %d our commitments (total in block: %d, shard total: %d/%d)\n", currentCheckBlock, ourCommitmentCount, len(commitsResp.Commitments), atomic.LoadInt64(&shardM.totalBlockCommitments), shardSuccessful) + } else if len(commitsResp.Commitments) > 0 { + fmt.Printf("Block %d: %d commitments from other sources\n", currentCheckBlock, len(commitsResp.Commitments)) + } + + currentCheckBlock++ + time.Sleep(20 * time.Millisecond) + } + + if atomic.LoadInt64(&shardM.totalBlockCommitments) < shardSuccessful { + fmt.Printf("\nContinuing to check for remaining commitments on %s...\n", waitClient.url) + fmt.Printf("Will check for up to 3 minutes for new blocks...\n") + + timeoutTime := time.Now().Add(3 * time.Minute) + lastProgressTime := time.Now() + lastReportTime := time.Now() + + type blockRetryInfo struct { + blockNumber int64 + totalCommitments int + lastChecked time.Time + retryCount int + } + blocksToRetry := make(map[int64]*blockRetryInfo) + + checkBlock := func(blockNum int64) bool { + commitReq := GetBlockCommitmentsRequest{ + BlockNumber: fmt.Sprintf("%d", blockNum), + } + + commitResp, err := waitClient.call("get_block_commitments", commitReq) + if err != nil { + fmt.Printf("Block %d: network error: %v\n", blockNum, err) + return false + } + + if commitResp.Error != nil { + if commitResp.Error.Code != -32602 { + fmt.Printf("Block %d: error %d: %s\n", blockNum, commitResp.Error.Code, commitResp.Error.Message) + } + return false + } + + var commitsResp GetBlockCommitmentsResponse + commitRespBytes, _ := json.Marshal(commitResp.Result) + if err := json.Unmarshal(commitRespBytes, &commitsResp); err != nil { + fmt.Printf("Block %d: failed to parse response: %v\n", blockNum, err) + return false + } + + ourCommitmentCount := 0 + for _, commitment := range commitsResp.Commitments { + requestIDStr := strings.ToLower(commitment.RequestID) + if val, exists := metrics.submittedRequestIDs.Load(requestIDStr); exists { + info := val.(*requestInfo) + if info.URL == waitClient.url { + if atomic.CompareAndSwapInt32(&info.Found, 0, 1) { + ourCommitmentCount++ + } + } + } + } + + if ourCommitmentCount > 0 { + shardM.addBlockCommitmentCount(ourCommitmentCount) + atomic.AddInt64(&shardM.totalBlockCommitments, int64(ourCommitmentCount)) + fmt.Printf("Block %d: %d our commitments (total in block: %d, shard total: %d/%d)\n", blockNum, ourCommitmentCount, len(commitsResp.Commitments), atomic.LoadInt64(&shardM.totalBlockCommitments), shardSuccessful) + lastProgressTime = time.Now() + delete(blocksToRetry, blockNum) + return true + } else if len(commitsResp.Commitments) > 0 { + if _, exists := blocksToRetry[blockNum]; !exists { + blocksToRetry[blockNum] = &blockRetryInfo{ + blockNumber: blockNum, + totalCommitments: len(commitsResp.Commitments), + lastChecked: time.Now(), + retryCount: 0, + } + fmt.Printf("Block %d: 0 our commitments yet (will retry, total: %d)\n", blockNum, len(commitsResp.Commitments)) + } + } + return false + } + + for atomic.LoadInt64(&shardM.totalBlockCommitments) < shardSuccessful && time.Now().Before(timeoutTime) { + heightResp, err := waitClient.call("get_block_height", nil) + if err != nil { + time.Sleep(500 * time.Millisecond) + continue + } + + var heightResult GetBlockHeightResponse + heightRespBytes, _ := json.Marshal(heightResp.Result) + if err := json.Unmarshal(heightRespBytes, &heightResult); err != nil { + time.Sleep(500 * time.Millisecond) + continue + } + + var latestBlock int64 + if _, err := fmt.Sscanf(heightResult.BlockNumber, "%d", &latestBlock); err != nil { + time.Sleep(500 * time.Millisecond) + continue + } + + safeLatestBlock := latestBlock - 1 + for currentCheckBlock <= safeLatestBlock && atomic.LoadInt64(&shardM.totalBlockCommitments) < shardSuccessful { + checkBlock(currentCheckBlock) + currentCheckBlock++ + } + + for blockNum, info := range blocksToRetry { + if time.Since(info.lastChecked) > 500*time.Millisecond { + checkBlock(blockNum) + info.lastChecked = time.Now() + info.retryCount++ + } + } + + if time.Since(lastReportTime) > 5*time.Second { + fmt.Printf("Still checking %s... found %d/%d commitments, %d blocks pending retry...\n", waitClient.url, atomic.LoadInt64(&shardM.totalBlockCommitments), shardSuccessful, len(blocksToRetry)) + lastReportTime = time.Now() + } + + if time.Since(lastProgressTime) > 90*time.Second && len(blocksToRetry) == 0 { + fmt.Printf("\nNo new commitments found for 90 seconds on %s, stopping...\n", waitClient.url) + break + } + + time.Sleep(200 * time.Millisecond) + } + } + } + + // Debug: count tracked IDs and find missing ones + trackedCount := 0 + foundCount := 0 + var sampleMissingIDs []string + + metrics.submittedRequestIDs.Range(func(key, value interface{}) bool { + trackedCount++ + requestID := key.(string) + info := value.(*requestInfo) + if info.Found == 1 { + foundCount++ + } else if len(sampleMissingIDs) < 5 { + sampleMissingIDs = append(sampleMissingIDs, requestID) + } + return true + }) + + fmt.Printf("\nDebug: Tracked %d request IDs, found in blocks: %d, missing: %d\n", trackedCount, foundCount, trackedCount-foundCount) + if len(sampleMissingIDs) > 0 { + fmt.Printf("Sample missing IDs:\n") + for i, id := range sampleMissingIDs { + fmt.Printf(" %d. %s\n", i+1, id) + } + } + + if foundCount < trackedCount { + fmt.Printf("\nFinished checking. Found %d/%d commitments\n", foundCount, trackedCount) + } else { + fmt.Printf("\nAll %d commitments have been found in blocks!\n", trackedCount) + } + // Final metrics + elapsed := time.Since(metrics.startTime) + + fmt.Printf("\n\n========================================\n") + fmt.Printf("PERFORMANCE TEST RESULTS\n") + fmt.Printf("========================================\n") + fmt.Printf("Duration: %v\n", elapsed.Truncate(time.Millisecond)) + + var total, successful, failed, exists, processedInBlocks int64 + + // Per-shard results + for url, shardM := range metrics.shardMetrics { + fmt.Printf("\n--- SHARD: %s ---\n", url) + + shardTotal := atomic.LoadInt64(&shardM.totalRequests) + shardSuccessful := atomic.LoadInt64(&shardM.successfulRequests) + shardFailed := atomic.LoadInt64(&shardM.failedRequests) + shardExists := atomic.LoadInt64(&shardM.requestIdExistsErr) + shardProcessedInBlocks := atomic.LoadInt64(&shardM.totalBlockCommitments) + + total += shardTotal + successful += shardSuccessful + failed += shardFailed + exists += shardExists + processedInBlocks += shardProcessedInBlocks + + fmt.Printf("Total requests: %d\n", shardTotal) + fmt.Printf("Successful requests: %d\n", shardSuccessful) + fmt.Printf("Failed requests: %d\n", shardFailed) + fmt.Printf("REQUEST_ID_EXISTS: %d\n", shardExists) + if elapsed.Seconds() > 0 { + fmt.Printf("Average RPS: %.2f\n", float64(shardTotal)/elapsed.Seconds()) + } + if shardTotal > 0 { + fmt.Printf("Success rate: %.2f%%\n", float64(shardSuccessful)/float64(shardTotal)*100) + } + + fmt.Printf("\nBLOCK PROCESSING:\n") + fmt.Printf("Total commitments in blocks: %d\n", shardProcessedInBlocks) + + pendingCommitments := shardSuccessful - shardProcessedInBlocks + if pendingCommitments > 0 { + if shardSuccessful > 0 { + percentage := float64(pendingCommitments) / float64(shardSuccessful) * 100 + fmt.Printf("\n⚠️ WARNING: %d commitments (%.1f%%) not found in blocks!\n", pendingCommitments, percentage) + } + } else if shardSuccessful > 0 { + fmt.Printf("\n✅ SUCCESS: All %d commitments were found in blocks!\n", shardSuccessful) + } + + fmt.Printf("\nBLOCK THROUGHPUT:\n") + shardM.mutex.RLock() + fmt.Printf("Total blocks checked: %d\n", len(shardM.blockCommitmentCounts)) + + emptyBlocks := 0 + nonEmptyBlocks := 0 + for _, count := range shardM.blockCommitmentCounts { + if count == 0 { + emptyBlocks++ + } else { + nonEmptyBlocks++ + } + } + shardM.mutex.RUnlock() + + if nonEmptyBlocks > 0 { + fmt.Printf("Non-empty blocks: %d (average %.1f commitments/block)\n", nonEmptyBlocks, float64(shardProcessedInBlocks)/float64(nonEmptyBlocks)) + } + if emptyBlocks > 0 { + fmt.Printf("Empty blocks: %d\n", emptyBlocks) + } + } + + // Aggregate results + fmt.Printf("\n\n--- AGGREGATE RESULTS ---\n") + fmt.Printf("Total requests: %d\n", total) + fmt.Printf("Successful requests: %d\n", successful) + fmt.Printf("Failed requests: %d\n", failed) + fmt.Printf("REQUEST_ID_EXISTS: %d\n", exists) + if elapsed.Seconds() > 0 { + fmt.Printf("Average RPS: %.2f\n", float64(total)/elapsed.Seconds()) + } + if total > 0 { + fmt.Printf("Success rate: %.2f%%\n", float64(successful)/float64(total)*100) + } + + fmt.Printf("\nBLOCK PROCESSING:\n") + fmt.Printf("Total commitments in blocks: %d\n", processedInBlocks) + + pendingCommitments := successful - processedInBlocks + if pendingCommitments > 0 { + if successful > 0 { + percentage := float64(pendingCommitments) / float64(successful) * 100 + fmt.Printf("\n⚠️ WARNING: %d commitments (%.1f%%) not found in blocks!\n", pendingCommitments, percentage) + } + } else if successful > 0 { + fmt.Printf("\n✅ SUCCESS: All %d commitments were found in blocks!\n", successful) + } + + fmt.Printf("========================================\n") + + verifyInclusionProofs(metrics, clients) +} + +func verifyInclusionProofs(metrics *Metrics, clients []*JSONRPCClient) { + fmt.Printf("\n\n========================================\n") + fmt.Printf("INCLUSION PROOF VERIFICATION\n") + fmt.Printf("========================================\n") + + clientMap := make(map[string]*JSONRPCClient) + for _, c := range clients { + clientMap[c.url] = c + } + + var totalToVerify, successfulVerifications, failedVerifications int64 + + // A channel to collect verification results + results := make(chan bool) + + var wg sync.WaitGroup + + const maxConcurrentVerifications = 100 + semaphore := make(chan struct{}, maxConcurrentVerifications) + + metrics.submittedRequestIDs.Range(func(key, value interface{}) bool { + requestIDStr := key.(string) + info := value.(*requestInfo) + + if info.Found == 1 { + wg.Add(1) + atomic.AddInt64(&totalToVerify, 1) + + go func(rid string, shardURL string) { + semaphore <- struct{}{} // Acquire token + defer func() { + <-semaphore // Release token + wg.Done() + }() + + client, ok := clientMap[shardURL] + if !ok { + fmt.Printf("ERROR: No client found for shard URL %s\n", shardURL) + results <- false + return + } + + // 1. Get inclusion proof + params := map[string]string{"requestId": rid} + resp, err := client.call("get_inclusion_proof", params) + if err != nil { + fmt.Printf("ERROR for %s: failed to get inclusion proof: %v\n", rid, err) + results <- false + return + } + if resp.Error != nil { + fmt.Printf("ERROR for %s: API error getting inclusion proof: %s\n", rid, resp.Error.Message) + results <- false + return + } + + var proofResp api.GetInclusionProofResponse + respBytes, _ := json.Marshal(resp.Result) + if err := json.Unmarshal(respBytes, &proofResp); err != nil { + fmt.Printf("ERROR for %s: failed to parse inclusion proof response: %v\n", rid, err) + results <- false + return + } + + if proofResp.InclusionProof == nil || proofResp.InclusionProof.MerkleTreePath == nil { + fmt.Printf("ERROR for %s: Inclusion proof or MerkleTreePath is nil\n", rid) + results <- false + return + } + + // 2. Get path from request ID + reqID := api.RequestID(rid) + reqIDPath, err := reqID.GetPath() + if err != nil { + fmt.Printf("ERROR for %s: failed to get path from request ID: %v\n", rid, err) + results <- false + return + } + + // 3. Verify proof + verifyResult, err := proofResp.InclusionProof.MerkleTreePath.Verify(reqIDPath) + if err != nil { + fmt.Printf("ERROR for %s: proof verification returned an error: %v\n", rid, err) + results <- false + return + } + + if verifyResult == nil || !verifyResult.Result { + fmt.Printf("FAILURE for %s: Proof verification failed.\n", rid) + results <- false + return + } + + // Success + results <- true + }(requestIDStr, info.URL) + } + return true + }) + + // Closer goroutine + go func() { + wg.Wait() + close(results) + }() + + // Collect results + for result := range results { + if result { + atomic.AddInt64(&successfulVerifications, 1) + } else { + atomic.AddInt64(&failedVerifications, 1) + } + } + + fmt.Printf("\nVerification Summary:\n") + fmt.Printf("Total commitments to verify: %d\n", totalToVerify) + fmt.Printf("Successful verifications: %d\n", successfulVerifications) + fmt.Printf("Failed verifications: %d\n", failedVerifications) + + if failedVerifications > 0 { + fmt.Printf("\n⚠️ WARNING: %d inclusion proof verifications failed!\n", failedVerifications) + } else if totalToVerify > 0 { + fmt.Printf("\n✅ SUCCESS: All %d inclusion proofs verified successfully!\n", totalToVerify) + } else { + fmt.Printf("\nNo commitments were processed, nothing to verify.\n") + } + fmt.Printf("========================================\n") +} diff --git a/cmd/sharding-perf-test/types.go b/cmd/sharding-perf-test/types.go new file mode 100644 index 0000000..fd3a7de --- /dev/null +++ b/cmd/sharding-perf-test/types.go @@ -0,0 +1,156 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "sync" + "sync/atomic" + "time" +) + +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params interface{} `json:"params"` + ID int `json:"id"` +} + +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + Result interface{} `json:"result,omitempty"` + Error *JSONRPCError `json:"error,omitempty"` + ID int `json:"id"` +} + +type JSONRPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data string `json:"data,omitempty"` +} + +// Use the public API types + +type GetBlockResponse struct { + Block Block `json:"block"` +} + +type Block struct { + Index string `json:"index"` + Timestamp string `json:"timestamp"` +} + +type GetBlockCommitmentsRequest struct { + BlockNumber string `json:"blockNumber"` +} + +type GetBlockCommitmentsResponse struct { + Commitments []AggregatorRecord `json:"commitments"` +} + +type AggregatorRecord struct { + RequestID string `json:"requestId"` + BlockNumber string `json:"blockNumber"` +} + +type GetBlockHeightResponse struct { + BlockNumber string `json:"blockNumber"` +} + +type requestInfo struct { + URL string + Found int32 // 0 for not found, 1 for found +} + +// ShardMetrics holds all the counters and data for a single shard. +type ShardMetrics struct { + totalRequests int64 + successfulRequests int64 + failedRequests int64 + requestIdExistsErr int64 + blockCommitmentCounts []int + totalBlockCommitments int64 + mutex sync.RWMutex +} + +// Metrics holds all the metrics for the performance test. +type Metrics struct { + startTime time.Time + startingBlockNumbers map[string]int64 + submittedRequestIDs sync.Map + shardMetrics map[string]*ShardMetrics + mutex sync.RWMutex +} + +func (sm *ShardMetrics) addBlockCommitmentCount(count int) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.blockCommitmentCounts = append(sm.blockCommitmentCounts, count) +} + +type JSONRPCClient struct { + httpClient *http.Client + url string + authHeader string + requestID int64 +} + +func NewJSONRPCClient(url string, authHeader string) *JSONRPCClient { + // Configure transport with higher connection limits + transport := &http.Transport{ + MaxIdleConns: 1000, + MaxIdleConnsPerHost: 1000, + MaxConnsPerHost: 1000, + IdleConnTimeout: 90 * time.Second, + } + + return &JSONRPCClient{ + httpClient: &http.Client{ + Timeout: 30 * time.Second, + Transport: transport, + }, + url: url, + authHeader: authHeader, + } +} + +func (c *JSONRPCClient) call(method string, params interface{}) (*JSONRPCResponse, error) { + id := atomic.AddInt64(&c.requestID, 1) + + request := JSONRPCRequest{ + JSONRPC: "2.0", + Method: method, + Params: params, + ID: int(id), + } + + reqBody, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", c.url, bytes.NewBuffer(reqBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + // Add auth header if provided + if c.authHeader != "" { + req.Header.Set("Authorization", c.authHeader) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + var response JSONRPCResponse + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return &response, nil +} diff --git a/internal/config/config.go b/internal/config/config.go index 1150eb5..1c84ffa 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "errors" "fmt" "os" "strconv" @@ -13,6 +14,8 @@ import ( "github.com/unicitynetwork/bft-core/partition" "github.com/unicitynetwork/bft-go-base/types" "github.com/unicitynetwork/bft-go-base/util" + + "github.com/unicitynetwork/aggregator-go/pkg/api" ) // Config represents the application configuration @@ -25,6 +28,15 @@ type Config struct { Logging LoggingConfig `mapstructure:"logging"` BFT BFTConfig `mapstructure:"bft"` Processing ProcessingConfig `mapstructure:"processing"` + Sharding ShardingConfig `mapstructure:"sharding"` + Chain ChainConfig `mapstructure:"chain"` +} + +// ChainConfig holds metadata about the current chain configuration +type ChainConfig struct { + ID string `mapstructure:"id"` + Version string `mapstructure:"version"` + ForkID string `mapstructure:"fork_id"` } // ServerConfig holds HTTP server configuration @@ -104,6 +116,89 @@ type StorageConfig struct { RedisMaxStreamLength int64 `mapstructure:"redis_max_stream_length"` } +// ShardingMode represents the aggregator operating mode +type ShardingMode string + +const ( + ShardingModeStandalone ShardingMode = "standalone" + ShardingModeParent ShardingMode = "parent" + ShardingModeChild ShardingMode = "child" +) + +// String returns the string representation of the sharding mode +func (sm ShardingMode) String() string { + return string(sm) +} + +// IsValid returns true if the sharding mode is valid +func (sm ShardingMode) IsValid() bool { + switch sm { + case ShardingModeStandalone, ShardingModeParent, ShardingModeChild: + return true + default: + return false + } +} + +// IsStandalone returns true if this is standalone mode +func (sm ShardingMode) IsStandalone() bool { + return sm == ShardingModeStandalone +} + +// IsParent returns true if this is parent mode +func (sm ShardingMode) IsParent() bool { + return sm == ShardingModeParent +} + +// IsChild returns true if this is child mode +func (sm ShardingMode) IsChild() bool { + return sm == ShardingModeChild +} + +// ShardingConfig holds sharding configuration +type ShardingConfig struct { + Mode ShardingMode `mapstructure:"mode"` // Operating mode: standalone, parent, or child + ShardIDLength int `mapstructure:"shard_id_length"` // Bit length for shard IDs (e.g., 4 bits = 16 shards) + Child ChildConfig `mapstructure:"child"` // child aggregator config +} + +type ChildConfig struct { + ParentRpcAddr string `mapstructure:"parent_rpc_addr"` + ShardID api.ShardID `mapstructure:"shard_id"` + ParentPollTimeout time.Duration `mapstructure:"parent_poll_timeout"` + ParentPollInterval time.Duration `mapstructure:"parent_poll_interval"` +} + +func (c ShardingConfig) Validate() error { + if c.Mode == ShardingModeChild { + if err := c.Child.Validate(); err != nil { + return fmt.Errorf("invalid child mode configuration: %w", err) + } + } + if c.Mode == ShardingModeStandalone { + if c.Child.ShardID != 0 { + return errors.New("shard_id must be undefined in standalone mode") + } + } + return nil +} + +func (c ChildConfig) Validate() error { + if c.ParentRpcAddr == "" { + return errors.New("parent rpc addr is required") + } + if c.ShardID <= 1 { + return errors.New("shard ID must be positive and have at least 2 bits") + } + if c.ParentPollTimeout == 0 { + return errors.New("parent poll timeout is required") + } + if c.ParentPollInterval == 0 { + return errors.New("parent poll interval is required") + } + return nil +} + type BFTConfig struct { Enabled bool `mapstructure:"enabled"` KeyConf *partition.KeyConf `mapstructure:"key_conf"` @@ -120,6 +215,11 @@ type BFTConfig struct { // Load loads configuration from environment variables with defaults func Load() (*Config, error) { config := &Config{ + Chain: ChainConfig{ + ID: getEnvOrDefault("CHAIN_ID", "unicity"), + Version: getEnvOrDefault("CHAIN_VERSION", "1.0"), + ForkID: getEnvOrDefault("CHAIN_FORK_ID", "testnet"), + }, Server: ServerConfig{ Port: getEnvOrDefault("PORT", "3000"), Host: getEnvOrDefault("HOST", "0.0.0.0"), @@ -183,6 +283,16 @@ func Load() (*Config, error) { RedisCleanupInterval: getEnvDurationOrDefault("REDIS_CLEANUP_INTERVAL", "5m"), RedisMaxStreamLength: int64(getEnvIntOrDefault("REDIS_MAX_STREAM_LENGTH", 1000000)), }, + Sharding: ShardingConfig{ + Mode: ShardingMode(getEnvOrDefault("SHARDING_MODE", "standalone")), + ShardIDLength: getEnvIntOrDefault("SHARD_ID_LENGTH", 4), + Child: ChildConfig{ + ParentRpcAddr: getEnvOrDefault("SHARDING_CHILD_PARENT_RPC_ADDR", "http://localhost:3009"), + ShardID: getEnvIntOrDefault("SHARDING_CHILD_SHARD_ID", 0), + ParentPollTimeout: getEnvDurationOrDefault("SHARDING_CHILD_PARENT_POLL_TIMEOUT", "5s"), + ParentPollInterval: getEnvDurationOrDefault("SHARDING_CHILD_PARENT_POLL_INTERVAL", "100ms"), + }, + }, } config.BFT = BFTConfig{ Enabled: getEnvBoolOrDefault("BFT_ENABLED", true), @@ -248,6 +358,19 @@ func (c *Config) Validate() error { return fmt.Errorf("invalid log level: %s", c.Logging.Level) } + // Validate sharding configuration + if !c.Sharding.Mode.IsValid() { + return fmt.Errorf("invalid sharding mode: %s, must be one of: standalone, parent, child", c.Sharding.Mode) + } + + if c.Sharding.ShardIDLength < 1 || c.Sharding.ShardIDLength > 16 { + return fmt.Errorf("shard ID length must be between 1 and 16 bits, got: %d", c.Sharding.ShardIDLength) + } + + if err := c.Sharding.Validate(); err != nil { + return fmt.Errorf("invalid sharding configuration: %w", err) + } + return nil } diff --git a/internal/gateway/docs_test.go b/internal/gateway/docs_test.go index e62659b..f8e1889 100644 --- a/internal/gateway/docs_test.go +++ b/internal/gateway/docs_test.go @@ -6,6 +6,8 @@ import ( "testing" "github.com/stretchr/testify/require" + + "github.com/unicitynetwork/aggregator-go/internal/config" "github.com/unicitynetwork/aggregator-go/internal/models" "github.com/unicitynetwork/aggregator-go/internal/signing" "github.com/unicitynetwork/aggregator-go/pkg/api" @@ -99,10 +101,10 @@ func TestDocumentationExamplePayload(t *testing.T) { err = json.Unmarshal([]byte(exampleJSON), &commitment) require.NoError(t, err, "Failed to unmarshal commitment") - validator := signing.NewCommitmentValidator() + validator := signing.NewCommitmentValidator(config.ShardingConfig{Mode: config.ShardingModeStandalone}) result := validator.ValidateCommitment(&commitment) - require.Equal(t, signing.ValidationStatusSuccess, result.Status, + require.Equal(t, signing.ValidationStatusSuccess, result.Status, "Commitment validation failed with status: %s, error: %v", result.Status.String(), result.Error) t.Logf("✅ Documentation example payload passes all validation checks!") diff --git a/internal/gateway/handlers.go b/internal/gateway/handlers.go index d08d2dd..6e0b476 100644 --- a/internal/gateway/handlers.go +++ b/internal/gateway/handlers.go @@ -119,3 +119,48 @@ func (s *Server) handleGetBlockCommitments(ctx context.Context, params json.RawM return response, nil } + +// Parent mode handlers + +// handleSubmitShardRoot handles the submit_shard_root method +func (s *Server) handleSubmitShardRoot(ctx context.Context, params json.RawMessage) (interface{}, *jsonrpc.Error) { + var req api.SubmitShardRootRequest + if err := json.Unmarshal(params, &req); err != nil { + return nil, jsonrpc.NewValidationError("Invalid parameters: " + err.Error()) + } + + if req.ShardID <= 1 { + return nil, jsonrpc.NewValidationError("shard ID must be positive and have at least 2 bits") + } + if len(req.RootHash) == 0 { + return nil, jsonrpc.NewValidationError("rootHash is required") + } + + response, err := s.service.SubmitShardRoot(ctx, &req) + if err != nil { + s.logger.WithContext(ctx).Error("Failed to submit shard root", "error", err.Error()) + return nil, jsonrpc.NewError(jsonrpc.InternalErrorCode, "Failed to submit shard root", err.Error()) + } + + return response, nil +} + +// handleGetShardProof handles the get_shard_proof method +func (s *Server) handleGetShardProof(ctx context.Context, params json.RawMessage) (interface{}, *jsonrpc.Error) { + var req api.GetShardProofRequest + if err := json.Unmarshal(params, &req); err != nil { + return nil, jsonrpc.NewValidationError("Invalid parameters: " + err.Error()) + } + + if req.ShardID <= 1 { + return nil, jsonrpc.NewValidationError("shard ID must be positive and have at least 2 bits") + } + + response, err := s.service.GetShardProof(ctx, &req) + if err != nil { + s.logger.WithContext(ctx).Error("Failed to get shard proof", "error", err.Error()) + return nil, jsonrpc.NewError(jsonrpc.InternalErrorCode, "Failed to get shard proof", err.Error()) + } + + return response, nil +} diff --git a/internal/gateway/server.go b/internal/gateway/server.go index 36ebcd7..327917d 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -33,6 +33,10 @@ type Service interface { GetBlock(ctx context.Context, req *api.GetBlockRequest) (*api.GetBlockResponse, error) GetBlockCommitments(ctx context.Context, req *api.GetBlockCommitmentsRequest) (*api.GetBlockCommitmentsResponse, error) GetHealthStatus(ctx context.Context) (*api.HealthStatus, error) + + // Parent mode specific methods (will return errors in standalone mode) + SubmitShardRoot(ctx context.Context, req *api.SubmitShardRootRequest) (*api.SubmitShardRootResponse, error) + GetShardProof(ctx context.Context, req *api.GetShardProofRequest) (*api.GetShardProofResponse, error) } // NewServer creates a new gateway server @@ -117,9 +121,18 @@ func (s *Server) setupJSONRPCHandlers() { s.rpcServer.AddMiddleware(jsonrpc.LoggingMiddleware(s.logger)) s.rpcServer.AddMiddleware(jsonrpc.TimeoutMiddleware(30 * time.Second)) - // Register handlers - s.rpcServer.RegisterMethod("submit_commitment", s.handleSubmitCommitment) - s.rpcServer.RegisterMethod("get_inclusion_proof", s.handleGetInclusionProof) + // Register handlers based on mode + if s.config.Sharding.Mode.IsParent() { + // Parent mode handlers + s.rpcServer.RegisterMethod("submit_shard_root", s.handleSubmitShardRoot) + s.rpcServer.RegisterMethod("get_shard_proof", s.handleGetShardProof) + } else { + // Standalone mode handlers (default) + s.rpcServer.RegisterMethod("submit_commitment", s.handleSubmitCommitment) + s.rpcServer.RegisterMethod("get_inclusion_proof", s.handleGetInclusionProof) + } + + // Common handlers for all modes s.rpcServer.RegisterMethod("get_no_deletion_proof", s.handleGetNoDeletionProof) s.rpcServer.RegisterMethod("get_block_height", s.handleGetBlockHeight) s.rpcServer.RegisterMethod("get_block", s.handleGetBlock) diff --git a/internal/ha/block_syncer.go b/internal/ha/block_syncer.go index adb8a4d..8252465 100644 --- a/internal/ha/block_syncer.go +++ b/internal/ha/block_syncer.go @@ -18,14 +18,16 @@ type blockSyncer struct { logger *logger.Logger storage interfaces.Storage smt *smt.ThreadSafeSMT + shardID api.ShardID stateTracker *state.Tracker } -func newBlockSyncer(logger *logger.Logger, storage interfaces.Storage, smt *smt.ThreadSafeSMT, stateTracker *state.Tracker) *blockSyncer { +func newBlockSyncer(logger *logger.Logger, storage interfaces.Storage, smt *smt.ThreadSafeSMT, shardID api.ShardID, stateTracker *state.Tracker) *blockSyncer { return &blockSyncer{ logger: logger, storage: storage, smt: smt, + shardID: shardID, stateTracker: stateTracker, } } diff --git a/internal/ha/block_syncer_test.go b/internal/ha/block_syncer_test.go index f695ab6..fadaf1d 100644 --- a/internal/ha/block_syncer_test.go +++ b/internal/ha/block_syncer_test.go @@ -19,7 +19,7 @@ import ( ) func TestBlockSync(t *testing.T) { - storage, cleanup := testutil.SetupTestStorage(t, config.Config{ + storage := testutil.SetupTestStorage(t, config.Config{ Database: config.DatabaseConfig{ Database: "test_block_sync", ConnectTimeout: 30 * time.Second, @@ -30,7 +30,6 @@ func TestBlockSync(t *testing.T) { MaxConnIdleTime: 5 * time.Minute, }, }) - defer cleanup() ctx := context.Background() testLogger, err := logger.New("info", "text", "stdout", false) @@ -40,7 +39,7 @@ func TestBlockSync(t *testing.T) { smtInstance := smt.NewSparseMerkleTree(api.SHA256, 16+256) threadSafeSMT := smt.NewThreadSafeSMT(smtInstance) stateTracker := state.NewSyncStateTracker() - syncer := newBlockSyncer(testLogger, storage, threadSafeSMT, stateTracker) + syncer := newBlockSyncer(testLogger, storage, threadSafeSMT, 0, stateTracker) // simulate leader creating a block rootHash := createBlock(ctx, t, storage) @@ -95,7 +94,7 @@ func createBlock(ctx context.Context, t *testing.T, storage *mongodb.Storage) ap rootHash := api.NewHexBytes(tmpSMT.GetRootHash()) // persist block - block := models.NewBlock(blockNumber, "unicity", "1.0", "mainnet", rootHash, nil) + block := models.NewBlock(blockNumber, "unicity", 0, "1.0", "mainnet", rootHash, nil, nil, nil) err = storage.BlockStorage().Store(ctx, block) require.NoError(t, err) diff --git a/internal/ha/ha_manager.go b/internal/ha/ha_manager.go index 69f741f..38598f8 100644 --- a/internal/ha/ha_manager.go +++ b/internal/ha/ha_manager.go @@ -10,6 +10,7 @@ import ( "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" + "github.com/unicitynetwork/aggregator-go/pkg/api" ) type ( @@ -34,7 +35,7 @@ type ( HAManager struct { logger *logger.Logger leaderSelector LeaderSelector - blockSyncer *blockSyncer + blockSyncer *blockSyncer // Optional: nil when block syncing is disabled activatable Activatable syncInterval time.Duration @@ -48,13 +49,20 @@ func NewHAManager(logger *logger.Logger, leaderSelector LeaderSelector, storage interfaces.Storage, smt *smt.ThreadSafeSMT, + shardID api.ShardID, stateTracker *state.Tracker, syncInterval time.Duration, + disableBlockSync bool, // Set true for parent mode where block syncing is not needed ) *HAManager { + var syncer *blockSyncer + if !disableBlockSync { + syncer = newBlockSyncer(logger, storage, smt, shardID, stateTracker) + } + return &HAManager{ logger: logger, leaderSelector: leaderSelector, - blockSyncer: newBlockSyncer(logger, storage, smt, stateTracker), + blockSyncer: syncer, activatable: activatable, syncInterval: syncInterval, } @@ -109,10 +117,17 @@ func (ham *HAManager) onTick(ctx context.Context, wasLeader bool) (bool, error) ham.logger.WithContext(ctx).Debug("leader is already being synced") return isLeader, nil } - if err := ham.blockSyncer.syncToLatestBlock(ctx); err != nil { - // Log the error but continue, as we might still need to handle a leadership change. - ham.logger.Error("failed to sync smt to latest block", "error", err) + + // Only sync blocks if blockSyncer is enabled (regular aggregator mode) + if ham.blockSyncer != nil { + if err := ham.blockSyncer.syncToLatestBlock(ctx); err != nil { + // Log the error but continue, as we might still need to handle a leadership change. + ham.logger.Error("failed to sync smt to latest block", "error", err) + } + } else { + ham.logger.WithContext(ctx).Debug("block syncing disabled (parent mode), skipping SMT sync") } + if !wasLeader && isLeader { ham.logger.Info("Transitioning to LEADER") if err := ham.activatable.Activate(ctx); err != nil { diff --git a/internal/ha/ha_manager_test.go b/internal/ha/ha_manager_test.go index 90d26f7..4573e50 100644 --- a/internal/ha/ha_manager_test.go +++ b/internal/ha/ha_manager_test.go @@ -44,7 +44,7 @@ func (m *mockActivatable) Deactivate(_ context.Context) error { } func TestHAManager(t *testing.T) { - storage, cleanup := testutil.SetupTestStorage(t, config.Config{ + storage := testutil.SetupTestStorage(t, config.Config{ Database: config.DatabaseConfig{ Database: "test_block_sync", ConnectTimeout: 30 * time.Second, @@ -55,7 +55,6 @@ func TestHAManager(t *testing.T) { MaxConnIdleTime: 5 * time.Minute, }, }) - defer cleanup() ctx := context.Background() cfg := &config.Config{ @@ -72,7 +71,8 @@ func TestHAManager(t *testing.T) { callback := newMockActivatable() smtInstance := smt.NewThreadSafeSMT(smt.NewSparseMerkleTree(api.SHA256, 16+256)) stateTracker := state.NewSyncStateTracker() - ham := NewHAManager(testLogger, callback, mockLeader, storage, smtInstance, stateTracker, cfg.Processing.RoundDuration) + disableBlockSync := false + ham := NewHAManager(testLogger, callback, mockLeader, storage, smtInstance, 0, stateTracker, cfg.Processing.RoundDuration, disableBlockSync) // verify Activate/Deactivate has not been called initially require.Equal(t, int32(0), callback.activateCalled.Load(), "Activate should not be called initially") diff --git a/internal/ha/leader_election_test.go b/internal/ha/leader_election_test.go index 0248dd4..b904378 100644 --- a/internal/ha/leader_election_test.go +++ b/internal/ha/leader_election_test.go @@ -33,8 +33,7 @@ var conf = config.Config{ } func TestLeaderElection_LockContention(t *testing.T) { - storage, cleanup := testutil.SetupTestStorage(t, conf) - defer cleanup() + storage := testutil.SetupTestStorage(t, conf) log, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) @@ -78,8 +77,7 @@ func TestLeaderElection_LockContention(t *testing.T) { } func TestLeaderElection_Failover(t *testing.T) { - storage, cleanup := testutil.SetupTestStorage(t, conf) - defer cleanup() + storage := testutil.SetupTestStorage(t, conf) log, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) diff --git a/internal/models/aggregator_record.go b/internal/models/aggregator_record.go index f2b5217..7c5166c 100644 --- a/internal/models/aggregator_record.go +++ b/internal/models/aggregator_record.go @@ -2,41 +2,35 @@ package models import ( "fmt" - "strconv" "time" + "go.mongodb.org/mongo-driver/bson/primitive" + "github.com/unicitynetwork/aggregator-go/pkg/api" ) // AggregatorRecord represents a finalized commitment with proof data type AggregatorRecord struct { - RequestID api.RequestID `json:"requestId" bson:"requestId"` - TransactionHash api.TransactionHash `json:"transactionHash" bson:"transactionHash"` - Authenticator Authenticator `json:"authenticator" bson:"authenticator"` - AggregateRequestCount uint64 `json:"aggregateRequestCount" bson:"aggregateRequestCount"` - BlockNumber *api.BigInt `json:"blockNumber" bson:"blockNumber"` - LeafIndex *api.BigInt `json:"leafIndex" bson:"leafIndex"` - CreatedAt *api.Timestamp `json:"createdAt" bson:"createdAt"` - FinalizedAt *api.Timestamp `json:"finalizedAt" bson:"finalizedAt"` + RequestID api.RequestID `json:"requestId"` + TransactionHash api.TransactionHash `json:"transactionHash"` + Authenticator Authenticator `json:"authenticator"` + AggregateRequestCount uint64 `json:"aggregateRequestCount"` + BlockNumber *api.BigInt `json:"blockNumber"` + LeafIndex *api.BigInt `json:"leafIndex"` + CreatedAt *api.Timestamp `json:"createdAt"` + FinalizedAt *api.Timestamp `json:"finalizedAt"` } // AggregatorRecordBSON represents the BSON version of AggregatorRecord for MongoDB storage type AggregatorRecordBSON struct { - RequestID string `bson:"requestId"` - TransactionHash string `bson:"transactionHash"` - Authenticator AuthenticatorBSON `bson:"authenticator"` - AggregateRequestCount uint64 `bson:"aggregateRequestCount"` - BlockNumber string `bson:"blockNumber"` - LeafIndex string `bson:"leafIndex"` - CreatedAt string `bson:"createdAt"` - FinalizedAt string `bson:"finalizedAt"` -} - -type AuthenticatorBSON struct { - Algorithm string `bson:"algorithm"` - PublicKey string `bson:"publicKey"` - Signature string `bson:"signature"` - StateHash string `bson:"stateHash"` + RequestID string `bson:"requestId"` + TransactionHash string `bson:"transactionHash"` + Authenticator AuthenticatorBSON `bson:"authenticator"` + AggregateRequestCount uint64 `bson:"aggregateRequestCount"` + BlockNumber primitive.Decimal128 `bson:"blockNumber"` + LeafIndex primitive.Decimal128 `bson:"leafIndex"` + CreatedAt time.Time `bson:"createdAt"` + FinalizedAt time.Time `bson:"finalizedAt"` } // NewAggregatorRecord creates a new aggregator record from a commitment @@ -54,92 +48,58 @@ func NewAggregatorRecord(commitment *Commitment, blockNumber, leafIndex *api.Big } // ToBSON converts AggregatorRecord to AggregatorRecordBSON for MongoDB storage -func (ar *AggregatorRecord) ToBSON() *AggregatorRecordBSON { +func (ar *AggregatorRecord) ToBSON() (*AggregatorRecordBSON, error) { + blockNumber, err := primitive.ParseDecimal128(ar.BlockNumber.String()) + if err != nil { + return nil, fmt.Errorf("error converting block number to decimal-128: %w", err) + } + leafIndex, err := primitive.ParseDecimal128(ar.LeafIndex.String()) + if err != nil { + return nil, fmt.Errorf("error converting leaf index to decimal-128: %w", err) + } return &AggregatorRecordBSON{ - RequestID: string(ar.RequestID), - TransactionHash: string(ar.TransactionHash), - Authenticator: AuthenticatorBSON{ - Algorithm: ar.Authenticator.Algorithm, - PublicKey: ar.Authenticator.PublicKey.String(), - Signature: ar.Authenticator.Signature.String(), - StateHash: ar.Authenticator.StateHash.String(), - }, + RequestID: ar.RequestID.String(), + TransactionHash: ar.TransactionHash.String(), + Authenticator: ar.Authenticator.ToBSON(), AggregateRequestCount: ar.AggregateRequestCount, - BlockNumber: ar.BlockNumber.String(), - LeafIndex: ar.LeafIndex.String(), - CreatedAt: strconv.FormatInt(ar.CreatedAt.UnixMilli(), 10), - FinalizedAt: strconv.FormatInt(ar.FinalizedAt.UnixMilli(), 10), - } + BlockNumber: blockNumber, + LeafIndex: leafIndex, + CreatedAt: ar.CreatedAt.Time, + FinalizedAt: ar.FinalizedAt.Time, + }, nil } // FromBSON converts AggregatorRecordBSON back to AggregatorRecord func (arb *AggregatorRecordBSON) FromBSON() (*AggregatorRecord, error) { - //requestID, err := NewRequestID(arb.RequestID) - //if err != nil { - // return nil, fmt.Errorf("failed to parse requestID: %w", err) - //} - // - //transactionHash, err := NewTransactionHash(arb.TransactionHash) - //if err != nil { - // return nil, fmt.Errorf("failed to parse transactionHash: %w", err) - //} - // - blockNumber, err := api.NewBigIntFromString(arb.BlockNumber) + blockNumber, err := api.NewBigIntFromString(arb.BlockNumber.String()) if err != nil { return nil, fmt.Errorf("failed to parse blockNumber: %w", err) } - leafIndex, err := api.NewBigIntFromString(arb.LeafIndex) + leafIndex, err := api.NewBigIntFromString(arb.LeafIndex.String()) if err != nil { return nil, fmt.Errorf("failed to parse leafIndex: %w", err) } - publicKey, err := api.NewHexBytesFromString(arb.Authenticator.PublicKey) - if err != nil { - return nil, fmt.Errorf("failed to parse publicKey: %w", err) - } - - signature, err := api.NewHexBytesFromString(arb.Authenticator.Signature) - if err != nil { - return nil, fmt.Errorf("failed to parse signature: %w", err) - } - - //stateHash, err := NewHexBytesFromString(arb.Authenticator.StateHash) - //if err != nil { - // return nil, fmt.Errorf("failed to parse stateHash: %w", err) - //} - - createdAtMillis, err := strconv.ParseInt(arb.CreatedAt, 10, 64) - if err != nil { - return nil, fmt.Errorf("failed to parse createdAt: %w", err) - } - createdAt := &api.Timestamp{Time: time.UnixMilli(createdAtMillis)} - - finalizedAtMillis, err := strconv.ParseInt(arb.FinalizedAt, 10, 64) - if err != nil { - return nil, fmt.Errorf("failed to parse finalizedAt: %w", err) - } - finalizedAt := &api.Timestamp{Time: time.UnixMilli(finalizedAtMillis)} - // Default AggregateRequestCount to 1 if not present (backward compatibility) aggregateRequestCount := arb.AggregateRequestCount if aggregateRequestCount == 0 { aggregateRequestCount = 1 } + authenticatorBSON, err := arb.Authenticator.FromBSON() + if err != nil { + return nil, fmt.Errorf("failed to parse authenticator: %w", err) + } + return &AggregatorRecord{ - RequestID: api.RequestID(arb.RequestID), - TransactionHash: api.TransactionHash(arb.TransactionHash), - Authenticator: Authenticator{ - Algorithm: arb.Authenticator.Algorithm, - PublicKey: publicKey, - Signature: signature, - StateHash: api.StateHash(arb.Authenticator.StateHash), - }, + RequestID: api.RequestID(arb.RequestID), + TransactionHash: api.TransactionHash(arb.TransactionHash), + Authenticator: *authenticatorBSON, AggregateRequestCount: aggregateRequestCount, BlockNumber: blockNumber, LeafIndex: leafIndex, - CreatedAt: createdAt, - FinalizedAt: finalizedAt, + CreatedAt: api.NewTimestamp(arb.CreatedAt), + FinalizedAt: api.NewTimestamp(arb.FinalizedAt), }, nil } diff --git a/internal/models/aggregator_record_test.go b/internal/models/aggregator_record_test.go index 6d883b0..4d99749 100644 --- a/internal/models/aggregator_record_test.go +++ b/internal/models/aggregator_record_test.go @@ -1,30 +1,37 @@ -package models_test +package models import ( "testing" + "time" "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/bson/primitive" - "github.com/unicitynetwork/aggregator-go/internal/models" + "github.com/unicitynetwork/aggregator-go/pkg/api" ) func TestBackwardCompatibility(t *testing.T) { + blockNumber, err := primitive.ParseDecimal128("100") + require.NoError(t, err) + leafIndex, err := primitive.ParseDecimal128("4") + require.NoError(t, err) + t.Run("FromBSON defaults AggregateRequestCount to 1 when missing", func(t *testing.T) { // Simulate an old record without AggregateRequestCount - bsonRecord := &models.AggregatorRecordBSON{ + bsonRecord := &AggregatorRecordBSON{ RequestID: "0000a1b2c3d4e5f6789012345678901234567890123456789012345678901234567890", TransactionHash: "0000b1b2c3d4e5f6789012345678901234567890123456789012345678901234567890", - Authenticator: models.AuthenticatorBSON{ + Authenticator: AuthenticatorBSON{ Algorithm: "secp256k1", PublicKey: "02345678", Signature: "abcdef12", StateHash: "0000cd60", }, // AggregateRequestCount is intentionally not set (will be 0) - BlockNumber: "100", - LeafIndex: "5", - CreatedAt: "1700000000000", - FinalizedAt: "1700000001000", + BlockNumber: blockNumber, + LeafIndex: leafIndex, + CreatedAt: time.UnixMilli(1700000000000), + FinalizedAt: time.UnixMilli(1700000001000), } record, err := bsonRecord.FromBSON() @@ -37,20 +44,20 @@ func TestBackwardCompatibility(t *testing.T) { t.Run("FromBSON preserves AggregateRequestCount when present", func(t *testing.T) { // New record with AggregateRequestCount - bsonRecord := &models.AggregatorRecordBSON{ + bsonRecord := &AggregatorRecordBSON{ RequestID: "0000a1b2c3d4e5f6789012345678901234567890123456789012345678901234567890", TransactionHash: "0000b1b2c3d4e5f6789012345678901234567890123456789012345678901234567890", - Authenticator: models.AuthenticatorBSON{ + Authenticator: AuthenticatorBSON{ Algorithm: "secp256k1", PublicKey: "02345678", Signature: "abcdef12", StateHash: "0000cd60", }, AggregateRequestCount: 500, - BlockNumber: "100", - LeafIndex: "5", - CreatedAt: "1700000000000", - FinalizedAt: "1700000001000", + BlockNumber: blockNumber, + LeafIndex: leafIndex, + CreatedAt: time.UnixMilli(1700000000000), + FinalizedAt: time.UnixMilli(1700000001000), } record, err := bsonRecord.FromBSON() @@ -63,7 +70,7 @@ func TestBackwardCompatibility(t *testing.T) { t.Run("TotalCommitments calculation handles mixed old and new records", func(t *testing.T) { // Simulate a mix of old and new records - records := []*models.AggregatorRecord{ + records := []*AggregatorRecord{ // Old record (would have AggregateRequestCount = 0, treated as 1) {AggregateRequestCount: 1}, // New records with explicit counts @@ -84,3 +91,48 @@ func TestBackwardCompatibility(t *testing.T) { require.Equal(t, uint64(137), totalCommitments) }) } + +func TestAggregatorRecordSerialization(t *testing.T) { + // Create AggregatorRecord + originalRequestID := api.RequestID("0000a1b2c3d4e5f6789012345678901234567890123456789012345678901234567890") + originalTransactionHash := api.TransactionHash("0000b1b2c3d4e5f6789012345678901234567890123456789012345678901234567890") + originalBlockNumber, err := api.NewBigIntFromString("123") + require.NoError(t, err) + originalLeafIndex, err := api.NewBigIntFromString("456") + require.NoError(t, err) + record := &AggregatorRecord{ + RequestID: originalRequestID, + TransactionHash: originalTransactionHash, + Authenticator: Authenticator{ + Algorithm: "secp256k1", + PublicKey: api.HexBytes("02345678"), + Signature: api.HexBytes("abcdef12"), + StateHash: api.StateHash("0000cd60"), + }, + AggregateRequestCount: 1, + BlockNumber: originalBlockNumber, + LeafIndex: originalLeafIndex, + CreatedAt: api.Now(), + FinalizedAt: api.Now(), + } + + // Convert to BSON + bsonRecord, err := record.ToBSON() + require.NoError(t, err) + require.NotNil(t, bsonRecord) + + // Verify RequestID and TransactionHash in BSON format + require.Equal(t, string(originalRequestID), bsonRecord.RequestID) + require.Equal(t, string(originalTransactionHash), bsonRecord.TransactionHash) + + // Convert back from BSON + unmarshaledRecord, err := bsonRecord.FromBSON() + require.NoError(t, err) + require.NotNil(t, unmarshaledRecord) + + // Verify RequestID and TransactionHash are preserved + require.Equal(t, originalRequestID, unmarshaledRecord.RequestID) + require.Equal(t, originalTransactionHash, unmarshaledRecord.TransactionHash) + require.Equal(t, originalBlockNumber.String(), unmarshaledRecord.BlockNumber.String()) + require.Equal(t, originalLeafIndex.String(), unmarshaledRecord.LeafIndex.String()) +} diff --git a/internal/models/authenticator.go b/internal/models/authenticator.go new file mode 100644 index 0000000..3d5935c --- /dev/null +++ b/internal/models/authenticator.go @@ -0,0 +1,59 @@ +package models + +import ( + "fmt" + + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +// Authenticator represents the authentication data for a commitment +type Authenticator struct { + Algorithm string `json:"algorithm" bson:"algorithm"` + PublicKey api.HexBytes `json:"publicKey" bson:"publicKey"` + Signature api.HexBytes `json:"signature" bson:"signature"` + StateHash api.StateHash `json:"stateHash" bson:"stateHash"` +} + +type AuthenticatorBSON struct { + Algorithm string `bson:"algorithm"` + PublicKey string `bson:"publicKey"` + Signature string `bson:"signature"` + StateHash string `bson:"stateHash"` +} + +func (a *Authenticator) ToAPI() *api.Authenticator { + return &api.Authenticator{ + Algorithm: a.Algorithm, + PublicKey: a.PublicKey, + Signature: a.Signature, + StateHash: a.StateHash, + } +} + +func (a *Authenticator) ToBSON() AuthenticatorBSON { + return AuthenticatorBSON{ + Algorithm: a.Algorithm, + PublicKey: a.PublicKey.String(), + Signature: a.Signature.String(), + StateHash: a.StateHash.String(), + } +} + +func (ab *AuthenticatorBSON) FromBSON() (*Authenticator, error) { + publicKey, err := api.NewHexBytesFromString(ab.PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to parse publicKey: %w", err) + } + + signature, err := api.NewHexBytesFromString(ab.Signature) + if err != nil { + return nil, fmt.Errorf("failed to parse signature: %w", err) + } + + return &Authenticator{ + Algorithm: ab.Algorithm, + PublicKey: publicKey, + Signature: signature, + StateHash: api.StateHash(ab.StateHash), + }, nil +} diff --git a/internal/models/block.go b/internal/models/block.go index fcf153d..d81715c 100644 --- a/internal/models/block.go +++ b/internal/models/block.go @@ -1,8 +1,8 @@ package models import ( + "encoding/json" "fmt" - "strconv" "time" "go.mongodb.org/mongo-driver/bson/primitive" @@ -12,49 +12,61 @@ import ( // Block represents a blockchain block type Block struct { - Index *api.BigInt `json:"index" bson:"index"` - ChainID string `json:"chainId" bson:"chainId"` - Version string `json:"version" bson:"version"` - ForkID string `json:"forkId" bson:"forkId"` - RootHash api.HexBytes `json:"rootHash" bson:"rootHash"` - PreviousBlockHash api.HexBytes `json:"previousBlockHash" bson:"previousBlockHash"` - NoDeletionProofHash api.HexBytes `json:"noDeletionProofHash" bson:"noDeletionProofHash,omitempty"` - CreatedAt *api.Timestamp `json:"createdAt" bson:"createdAt"` - UnicityCertificate api.HexBytes `json:"unicityCertificate" bson:"unicityCertificate"` + Index *api.BigInt `json:"index"` + ChainID string `json:"chainId"` + ShardID api.ShardID `json:"shardId"` + Version string `json:"version"` + ForkID string `json:"forkId"` + RootHash api.HexBytes `json:"rootHash"` + PreviousBlockHash api.HexBytes `json:"previousBlockHash"` + NoDeletionProofHash api.HexBytes `json:"noDeletionProofHash"` + CreatedAt *api.Timestamp `json:"createdAt"` + UnicityCertificate api.HexBytes `json:"unicityCertificate"` + ParentMerkleTreePath *api.MerkleTreePath `json:"parentMerkleTreePath,omitempty"` // child mode only } // BlockBSON represents the BSON version of Block for MongoDB storage type BlockBSON struct { Index primitive.Decimal128 `bson:"index"` ChainID string `bson:"chainId"` + ShardID api.ShardID `bson:"shardId"` Version string `bson:"version"` ForkID string `bson:"forkId"` RootHash string `bson:"rootHash"` PreviousBlockHash string `bson:"previousBlockHash"` NoDeletionProofHash string `bson:"noDeletionProofHash,omitempty"` - CreatedAt string `bson:"createdAt"` + CreatedAt time.Time `bson:"createdAt"` UnicityCertificate string `bson:"unicityCertificate"` + MerkleTreePath string `bson:"merkleTreePath,omitempty"` // child mode only } // ToBSON converts Block to BlockBSON for MongoDB storage -func (b *Block) ToBSON() *BlockBSON { +func (b *Block) ToBSON() (*BlockBSON, error) { indexDecimal, err := primitive.ParseDecimal128(b.Index.String()) if err != nil { - // This should never happen with valid BigInt, but fallback to zero - indexDecimal = primitive.NewDecimal128(0, 0) + return nil, fmt.Errorf("error converting block index to decimal-128: %w", err) + } + var merkleTreePath string + if b.ParentMerkleTreePath != nil { + merkleTreePathJson, err := json.Marshal(b.ParentMerkleTreePath) + if err != nil { + return nil, fmt.Errorf("failed to marshal parent merkle tree path: %w", err) + } + merkleTreePath = api.NewHexBytes(merkleTreePathJson).String() } - return &BlockBSON{ Index: indexDecimal, ChainID: b.ChainID, + ShardID: b.ShardID, Version: b.Version, ForkID: b.ForkID, RootHash: b.RootHash.String(), PreviousBlockHash: b.PreviousBlockHash.String(), NoDeletionProofHash: b.NoDeletionProofHash.String(), - CreatedAt: strconv.FormatInt(b.CreatedAt.UnixMilli(), 10), + CreatedAt: b.CreatedAt.Time, UnicityCertificate: b.UnicityCertificate.String(), - } + MerkleTreePath: merkleTreePath, + }, nil } // FromBSON converts BlockBSON back to Block @@ -64,12 +76,6 @@ func (bb *BlockBSON) FromBSON() (*Block, error) { return nil, fmt.Errorf("failed to parse index: %w", err) } - createdAtMillis, err := strconv.ParseInt(bb.CreatedAt, 10, 64) - if err != nil { - return nil, fmt.Errorf("failed to parse createdAt: %w", err) - } - createdAt := &api.Timestamp{Time: time.UnixMilli(createdAtMillis)} - rootHash, err := api.NewHexBytesFromString(bb.RootHash) if err != nil { return nil, fmt.Errorf("failed to parse rootHash: %w", err) @@ -85,51 +91,50 @@ func (bb *BlockBSON) FromBSON() (*Block, error) { return nil, fmt.Errorf("failed to parse unicityCertificate: %w", err) } + var parentMerkleTreePath *api.MerkleTreePath + if bb.MerkleTreePath != "" { + hexBytes, err := api.NewHexBytesFromString(bb.MerkleTreePath) + if err != nil { + return nil, fmt.Errorf("failed to parse parentMerkleTreePath: %w", err) + } + parentMerkleTreePath = &api.MerkleTreePath{} + if err := json.Unmarshal(hexBytes, parentMerkleTreePath); err != nil { + return nil, fmt.Errorf("failed to parse parentMerkleTreePath: %w", err) + } + } + noDeletionProofHash, err := api.NewHexBytesFromString(bb.NoDeletionProofHash) if err != nil { return nil, fmt.Errorf("failed to parse noDeletionProofHash: %w", err) } - block := &Block{ - Index: index, - ChainID: bb.ChainID, - Version: bb.Version, - ForkID: bb.ForkID, - RootHash: rootHash, - PreviousBlockHash: previousBlockHash, - CreatedAt: createdAt, - UnicityCertificate: unicityCertificate, - NoDeletionProofHash: noDeletionProofHash, - } - - return block, nil + return &Block{ + Index: index, + ChainID: bb.ChainID, + ShardID: bb.ShardID, + Version: bb.Version, + ForkID: bb.ForkID, + RootHash: rootHash, + PreviousBlockHash: previousBlockHash, + NoDeletionProofHash: noDeletionProofHash, + CreatedAt: api.NewTimestamp(bb.CreatedAt), + UnicityCertificate: unicityCertificate, + ParentMerkleTreePath: parentMerkleTreePath, + }, nil } // NewBlock creates a new block -func NewBlock(index *api.BigInt, chainID, version, forkID string, rootHash, previousBlockHash api.HexBytes) *Block { +func NewBlock(index *api.BigInt, chainID string, shardID api.ShardID, version, forkID string, rootHash, previousBlockHash, uc api.HexBytes, parentMerkleTreePath *api.MerkleTreePath) *Block { return &Block{ - Index: index, - ChainID: chainID, - Version: version, - ForkID: forkID, - RootHash: rootHash, - PreviousBlockHash: previousBlockHash, - CreatedAt: api.Now(), - } -} - -// SmtNode represents a Sparse Merkle Tree node -type SmtNode struct { - Key api.HexBytes `json:"key" bson:"key"` - Value api.HexBytes `json:"value" bson:"value"` - CreatedAt *api.Timestamp `json:"createdAt" bson:"createdAt"` -} - -// NewSmtNode creates a new SMT node -func NewSmtNode(key, value api.HexBytes) *SmtNode { - return &SmtNode{ - Key: key, - Value: value, - CreatedAt: api.Now(), + Index: index, + ChainID: chainID, + ShardID: shardID, + Version: version, + ForkID: forkID, + RootHash: rootHash, + PreviousBlockHash: previousBlockHash, + CreatedAt: api.Now(), + UnicityCertificate: uc, + ParentMerkleTreePath: parentMerkleTreePath, } } diff --git a/internal/models/block_record.go b/internal/models/block_record.go index 4eb19c4..980ee59 100644 --- a/internal/models/block_record.go +++ b/internal/models/block_record.go @@ -2,7 +2,6 @@ package models import ( "fmt" - "strconv" "time" "go.mongodb.org/mongo-driver/bson/primitive" @@ -20,8 +19,8 @@ type BlockRecords struct { // BlockRecordsBSON is the MongoDB representation of BlockRecords type BlockRecordsBSON struct { BlockNumber primitive.Decimal128 `bson:"blockNumber"` - RequestIDs []api.RequestID `bson:"requestIds"` - CreatedAt string `bson:"createdAt"` + RequestIDs []string `bson:"requestIds"` + CreatedAt time.Time `bson:"createdAt"` } // NewBlockRecords creates a new block records entry @@ -34,36 +33,39 @@ func NewBlockRecords(blockNumber *api.BigInt, requestIDs []api.RequestID) *Block } // ToBSON converts BlockRecords to BlockRecordsBSON -func (br *BlockRecords) ToBSON() *BlockRecordsBSON { +func (br *BlockRecords) ToBSON() (*BlockRecordsBSON, error) { blockNumberDecimal, err := primitive.ParseDecimal128(br.BlockNumber.String()) if err != nil { - // This should never happen with valid BigInt, but fallback to zero - blockNumberDecimal = primitive.NewDecimal128(0, 0) + return nil, fmt.Errorf("error converting block number to decimal: %w", err) + } + + requestIDs := make([]string, len(br.RequestIDs)) + for i, r := range br.RequestIDs { + requestIDs[i] = r.String() } return &BlockRecordsBSON{ BlockNumber: blockNumberDecimal, - RequestIDs: br.RequestIDs, - CreatedAt: strconv.FormatInt(br.CreatedAt.UnixMilli(), 10), - } + RequestIDs: requestIDs, + CreatedAt: br.CreatedAt.Time, + }, nil } // FromBSON converts BlockRecordsBSON to BlockRecords func (brb *BlockRecordsBSON) FromBSON() (*BlockRecords, error) { - blockNumber, err := api.NewBigIntFromString(brb.BlockNumber.String()) + blockNumber, _, err := brb.BlockNumber.BigInt() if err != nil { return nil, fmt.Errorf("failed to parse blockNumber: %w", err) } - createdAtMillis, err := strconv.ParseInt(brb.CreatedAt, 10, 64) - if err != nil { - return nil, fmt.Errorf("failed to parse createdAt: %w", err) + requestIDs := make([]api.RequestID, len(brb.RequestIDs)) + for i, r := range brb.RequestIDs { + requestIDs[i] = api.RequestID(r) } - createdAt := &api.Timestamp{Time: time.UnixMilli(createdAtMillis)} return &BlockRecords{ - BlockNumber: blockNumber, - RequestIDs: brb.RequestIDs, - CreatedAt: createdAt, + BlockNumber: api.NewBigInt(blockNumber), + RequestIDs: requestIDs, + CreatedAt: api.NewTimestamp(brb.CreatedAt), }, nil } diff --git a/internal/models/block_test.go b/internal/models/block_test.go new file mode 100644 index 0000000..777962b --- /dev/null +++ b/internal/models/block_test.go @@ -0,0 +1,42 @@ +package models + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +func TestBlock_ToBSONFromBSON(t *testing.T) { + block := createTestBlock() + + blockBSON, err := block.ToBSON() + require.NoError(t, err) + + blockFromBSON, err := blockBSON.FromBSON() + require.NoError(t, err) + + require.Equal(t, block, blockFromBSON) +} + +func createTestBlock() *Block { + randomHash, _ := api.NewHexBytesFromString("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + + return &Block{ + Index: api.NewBigInt(big.NewInt(1)), + ChainID: "test-chain-id", + ShardID: 0b11, + Version: "1.0.0", + ForkID: "test-fork", + RootHash: randomHash, + PreviousBlockHash: randomHash, + NoDeletionProofHash: randomHash, + CreatedAt: api.Now(), + UnicityCertificate: randomHash, + ParentMerkleTreePath: &api.MerkleTreePath{ + Root: randomHash.String(), + }, + } +} diff --git a/internal/models/commitment.go b/internal/models/commitment.go index 0540525..f47c940 100644 --- a/internal/models/commitment.go +++ b/internal/models/commitment.go @@ -3,6 +3,7 @@ package models import ( "crypto/sha256" "fmt" + "time" "github.com/fxamacker/cbor/v2" "go.mongodb.org/mongo-driver/bson/primitive" @@ -22,12 +23,15 @@ type Commitment struct { StreamID string `json:"-" bson:"-"` // Redis stream ID used for stream acknowledgements } -// Authenticator represents the authentication data for a commitment -type Authenticator struct { - Algorithm string `json:"algorithm" bson:"algorithm"` - PublicKey api.HexBytes `json:"publicKey" bson:"publicKey"` - Signature api.HexBytes `json:"signature" bson:"signature"` - StateHash api.StateHash `json:"stateHash" bson:"stateHash"` +// CommitmentBSON represents the BSON version of Commitment for MongoDB storage +type CommitmentBSON struct { + ID primitive.ObjectID `bson:"_id,omitempty"` + RequestID string `bson:"requestId"` + TransactionHash string `bson:"transactionHash"` + Authenticator AuthenticatorBSON `bson:"authenticator"` + AggregateRequestCount uint64 `bson:"aggregateRequestCount"` + CreatedAt time.Time `bson:"createdAt"` + ProcessedAt *time.Time `bson:"processedAt,omitempty"` } // NewCommitment creates a new commitment @@ -52,6 +56,44 @@ func NewCommitmentWithAggregate(requestID api.RequestID, transactionHash api.Tra } } +// ToBSON converts Commitment to CommitmentBSON for MongoDB storage +func (c *Commitment) ToBSON() *CommitmentBSON { + var processedAt *time.Time + if c.ProcessedAt != nil { + processedAt = &c.ProcessedAt.Time + } + return &CommitmentBSON{ + ID: c.ID, + RequestID: c.RequestID.String(), + TransactionHash: c.TransactionHash.String(), + Authenticator: c.Authenticator.ToBSON(), + AggregateRequestCount: c.AggregateRequestCount, + CreatedAt: c.CreatedAt.Time, + ProcessedAt: processedAt, + } +} + +// FromBSON converts CommitmentBSON back to Commitment +func (cb *CommitmentBSON) FromBSON() (*Commitment, error) { + var processedAt *api.Timestamp + if cb.ProcessedAt != nil { + processedAt = api.NewTimestamp(*cb.ProcessedAt) + } + authenticator, err := cb.Authenticator.FromBSON() + if err != nil { + return nil, err + } + return &Commitment{ + ID: cb.ID, + RequestID: api.RequestID(cb.RequestID), + TransactionHash: api.TransactionHash(cb.TransactionHash), + Authenticator: *authenticator, + AggregateRequestCount: cb.AggregateRequestCount, + CreatedAt: api.NewTimestamp(cb.CreatedAt), + ProcessedAt: processedAt, + }, nil +} + // CreateLeafValue creates the value to store in the SMT leaf for a commitment // This matches the TypeScript LeafValue.create() method exactly: // - CBOR encode the authenticator as an array [algorithm, publicKey, signature, stateHashImprint] diff --git a/internal/models/health_status.go b/internal/models/health_status.go index a5f6fb1..ad4e252 100644 --- a/internal/models/health_status.go +++ b/internal/models/health_status.go @@ -1,19 +1,23 @@ package models +import "github.com/unicitynetwork/aggregator-go/pkg/api" + // HealthStatus represents the health status of the service type HealthStatus struct { Status string `json:"status"` Role string `json:"role"` ServerID string `json:"serverId"` + Sharding api.Sharding `json:"sharding"` Details map[string]string `json:"details,omitempty"` } // NewHealthStatus creates a new health status -func NewHealthStatus(role, serverID string) *HealthStatus { +func NewHealthStatus(role, serverID string, sharding api.Sharding) *HealthStatus { return &HealthStatus{ Status: "ok", Role: role, ServerID: serverID, + Sharding: sharding, Details: make(map[string]string), } } diff --git a/internal/models/shard.go b/internal/models/shard.go new file mode 100644 index 0000000..b0c4b6d --- /dev/null +++ b/internal/models/shard.go @@ -0,0 +1,49 @@ +package models + +import ( + "fmt" + "math/big" + "time" + + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +// ShardRootUpdate represents an incoming shard root submission from a child aggregator. +type ShardRootUpdate struct { + ShardID api.ShardID + RootHash api.HexBytes // Raw root hash from child SMT + Timestamp time.Time +} + +// NewShardRootUpdate creates a new shard root update +func NewShardRootUpdate(shardID api.ShardID, rootHash api.HexBytes) *ShardRootUpdate { + return &ShardRootUpdate{ + ShardID: shardID, + RootHash: rootHash, + Timestamp: time.Now(), + } +} + +// GetPath returns the shard ID as a big.Int for SMT operations +func (sru *ShardRootUpdate) GetPath() *big.Int { + return new(big.Int).SetInt64(int64(sru.ShardID)) +} + +// Validate validates the shard root update +func (sru *ShardRootUpdate) Validate() error { + if sru.ShardID <= 1 { + return fmt.Errorf("shard ID must be positive and have at least 2 bits") + } + + if len(sru.RootHash) == 0 { + return fmt.Errorf("root hash cannot be empty") + } + + return nil +} + +// String returns a string representation of the shard root update +func (sru *ShardRootUpdate) String() string { + return fmt.Sprintf("ShardRootUpdate{ShardID: %d, RootHash: %s, Timestamp: %v}", + sru.ShardID, sru.RootHash.String(), sru.Timestamp) +} diff --git a/internal/models/smt_node.go b/internal/models/smt_node.go new file mode 100644 index 0000000..b780c1a --- /dev/null +++ b/internal/models/smt_node.go @@ -0,0 +1,54 @@ +package models + +import ( + "fmt" + "time" + + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +// SmtNode represents a Sparse Merkle Tree node +type SmtNode struct { + Key api.HexBytes `json:"key"` + Value api.HexBytes `json:"value"` + CreatedAt *api.Timestamp `json:"createdAt"` +} + +type SmtNodeBSON struct { + Key string `bson:"key"` + Value string `bson:"value"` + CreatedAt time.Time `bson:"createdAt"` +} + +func NewSmtNode(key, value api.HexBytes) *SmtNode { + return &SmtNode{ + Key: key, + Value: value, + CreatedAt: api.Now(), + } +} + +func (n *SmtNode) ToBSON() *SmtNodeBSON { + return &SmtNodeBSON{ + Key: n.Key.String(), + Value: n.Value.String(), + CreatedAt: n.CreatedAt.Time, + } +} + +func (nb *SmtNodeBSON) FromBSON() (*SmtNode, error) { + key, err := api.NewHexBytesFromString(nb.Key) + if err != nil { + return nil, fmt.Errorf("failed to parse key: %w", err) + } + + val, err := api.NewHexBytesFromString(nb.Value) + if err != nil { + return nil, fmt.Errorf("failed to parse value: %w", err) + } + return &SmtNode{ + Key: key, + Value: val, + CreatedAt: api.NewTimestamp(nb.CreatedAt), + }, nil +} diff --git a/internal/round/adaptive_timing_test.go b/internal/round/adaptive_timing_test.go index 0133798..294d6e3 100644 --- a/internal/round/adaptive_timing_test.go +++ b/internal/round/adaptive_timing_test.go @@ -15,7 +15,7 @@ import ( "github.com/unicitynetwork/aggregator-go/internal/ha/state" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/models" - "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" + "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/pkg/api" ) @@ -36,7 +36,7 @@ func TestAdaptiveProcessingRatio(t *testing.T) { require.NoError(t, err) // Create round manager - rm, err := NewRoundManager(context.Background(), cfg, testLogger, nil, nil, state.NewSyncStateTracker()) + rm, err := NewRoundManager(context.Background(), cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), nil, nil, nil, state.NewSyncStateTracker()) require.NoError(t, err) // Test initial values @@ -131,7 +131,7 @@ func TestAdaptiveDeadlineCalculation(t *testing.T) { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(context.Background(), cfg, testLogger, nil, nil, state.NewSyncStateTracker()) + rm, err := NewRoundManager(context.Background(), cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), nil, nil, nil, state.NewSyncStateTracker()) require.NoError(t, err) tests := []struct { @@ -198,7 +198,7 @@ func TestSMTUpdateTimeTracking(t *testing.T) { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(context.Background(), cfg, testLogger, nil, nil, state.NewSyncStateTracker()) + rm, err := NewRoundManager(context.Background(), cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), nil, nil, nil, state.NewSyncStateTracker()) require.NoError(t, err) ctx := context.Background() @@ -227,8 +227,8 @@ func TestSMTUpdateTimeTracking(t *testing.T) { TransactionHash: api.ImprintHexString("0000" + hex.EncodeToString(txHashBytes)), Authenticator: models.Authenticator{ Algorithm: "secp256k1", - PublicKey: api.HexBytes(append([]byte{0x02}, make([]byte, 32)...)), - Signature: api.HexBytes(make([]byte, 65)), + PublicKey: append([]byte{0x02}, make([]byte, 32)...), + Signature: make([]byte, 65), StateHash: api.ImprintHexString("0000" + hex.EncodeToString(make([]byte, 32))), }, } @@ -261,7 +261,7 @@ func TestStreamingMetrics(t *testing.T) { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(context.Background(), cfg, testLogger, nil, nil, state.NewSyncStateTracker()) + rm, err := NewRoundManager(context.Background(), cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), nil, nil, nil, state.NewSyncStateTracker()) require.NoError(t, err) // Set some test values @@ -314,7 +314,7 @@ func TestAdaptiveTimingIntegration(t *testing.T) { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(context.Background(), cfg, testLogger, nil, nil, state.NewSyncStateTracker()) + rm, err := NewRoundManager(context.Background(), cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), nil, nil, nil, state.NewSyncStateTracker()) require.NoError(t, err) ctx := context.Background() @@ -356,10 +356,3 @@ func TestAdaptiveTimingIntegration(t *testing.T) { }) } } - -// createMockStorage creates a mock storage for testing -func createMockStorage(t *testing.T) interfaces.Storage { - // This is a simplified mock - in production you'd use a proper mock - // For now, we'll return nil which is fine since BFT is disabled - return nil -} diff --git a/internal/round/batch_processor.go b/internal/round/batch_processor.go index 977a227..9b488d9 100644 --- a/internal/round/batch_processor.go +++ b/internal/round/batch_processor.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/unicitynetwork/aggregator-go/internal/config" "github.com/unicitynetwork/aggregator-go/internal/models" "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" @@ -103,30 +104,107 @@ func (rm *RoundManager) proposeBlock(ctx context.Context, blockNumber *api.BigIn return fmt.Errorf("failed to parse root hash %s: %w", rootHash, err) } - block := models.NewBlock( - blockNumber, - "unicity", - "1.0", - "mainnet", - rootHashBytes, - parentHash, - ) - - rm.logger.WithContext(ctx).Info("Sending certification request to BFT client", - "blockNumber", blockNumber.String(), - "bftClientType", fmt.Sprintf("%T", rm.bftClient)) - - if err := rm.bftClient.CertificationRequest(ctx, block); err != nil { - rm.logger.WithContext(ctx).Error("Failed to send certification request", + switch rm.config.Sharding.Mode { + case config.ShardingModeStandalone: + block := models.NewBlock( + blockNumber, + rm.config.Chain.ID, + 0, + rm.config.Chain.Version, + rm.config.Chain.ForkID, + rootHashBytes, + parentHash, + nil, + nil, + ) + rm.logger.WithContext(ctx).Info("Sending certification request to BFT client", "blockNumber", blockNumber.String(), - "error", err.Error()) - return fmt.Errorf("failed to send certification request: %w", err) - } + "bftClientType", fmt.Sprintf("%T", rm.bftClient)) + if err := rm.bftClient.CertificationRequest(ctx, block); err != nil { + rm.logger.WithContext(ctx).Error("Failed to send certification request", + "blockNumber", blockNumber.String(), + "error", err.Error()) + return fmt.Errorf("failed to send certification request: %w", err) + } + rm.logger.WithContext(ctx).Info("Certification request sent successfully", + "blockNumber", blockNumber.String()) + return nil + case config.ShardingModeChild: + rm.logger.WithContext(ctx).Info("Submitting root hash to parent shard", "rootHash", rootHash) + + // Strip algorithm prefix (first 2 bytes) before sending to parent + // Parent SMT stores raw 32-byte hashes, not the full 34-byte format with algorithm ID + // This is required for JoinPaths to work correctly when combining child and parent proofs + if len(rootHashBytes) < 2 { + return fmt.Errorf("root hash too short: expected at least 2 bytes for algorithm prefix, got %d", len(rootHashBytes)) + } + rootHashRaw := rootHashBytes[2:] // Remove algorithm identifier + if len(rootHashRaw) != 32 { + return fmt.Errorf("child root hash has invalid length after stripping prefix: expected 32 bytes, got %d", len(rootHashRaw)) + } - rm.logger.WithContext(ctx).Info("Certification request sent successfully", - "blockNumber", blockNumber.String()) + request := &api.SubmitShardRootRequest{ + ShardID: rm.config.Sharding.Child.ShardID, + RootHash: rootHashRaw, + } + if err := rm.rootClient.SubmitShardRoot(ctx, request); err != nil { + return fmt.Errorf("failed to submit root hash to parent shard: %w", err) + } + rm.logger.Info("Root hash submitted to parent, polling for inclusion proof...", "rootHash", rootHashRaw.String()) + proof, err := rm.pollInclusionProof(ctx, rootHashRaw.String()) + if err != nil { + return fmt.Errorf("failed to poll for parent shard inclusion proof: %w", err) + } + block := models.NewBlock( + blockNumber, + rm.config.Chain.ID, + request.ShardID, + rm.config.Chain.Version, + rm.config.Chain.ForkID, + rootHashBytes, + parentHash, + proof.UnicityCertificate, + proof.MerkleTreePath, + ) + if err := rm.FinalizeBlock(ctx, block); err != nil { + return fmt.Errorf("failed to finalize block: %w", err) + } + nextRoundNumber := big.NewInt(0).Add(blockNumber.Int, big.NewInt(1)) + if err := rm.StartNewRound(ctx, api.NewBigInt(nextRoundNumber)); err != nil { + rm.logger.WithContext(ctx).Error("Failed to start new round after finalization.", "error", err.Error()) + } + rm.logger.WithContext(ctx).Info("Block finalized and new round started", "blockNumber", blockNumber.String()) + return nil + default: + return fmt.Errorf("invalid sharding mode: %s", rm.config.Sharding.Mode) + } +} - return nil +func (rm *RoundManager) pollInclusionProof(ctx context.Context, rootHash string) (*api.RootShardInclusionProof, error) { + pollingCtx, cancel := context.WithTimeout(ctx, rm.config.Sharding.Child.ParentPollTimeout) + defer cancel() + + ticker := time.NewTicker(rm.config.Sharding.Child.ParentPollInterval) + defer ticker.Stop() + + for { + select { + case <-pollingCtx.Done(): + return nil, fmt.Errorf("timed out waiting for parent shard inclusion proof %s", rootHash) + case <-ticker.C: + request := &api.GetShardProofRequest{ShardID: rm.config.Sharding.Child.ShardID} + proof, err := rm.rootClient.GetShardProof(pollingCtx, request) + if err != nil { + return nil, fmt.Errorf("failed to fetch parent shard inclusion proof: %w", err) + } + if proof == nil || !proof.IsValid(rootHash) { + rm.logger.WithContext(ctx).Debug("Parent shard inclusion proof not found, retrying...") + continue + } + rm.logger.WithContext(ctx).Info("Successfully received shard proof from parent") + return proof, nil + } + } } // FinalizeBlock creates and persists a new block with the given data diff --git a/internal/round/factory.go b/internal/round/factory.go new file mode 100644 index 0000000..2cfdd78 --- /dev/null +++ b/internal/round/factory.go @@ -0,0 +1,40 @@ +package round + +import ( + "context" + "fmt" + + "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/ha/state" + "github.com/unicitynetwork/aggregator-go/internal/logger" + "github.com/unicitynetwork/aggregator-go/internal/sharding" + "github.com/unicitynetwork/aggregator-go/internal/smt" + "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +// Manager interface for both standalone and parent round managers +type Manager interface { + Start(ctx context.Context) error + Stop(ctx context.Context) error + Activate(ctx context.Context) error + Deactivate(ctx context.Context) error + GetSMT() *smt.ThreadSafeSMT +} + +// NewManager creates the appropriate round manager based on sharding mode +func NewManager(ctx context.Context, cfg *config.Config, logger *logger.Logger, commitmentQueue interfaces.CommitmentQueue, storage interfaces.Storage, stateTracker *state.Tracker) (Manager, error) { + switch cfg.Sharding.Mode { + case config.ShardingModeStandalone: + smtInstance := smt.NewSparseMerkleTree(api.SHA256, 16+256) + return NewRoundManager(ctx, cfg, logger, smtInstance, commitmentQueue, storage, nil, stateTracker) + case config.ShardingModeParent: + return NewParentRoundManager(ctx, cfg, logger, storage) + case config.ShardingModeChild: + smtInstance := smt.NewChildSparseMerkleTree(api.SHA256, 16+256, cfg.Sharding.Child.ShardID) + rootAggregatorClient := sharding.NewRootAggregatorClient(cfg.Sharding.Child.ParentRpcAddr) + return NewRoundManager(ctx, cfg, logger, smtInstance, commitmentQueue, storage, rootAggregatorClient, stateTracker) + default: + return nil, fmt.Errorf("unsupported sharding mode: %s", cfg.Sharding.Mode) + } +} diff --git a/internal/round/parent_round_manager.go b/internal/round/parent_round_manager.go new file mode 100644 index 0000000..8f8fe27 --- /dev/null +++ b/internal/round/parent_round_manager.go @@ -0,0 +1,504 @@ +package round + +import ( + "context" + "encoding/binary" + "fmt" + "math/big" + "sync" + "time" + + "github.com/unicitynetwork/aggregator-go/internal/bft" + "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/logger" + "github.com/unicitynetwork/aggregator-go/internal/models" + "github.com/unicitynetwork/aggregator-go/internal/smt" + "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +// ParentRound represents a single aggregation round for parent aggregator +type ParentRound struct { + Number *api.BigInt + StartTime time.Time + State RoundState + ShardUpdates map[int]*models.ShardRootUpdate // Latest update per shard path (key is shard ID) + Block *models.Block + + // SMT snapshot for this round - allows accumulating shard changes before committing + Snapshot *smt.ThreadSafeSmtSnapshot + + // Store processed data for persistence during FinalizeBlock + ProcessedShardUpdates map[int]*models.ShardRootUpdate // Shard updates that were actually processed into the parent SMT + + // Timing metrics for this round + ProcessingTime time.Duration + ProposalTime time.Time // When block was proposed to BFT + FinalizationTime time.Time // When block was actually finalized +} + +// ParentRoundManager handles round processing for parent aggregator mode +type ParentRoundManager struct { + config *config.Config + logger *logger.Logger + storage interfaces.Storage + parentSMT *smt.ThreadSafeSMT + bftClient bft.BFTClient + + // Round management + currentRound *ParentRound + roundMutex sync.RWMutex + roundTimer *time.Timer + stopChan chan struct{} + wg sync.WaitGroup + roundDuration time.Duration + + // Metrics + totalRounds int64 + totalShardUpdates int64 +} + +// NewParentRoundManager creates a new parent round manager +func NewParentRoundManager(ctx context.Context, cfg *config.Config, logger *logger.Logger, storage interfaces.Storage) (*ParentRoundManager, error) { + // Initialize parent SMT in parent mode with support for mutable leaves + smtInstance := smt.NewParentSparseMerkleTree(api.SHA256, cfg.Sharding.ShardIDLength) + parentSMT := smt.NewThreadSafeSMT(smtInstance) + + prm := &ParentRoundManager{ + config: cfg, + logger: logger, + storage: storage, + parentSMT: parentSMT, + stopChan: make(chan struct{}), + roundDuration: cfg.Processing.RoundDuration, + } + + // Create BFT client (same logic as regular RoundManager) + if cfg.BFT.Enabled { + var err error + prm.bftClient, err = bft.NewBFTClient(ctx, &cfg.BFT, prm, logger) + if err != nil { + return nil, fmt.Errorf("failed to create BFT client: %w", err) + } + } else { + // Calculate initial block number like regular round manager + nextBlockNumber := api.NewBigInt(nil) + lastBlockNumber := api.NewBigInt(big.NewInt(0)) + if storage != nil && storage.BlockStorage() != nil { + var err error + lastBlockNumber, err = storage.BlockStorage().GetLatestNumber(ctx) + if err != nil { + return nil, fmt.Errorf("failed to fetch latest block number: %w", err) + } + if lastBlockNumber == nil { + lastBlockNumber = api.NewBigInt(big.NewInt(0)) + } + } + nextBlockNumber.Add(lastBlockNumber.Int, big.NewInt(1)) + prm.bftClient = bft.NewBFTClientStub(logger, prm, nextBlockNumber) + } + + return prm, nil +} + +// Start performs initialization (called once at startup) +// Note: SMT reconstruction is done in Activate() when the node becomes leader +func (prm *ParentRoundManager) Start(ctx context.Context) error { + prm.logger.WithContext(ctx).Info("Starting Parent Round Manager", + "roundDuration", prm.roundDuration.String()) + + prm.logger.WithContext(ctx).Info("Parent Round Manager started successfully") + return nil +} + +// Stop gracefully stops the parent round manager (called once at shutdown) +func (prm *ParentRoundManager) Stop(ctx context.Context) error { + prm.logger.WithContext(ctx).Info("Stopping Parent Round Manager") + + // Deactivate first + if err := prm.Deactivate(ctx); err != nil { + prm.logger.WithContext(ctx).Error("Failed to deactivate Parent Round Manager", "error", err.Error()) + } + + // Wait for goroutines to finish + prm.wg.Wait() + + prm.logger.WithContext(ctx).Info("Parent Round Manager stopped successfully") + return nil +} + +// SubmitShardRoot accepts shard root updates during the current round +func (prm *ParentRoundManager) SubmitShardRoot(ctx context.Context, update *models.ShardRootUpdate) error { + prm.roundMutex.Lock() + defer prm.roundMutex.Unlock() + + if prm.currentRound == nil { + return fmt.Errorf("no active round to accept shard root update") + } + + shardKey := update.ShardID + prm.currentRound.ShardUpdates[shardKey] = update + prm.totalShardUpdates++ + + prm.logger.WithContext(ctx).Debug("Shard root update accepted", + "roundNumber", prm.currentRound.Number.String(), + "shardID", update.ShardID, + "rootHash", update.RootHash.String(), + "totalShards", len(prm.currentRound.ShardUpdates)) + + return nil +} + +// StartNewRound begins a new aggregation round (public method for BFT interface) +func (prm *ParentRoundManager) StartNewRound(ctx context.Context, roundNumber *api.BigInt) error { + return prm.startNewRound(ctx, roundNumber) +} + +// startNewRound is the internal implementation +func (prm *ParentRoundManager) startNewRound(ctx context.Context, roundNumber *api.BigInt) error { + prm.logger.WithContext(ctx).Info("Starting new parent round", + "roundNumber", roundNumber.String()) + + prm.roundMutex.Lock() + defer prm.roundMutex.Unlock() + + // Stop existing timer if any + if prm.roundTimer != nil { + prm.roundTimer.Stop() + } + + var shardUpdates map[int]*models.ShardRootUpdate + if prm.currentRound != nil { + shardUpdates = prm.currentRound.ShardUpdates + } else { + shardUpdates = make(map[int]*models.ShardRootUpdate) + } + + // Create new round + prm.currentRound = &ParentRound{ + Number: roundNumber, + StartTime: time.Now(), + State: RoundStateCollecting, + ShardUpdates: shardUpdates, // Reuse the same map + Snapshot: prm.parentSMT.CreateSnapshot(), // Create SMT snapshot for this round + } + + // Set timer for round processing + prm.roundTimer = time.AfterFunc(prm.roundDuration, func() { + prm.logger.WithContext(ctx).Info("Parent round timer fired", + "roundNumber", roundNumber.String(), + "elapsed", prm.roundDuration.String()) + if err := prm.processCurrentRound(ctx); err != nil { + prm.logger.WithContext(ctx).Error("Failed to process parent round", + "roundNumber", roundNumber.String(), + "error", err.Error()) + } + }) + + prm.logger.WithContext(ctx).Info("Parent round started", + "roundNumber", roundNumber.String(), + "duration", prm.roundDuration.String()) + + return nil +} + +// processCurrentRound processes the current round and creates a block +func (prm *ParentRoundManager) processCurrentRound(ctx context.Context) error { + prm.roundMutex.Lock() + + if prm.currentRound == nil { + prm.roundMutex.Unlock() + return fmt.Errorf("no current round to process") + } + + // Capture current round for processing + round := prm.currentRound + round.State = RoundStateProcessing + + round.ProcessedShardUpdates = make(map[int]*models.ShardRootUpdate, len(round.ShardUpdates)) + for shardKey, update := range round.ShardUpdates { + round.ProcessedShardUpdates[shardKey] = update + } + clear(round.ShardUpdates) + + prm.roundMutex.Unlock() + + return prm.processRound(ctx, round) +} + +// processRound processes a specific round +func (prm *ParentRoundManager) processRound(ctx context.Context, round *ParentRound) error { + startTime := time.Now() + + prm.logger.WithContext(ctx).Info("Processing parent round", + "roundNumber", round.Number.String(), + "shardCount", len(round.ProcessedShardUpdates)) + + var parentRootHash api.HexBytes + if len(round.ProcessedShardUpdates) == 0 { + rootHashHex := round.Snapshot.GetRootHash() + parsedRoot, err := api.NewHexBytesFromString(rootHashHex) + if err != nil { + return fmt.Errorf("failed to parse parent SMT root hash %q: %w", rootHashHex, err) + } + parentRootHash = parsedRoot + prm.logger.WithContext(ctx).Info("Empty parent round, using current SMT root hash", + "rootHash", parentRootHash.String()) + } else { + leaves := make([]*smt.Leaf, 0, len(round.ProcessedShardUpdates)) + for _, update := range round.ProcessedShardUpdates { + path := update.GetPath() + + leaf := &smt.Leaf{ + Path: path, + Value: update.RootHash, + } + leaves = append(leaves, leaf) + } + + rootHashStr, err := round.Snapshot.AddLeaves(leaves) + if err != nil { + return fmt.Errorf("failed to add shard leaves to parent SMT snapshot: %w", err) + } + + parsedRoot, err := api.NewHexBytesFromString(rootHashStr) + if err != nil { + return fmt.Errorf("failed to parse updated parent SMT root hash %q: %w", rootHashStr, err) + } + + parentRootHash = parsedRoot + prm.logger.WithContext(ctx).Info("Added shard updates to parent SMT snapshot", + "shardCount", len(round.ProcessedShardUpdates), + "newRootHash", parentRootHash.String()) + } + + var previousBlockHash api.HexBytes + if round.Number.Cmp(big.NewInt(1)) > 0 { + prevBlockNumber := api.NewBigInt(nil) + prevBlockNumber.Set(round.Number.Int) + prevBlockNumber.Sub(prevBlockNumber.Int, big.NewInt(1)) + + prevBlock, err := prm.storage.BlockStorage().GetByNumber(ctx, prevBlockNumber) + if err != nil { + return fmt.Errorf("failed to get previous block %s: %w", prevBlockNumber.String(), err) + } + if prevBlock != nil { + previousBlockHash = prevBlock.RootHash + } + } + + block := models.NewBlock( + round.Number, + prm.config.Chain.ID, + 0, + prm.config.Chain.Version, + prm.config.Chain.ForkID, + parentRootHash, + previousBlockHash, + nil, + nil, + ) + + round.Block = block + round.State = RoundStateFinalizing + round.ProposalTime = time.Now() + + if err := prm.bftClient.CertificationRequest(ctx, block); err != nil { + prm.logger.WithContext(ctx).Error("Failed to send certification request", + "roundNumber", round.Number.String(), + "error", err.Error()) + return fmt.Errorf("failed to send certification request: %w", err) + } + + prm.logger.WithContext(ctx).Info("Certification request sent successfully", + "roundNumber", round.Number.String()) + + round.ProcessingTime = time.Since(startTime) + prm.totalRounds++ + + prm.logger.WithContext(ctx).Info("Parent round processed successfully", + "roundNumber", round.Number.String(), + "processingTime", round.ProcessingTime.String(), + "shardCount", len(round.ProcessedShardUpdates)) + + return nil +} + +// GetCurrentRound returns the current round (for BFT client callback compatibility) +func (prm *ParentRoundManager) GetCurrentRound() interface{} { + prm.roundMutex.RLock() + defer prm.roundMutex.RUnlock() + return prm.currentRound +} + +// GetSMT returns the parent SMT instance +func (prm *ParentRoundManager) GetSMT() *smt.ThreadSafeSMT { + return prm.parentSMT +} + +// Activate starts active round processing (called when node becomes leader in HA mode) +func (prm *ParentRoundManager) Activate(ctx context.Context) error { + prm.logger.WithContext(ctx).Info("Activating parent round manager") + + // Reconstruct parent SMT from current shard states in storage + // This ensures the follower-turned-leader has the latest state + prm.logger.WithContext(ctx).Info("Reconstructing parent SMT from storage on leadership transition") + if err := prm.reconstructParentSMT(ctx); err != nil { + return fmt.Errorf("failed to reconstruct parent SMT on activation: %w", err) + } + + // Start BFT client + if err := prm.bftClient.Start(ctx); err != nil { + return fmt.Errorf("failed to start BFT client: %w", err) + } + + // Get latest block number to determine starting round + latestBlockNumber, err := prm.storage.BlockStorage().GetLatestNumber(ctx) + if err != nil { + return fmt.Errorf("failed to get latest block number: %w", err) + } + + // Calculate next round number + var nextRoundNumber *api.BigInt + if latestBlockNumber == nil { + nextRoundNumber = api.NewBigInt(nil) + nextRoundNumber.SetInt64(1) + } else { + nextRoundNumber = api.NewBigInt(nil) + nextRoundNumber.Set(latestBlockNumber.Int) + nextRoundNumber.Add(nextRoundNumber.Int, big.NewInt(1)) + } + + // Start first round + if err := prm.startNewRound(ctx, nextRoundNumber); err != nil { + return fmt.Errorf("failed to start initial round: %w", err) + } + + prm.logger.WithContext(ctx).Info("Parent round manager activated successfully", + "initialRound", nextRoundNumber.String()) + + return nil +} + +// Deactivate stops active round processing (called when node loses leadership in HA mode) +func (prm *ParentRoundManager) Deactivate(ctx context.Context) error { + prm.logger.WithContext(ctx).Info("Deactivating parent round manager") + + // Stop BFT client + prm.bftClient.Stop() + + // Stop current round timer + prm.roundMutex.Lock() + if prm.roundTimer != nil { + prm.roundTimer.Stop() + prm.roundTimer = nil + } + prm.currentRound = nil + prm.roundMutex.Unlock() + + prm.logger.WithContext(ctx).Info("Parent round manager deactivated successfully") + return nil +} + +// FinalizeBlock is called by BFT client when block is finalized +func (prm *ParentRoundManager) FinalizeBlock(ctx context.Context, block *models.Block) error { + prm.logger.WithContext(ctx).Info("Finalizing parent block", + "blockNumber", block.Index.String()) + + prm.roundMutex.RLock() + var processedShardUpdates map[int]*models.ShardRootUpdate + var snapshot *smt.ThreadSafeSmtSnapshot + if prm.currentRound != nil && prm.currentRound.ProcessedShardUpdates != nil { + processedShardUpdates = make(map[int]*models.ShardRootUpdate) + for shardKey, update := range prm.currentRound.ProcessedShardUpdates { + processedShardUpdates[shardKey] = update + } + snapshot = prm.currentRound.Snapshot + } + prm.roundMutex.RUnlock() + + if len(processedShardUpdates) > 0 { + smtNodes := make([]*models.SmtNode, 0, len(processedShardUpdates)) + for _, update := range processedShardUpdates { + bytes := make([]byte, 4) + binary.BigEndian.PutUint32(bytes, uint32(update.ShardID)) + smtNode := models.NewSmtNode(api.NewHexBytes(bytes), update.RootHash) + smtNodes = append(smtNodes, smtNode) + } + + if err := prm.storage.SmtStorage().UpsertBatch(ctx, smtNodes); err != nil { + return fmt.Errorf("failed to upsert shard states: %w", err) + } + + prm.logger.WithContext(ctx).Info("Updated current shard states in storage", + "blockNumber", block.Index.String(), + "shardCount", len(processedShardUpdates), + "shardIDs", getShardIDs(processedShardUpdates)) + } + + if err := prm.storage.BlockStorage().Store(ctx, block); err != nil { + return fmt.Errorf("failed to store parent block: %w", err) + } + + if snapshot != nil { + prm.logger.WithContext(ctx).Info("Committing parent SMT snapshot to main tree after successful block storage", + "blockNumber", block.Index.String()) + + snapshot.Commit(prm.parentSMT) + + prm.logger.WithContext(ctx).Info("Successfully committed parent SMT snapshot to main tree", + "blockNumber", block.Index.String()) + } + + prm.logger.WithContext(ctx).Info("Parent block finalized successfully", + "blockNumber", block.Index.String()) + + return nil +} + +func (prm *ParentRoundManager) reconstructParentSMT(ctx context.Context) error { + smtNodes, err := prm.storage.SmtStorage().GetAll(ctx) + if err != nil { + return fmt.Errorf("failed to get SMT nodes for reconstruction: %w", err) + } + + if len(smtNodes) == 0 { + prm.logger.WithContext(ctx).Info("No existing SMT nodes found, starting with empty parent SMT") + return nil + } + + prm.logger.WithContext(ctx).Info("Reconstructing parent SMT from current shard states", + "nodeCount", len(smtNodes)) + + leaves := make([]*smt.Leaf, 0, len(smtNodes)) + for _, node := range smtNodes { + path := new(big.Int).SetBytes(node.Key) + + leaf := &smt.Leaf{ + Path: path, + Value: node.Value, + } + leaves = append(leaves, leaf) + } + + if len(leaves) > 0 { + _, err := prm.parentSMT.AddLeaves(leaves) + if err != nil { + return fmt.Errorf("failed to add leaves to parent SMT during reconstruction: %w", err) + } + + prm.logger.WithContext(ctx).Info("Successfully reconstructed parent SMT", + "leafCount", len(leaves), + "rootHash", prm.parentSMT.GetRootHash()) + } + + return nil +} + +func getShardIDs(shardUpdates map[int]*models.ShardRootUpdate) []int { + shardIDs := make([]int, 0, len(shardUpdates)) + for shardKey := range shardUpdates { + shardIDs = append(shardIDs, shardKey) + } + return shardIDs +} diff --git a/internal/round/parent_round_manager_test.go b/internal/round/parent_round_manager_test.go new file mode 100644 index 0000000..715271b --- /dev/null +++ b/internal/round/parent_round_manager_test.go @@ -0,0 +1,441 @@ +package round + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/logger" + "github.com/unicitynetwork/aggregator-go/internal/models" + "github.com/unicitynetwork/aggregator-go/internal/smt" + "github.com/unicitynetwork/aggregator-go/internal/storage/mongodb" + "github.com/unicitynetwork/aggregator-go/internal/testutil" + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +// ParentRoundManagerTestSuite is the test suite for parent round manager +type ParentRoundManagerTestSuite struct { + suite.Suite + cfg *config.Config + logger *logger.Logger + storage *mongodb.Storage + cleanup func() +} + +// SetupSuite runs once before all tests - creates one MongoDB container for all tests +func (suite *ParentRoundManagerTestSuite) SetupSuite() { + var err error + suite.logger, err = logger.New("info", "text", "stdout", false) + require.NoError(suite.T(), err, "Should create logger") + + suite.cfg = &config.Config{ + Sharding: config.ShardingConfig{ + Mode: config.ShardingModeParent, + ShardIDLength: 4, // 4 bits = 16 possible shards (realistic for testing) + }, + Database: config.DatabaseConfig{ + Database: "test_parent_aggregator", + ConnectTimeout: 5 * time.Second, + }, + BFT: config.BFTConfig{ + Enabled: false, // Will use BFT stub + }, + Processing: config.ProcessingConfig{ + RoundDuration: 100 * time.Millisecond, // Short duration for fast tests + }, + } + + // Create storage once for all tests (reuses same MongoDB container) + suite.storage = testutil.SetupTestStorage(suite.T(), *suite.cfg) +} + +// TearDownSuite runs once after all tests +func (suite *ParentRoundManagerTestSuite) TearDownSuite() { + if suite.cleanup != nil { + suite.cleanup() + } +} + +// TearDownTest runs after each test to clean all collections +func (suite *ParentRoundManagerTestSuite) TearDownTest() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Clean all collections to ensure clean state for next test + if err := suite.storage.CleanAllCollections(ctx); err != nil { + suite.T().Logf("Warning: failed to clean collections: %v", err) + } +} + +// Test helpers + +func makeTestHash(value byte) []byte { + hash := make([]byte, 32) + hash[0] = value + return hash +} + +// Test 1: Initialization +func (suite *ParentRoundManagerTestSuite) TestInitialization() { + ctx := context.Background() + + // Create parent round manager (BFT stub will be created automatically when BFT.Enabled = false) + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage) + suite.Require().NoError(err, "Should create parent round manager successfully") + suite.Require().NotNil(prm, "ParentRoundManager should not be nil") + + // Verify initial state + suite.Assert().NotNil(prm.parentSMT, "Parent SMT should be initialized") + suite.Assert().NotNil(prm.storage, "Storage should be set") + suite.Assert().NotNil(prm.logger, "Logger should be set") + + // Since BFT is disabled in config, it should use BFT stub + suite.Assert().NotNil(prm.bftClient, "BFT client stub should be initialized") + + suite.T().Log("✓ ParentRoundManager initialized successfully with BFT stub") +} + +// Test 2: Basic Round Lifecycle +func (suite *ParentRoundManagerTestSuite) TestBasicRoundLifecycle() { + ctx := context.Background() + + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage) + suite.Require().NoError(err) + defer prm.Stop(ctx) // Stop round manager before cleanup to avoid disconnection errors + + // Start the parent round manager (initialization and SMT restoration) + err = prm.Start(ctx) + suite.Require().NoError(err) + + // Activate the round manager to start round processing + err = prm.Activate(ctx) + suite.Require().NoError(err) + + // Create 2 shard updates (with ShardIDLength=4, valid IDs are 16-31) + shard0ID := 16 // 0b10000 + shard1ID := 17 // 0b10001 + + shard0Root := makeTestHash(0xAA) + shard1Root := makeTestHash(0xBB) + + // Submit shard updates + update0 := models.NewShardRootUpdate(shard0ID, shard0Root) + update1 := models.NewShardRootUpdate(shard1ID, shard1Root) + + err = prm.SubmitShardRoot(ctx, update0) + suite.Require().NoError(err, "Should submit shard 0 update") + + err = prm.SubmitShardRoot(ctx, update1) + suite.Require().NoError(err, "Should submit shard 1 update") + + // Process the round + // The BFT stub will automatically process the round after roundDuration (100ms) + // and start the next round + time.Sleep(150 * time.Millisecond) // Wait for round to process + + // Get the parent SMT root after processing + parentRoot := prm.parentSMT.GetRootHash() + suite.Require().NotNil(parentRoot, "Parent root should be calculated") + + suite.T().Logf("Parent root hash after 2 shard updates: %x", parentRoot[:16]) + suite.T().Log("✓ Basic round lifecycle completed successfully") +} + +// Test 3: Multi-Round Updates +// TODO: This test will fail until SMT supports updating existing leaves +func (suite *ParentRoundManagerTestSuite) TestMultiRoundUpdates() { + suite.T().Skip("TODO(SMT): enable once sparse Merkle tree supports updating existing leaves") + ctx := context.Background() + + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage) + suite.Require().NoError(err) + defer prm.Stop(ctx) // Stop round manager before cleanup to avoid disconnection errors + + err = prm.Start(ctx) + suite.Require().NoError(err) + + // Activate the round manager to start round processing + err = prm.Activate(ctx) + suite.Require().NoError(err) + + // With ShardIDLength=4, valid shard IDs are 16-31 + shard0ID := 16 + shard1ID := 17 + + // Round 1: Both shards submit initial roots + suite.T().Log("=== Round 1: Initial shard roots ===") + update0_r1 := models.NewShardRootUpdate(shard0ID, makeTestHash(0x11)) + update1_r1 := models.NewShardRootUpdate(shard1ID, makeTestHash(0x22)) + + err = prm.SubmitShardRoot(ctx, update0_r1) + suite.Require().NoError(err) + err = prm.SubmitShardRoot(ctx, update1_r1) + suite.Require().NoError(err) + + time.Sleep(150 * time.Millisecond) // Wait for round 1 to process (100ms timer + buffer) + root1 := prm.parentSMT.GetRootHash() + suite.T().Logf("Parent root after round 1: %x", root1[:16]) + + // Round 2: Shard 0 submits UPDATED root (Shard 1 unchanged) + suite.T().Log("=== Round 2: Shard 0 updates ===") + update0_r2 := models.NewShardRootUpdate(shard0ID, makeTestHash(0x33)) + + err = prm.SubmitShardRoot(ctx, update0_r2) + suite.Require().NoError(err) + + time.Sleep(150 * time.Millisecond) // Wait for round 2 to process + root2 := prm.parentSMT.GetRootHash() + suite.T().Logf("Parent root after round 2: %x", root2[:16]) + + // Root should have changed because shard 0 updated + suite.Assert().NotEqual(root1, root2, "Parent root should change when shard 0 updates") + + // Round 3: Shard 1 also updates + suite.T().Log("=== Round 3: Shard 1 updates ===") + update1_r3 := models.NewShardRootUpdate(shard1ID, makeTestHash(0x44)) + + err = prm.SubmitShardRoot(ctx, update1_r3) + suite.Require().NoError(err) + + time.Sleep(150 * time.Millisecond) // Wait for round 3 to process + root3 := prm.parentSMT.GetRootHash() + suite.T().Logf("Parent root after round 3: %x", root3[:16]) + + // Root should have changed again + suite.Assert().NotEqual(root2, root3, "Parent root should change when shard 1 updates") + suite.Assert().NotEqual(root1, root3, "Parent root should be different from round 1") + + suite.T().Log("✓ Multi-round updates work correctly") +} + +// Test 4: Multiple Shards in One Round +func (suite *ParentRoundManagerTestSuite) TestMultipleShards() { + ctx := context.Background() + + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage) + suite.Require().NoError(err) + defer prm.Stop(ctx) + + err = prm.Start(ctx) + suite.Require().NoError(err) + + // Activate the round manager to start round processing + err = prm.Activate(ctx) + suite.Require().NoError(err) + + // Submit 4 different shard updates (with ShardIDLength=4, valid IDs are 16-31) + shard0ID := 16 + shard1ID := 17 + shard2ID := 18 + shard3ID := 19 + + update0 := models.NewShardRootUpdate(shard0ID, makeTestHash(0x10)) + update1 := models.NewShardRootUpdate(shard1ID, makeTestHash(0x20)) + update2 := models.NewShardRootUpdate(shard2ID, makeTestHash(0x30)) + update3 := models.NewShardRootUpdate(shard3ID, makeTestHash(0x40)) + + err = prm.SubmitShardRoot(ctx, update0) + suite.Require().NoError(err, "Should submit shard 0") + + err = prm.SubmitShardRoot(ctx, update1) + suite.Require().NoError(err, "Should submit shard 1") + + err = prm.SubmitShardRoot(ctx, update2) + suite.Require().NoError(err, "Should submit shard 2") + + err = prm.SubmitShardRoot(ctx, update3) + suite.Require().NoError(err, "Should submit shard 3") + + // Wait for round to process + time.Sleep(150 * time.Millisecond) + + // Get the parent SMT root - if it's not empty, the round processed successfully + parentRoot := prm.parentSMT.GetRootHash() + suite.Require().NotNil(parentRoot, "Parent root should be calculated") + suite.Assert().NotEqual([]byte{0}, parentRoot, "Parent root should not be empty") + + suite.T().Logf("Parent root with 4 shards: %x", parentRoot[:16]) + suite.T().Log("✓ Multiple shards processed correctly in one round") +} + +// Test 5: Empty Round (no shard updates) +func (suite *ParentRoundManagerTestSuite) TestEmptyRound() { + ctx := context.Background() + + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage) + suite.Require().NoError(err) + defer prm.Stop(ctx) + + err = prm.Start(ctx) + suite.Require().NoError(err) + + // Activate the round manager to start round processing + err = prm.Activate(ctx) + suite.Require().NoError(err) + + // Get the initial root (should be empty tree root) + initialRoot := prm.parentSMT.GetRootHash() + suite.Require().NotNil(initialRoot, "Initial root should exist") + + // Don't submit any shard updates - just wait for the round to process + time.Sleep(150 * time.Millisecond) + + // The root should remain the same (no changes) + currentRoot := prm.parentSMT.GetRootHash() + suite.Assert().Equal(initialRoot, currentRoot, "Root should not change in empty round") + + // Verify that a block was still created (empty rounds still create blocks) + block, err := suite.storage.BlockStorage().GetLatest(ctx) + suite.Require().NoError(err, "Should get latest block") + suite.Require().NotNil(block, "Block should exist even for empty round") + + suite.T().Logf("Empty round processed with root: %x", currentRoot[:16]) + suite.T().Log("✓ Empty round processed correctly") +} + +// Test 6: Duplicate Shard Update (same shard, same value) +func (suite *ParentRoundManagerTestSuite) TestDuplicateShardUpdate() { + ctx := context.Background() + + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage) + suite.Require().NoError(err) + defer prm.Stop(ctx) + + err = prm.Start(ctx) + suite.Require().NoError(err) + + // Activate the round manager to start round processing + err = prm.Activate(ctx) + suite.Require().NoError(err) + + // With ShardIDLength=4, valid shard IDs are 16-31 + shard0ID := 16 + shard0Root := makeTestHash(0xAA) + + // Submit the same shard update twice with the same value + update1 := models.NewShardRootUpdate(shard0ID, shard0Root) + update2 := models.NewShardRootUpdate(shard0ID, shard0Root) + + err = prm.SubmitShardRoot(ctx, update1) + suite.Require().NoError(err, "First submission should succeed") + + err = prm.SubmitShardRoot(ctx, update2) + suite.Require().NoError(err, "Second submission (duplicate) should succeed") + + // Wait for round to process + time.Sleep(150 * time.Millisecond) + + // Get the parent SMT root - should be calculated correctly + parentRoot := prm.parentSMT.GetRootHash() + suite.Require().NotNil(parentRoot, "Parent root should be calculated") + + // The duplicate submission should not cause any issues + // The SMT should have exactly one leaf for shard 0 + suite.T().Logf("Parent root with duplicate shard update: %x", parentRoot[:16]) + suite.T().Log("✓ Duplicate shard update handled correctly") +} + +// Test 7: Multiple Updates from Same Shard (different values - latest should win) +// TODO: This test will fail until SMT supports updating existing leaves +func (suite *ParentRoundManagerTestSuite) TestSameShardMultipleValues() { + suite.T().Skip("TODO(SMT): enable once sparse Merkle tree supports updating existing leaves") + ctx := context.Background() + + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage) + suite.Require().NoError(err) + defer prm.Stop(ctx) + + err = prm.Start(ctx) + suite.Require().NoError(err) + + // Activate the round manager to start round processing + err = prm.Activate(ctx) + suite.Require().NoError(err) + + // With ShardIDLength=4, valid shard IDs are 16-31 + shard0ID := 16 + latestValue := makeTestHash(0xCC) + + // Submit 3 different updates for the same shard + update1 := models.NewShardRootUpdate(shard0ID, makeTestHash(0xAA)) + update2 := models.NewShardRootUpdate(shard0ID, makeTestHash(0xBB)) + update3 := models.NewShardRootUpdate(shard0ID, latestValue) // This should be the final value + + err = prm.SubmitShardRoot(ctx, update1) + suite.Require().NoError(err, "First submission should succeed") + + err = prm.SubmitShardRoot(ctx, update2) + suite.Require().NoError(err, "Second submission should succeed") + + err = prm.SubmitShardRoot(ctx, update3) + suite.Require().NoError(err, "Third submission should succeed") + + // Wait for round to process + time.Sleep(150 * time.Millisecond) + + // Get the parent SMT root + parentRoot := prm.parentSMT.GetRootHash() + suite.Require().NotNil(parentRoot, "Parent root should be calculated") + + // Create a reference SMT with only the latest value to verify + smtInstance := smt.NewParentSparseMerkleTree(api.SHA256, suite.cfg.Sharding.ShardIDLength) + referenceSMT := smt.NewThreadSafeSMT(smtInstance) + + // Add the latest value as a pre-hashed leaf (same way ParentRoundManager does it) + err = referenceSMT.AddPreHashedLeaf(update3.GetPath(), latestValue) + suite.Require().NoError(err, "Should add leaf to reference SMT") + + expectedRoot := referenceSMT.GetRootHash() + suite.Assert().Equal(expectedRoot, parentRoot, "Parent root should match reference SMT with only latest value") + + suite.T().Logf("Parent root with latest value (0xCC): %s", parentRoot) + suite.T().Logf("Expected root (reference SMT): %s", expectedRoot) + suite.T().Log("✓ Latest shard update value was used correctly") +} + +// Test 8: Block root persisted in storage matches SMT root after round finalization +func (suite *ParentRoundManagerTestSuite) TestBlockRootMatchesSMTRoot() { + ctx := context.Background() + + prm, err := NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage) + suite.Require().NoError(err) + defer prm.Stop(ctx) + + err = prm.Start(ctx) + suite.Require().NoError(err) + + err = prm.Activate(ctx) + suite.Require().NoError(err) + + // With ShardIDLength=4, valid shard IDs are 16-31 (binary 10000-11111) + shardID := 16 + update := models.NewShardRootUpdate(shardID, makeTestHash(0xAB)) + err = prm.SubmitShardRoot(ctx, update) + suite.Require().NoError(err) + + var latestBlock *models.Block + suite.Require().Eventually(func() bool { + block, err := suite.storage.BlockStorage().GetLatest(ctx) + if err != nil || block == nil { + return false + } + latestBlock = block + return true + }, 5*time.Second, 50*time.Millisecond, "expected finalized block to be available") + + currentRootHex := prm.GetSMT().GetRootHash() + expectedRoot, err := api.NewHexBytesFromString(currentRootHex) + suite.Require().NoError(err) + suite.Require().NotNil(latestBlock, "latest block should be available") + + suite.Assert().Equal(expectedRoot.String(), latestBlock.RootHash.String(), "stored block root should match SMT root") +} + +// TestParentRoundManagerSuite runs the test suite +func TestParentRoundManagerSuite(t *testing.T) { + suite.Run(t, new(ParentRoundManagerTestSuite)) +} diff --git a/internal/round/round_manager.go b/internal/round/round_manager.go index 07ca015..ea34154 100644 --- a/internal/round/round_manager.go +++ b/internal/round/round_manager.go @@ -27,6 +27,8 @@ const ( RoundStateFinalizing // Finalizing block ) +const miniBatchSize = 100 // Number of commitments to process per SMT mini-batch + func (rs RoundState) String() string { switch rs { case RoundStateCollecting: @@ -79,6 +81,7 @@ type RoundManager struct { commitmentQueue interfaces.CommitmentQueue storage interfaces.Storage smt *smt.ThreadSafeSMT + rootClient RootAggregatorClient bftClient bft.BFTClient stateTracker *state.Tracker @@ -111,6 +114,11 @@ type RoundManager struct { totalCommitments int64 } +type RootAggregatorClient interface { + SubmitShardRoot(ctx context.Context, req *api.SubmitShardRootRequest) error + GetShardProof(ctx context.Context, req *api.GetShardProofRequest) (*api.RootShardInclusionProof, error) +} + // RoundMetrics tracks performance metrics for a round type RoundMetrics struct { CommitmentsProcessed int @@ -120,17 +128,14 @@ type RoundMetrics struct { } // NewRoundManager creates a new round manager -func NewRoundManager(ctx context.Context, cfg *config.Config, logger *logger.Logger, commitmentQueue interfaces.CommitmentQueue, storage interfaces.Storage, stateTracker *state.Tracker) (*RoundManager, error) { - // Initialize SMT with empty tree - will be replaced with restored tree in Start() - smtInstance := smt.NewSparseMerkleTree(api.SHA256, 16+256) - threadSafeSMT := smt.NewThreadSafeSMT(smtInstance) - +func NewRoundManager(ctx context.Context, cfg *config.Config, logger *logger.Logger, smtInstance *smt.SparseMerkleTree, commitmentQueue interfaces.CommitmentQueue, storage interfaces.Storage, rootAggregatorClient RootAggregatorClient, stateTracker *state.Tracker) (*RoundManager, error) { rm := &RoundManager{ config: cfg, logger: logger, commitmentQueue: commitmentQueue, storage: storage, - smt: threadSafeSMT, + smt: smt.NewThreadSafeSMT(smtInstance), + rootClient: rootAggregatorClient, stateTracker: stateTracker, roundDuration: cfg.Processing.RoundDuration, // Configurable round duration (default 1s) commitmentStream: make(chan *models.Commitment, 10000), // Reasonable buffer for streaming @@ -140,27 +145,30 @@ func NewRoundManager(ctx context.Context, cfg *config.Config, logger *logger.Log avgSMTUpdateTime: 5 * time.Millisecond, // Initial estimate per batch } - if cfg.BFT.Enabled { - var err error - rm.bftClient, err = bft.NewBFTClient(ctx, &cfg.BFT, rm, logger) - if err != nil { - return nil, fmt.Errorf("failed to create BFT client: %w", err) - } - } else { - nextBlockNumber := api.NewBigInt(nil) - lastBlockNumber := api.NewBigInt(big.NewInt(0)) - if rm.storage != nil && rm.storage.BlockStorage() != nil { + // create BFT client for standalone mode + if cfg.Sharding.Mode == config.ShardingModeStandalone { + if cfg.BFT.Enabled { var err error - lastBlockNumber, err = rm.storage.BlockStorage().GetLatestNumber(ctx) + rm.bftClient, err = bft.NewBFTClient(ctx, &cfg.BFT, rm, logger) if err != nil { - return nil, fmt.Errorf("failed to fetch latest block number: %w", err) + return nil, fmt.Errorf("failed to create BFT client: %w", err) } - if lastBlockNumber == nil { - lastBlockNumber = api.NewBigInt(big.NewInt(0)) + } else { + nextBlockNumber := api.NewBigInt(nil) + lastBlockNumber := api.NewBigInt(big.NewInt(0)) + if rm.storage != nil && rm.storage.BlockStorage() != nil { + var err error + lastBlockNumber, err = rm.storage.BlockStorage().GetLatestNumber(ctx) + if err != nil { + return nil, fmt.Errorf("failed to fetch latest block number: %w", err) + } + if lastBlockNumber == nil { + lastBlockNumber = api.NewBigInt(big.NewInt(0)) + } } + nextBlockNumber.Add(lastBlockNumber.Int, big.NewInt(1)) + rm.bftClient = bft.NewBFTClientStub(logger, rm, nextBlockNumber) } - nextBlockNumber.Add(lastBlockNumber.Int, big.NewInt(1)) - rm.bftClient = bft.NewBFTClientStub(logger, rm, nextBlockNumber) } return rm, nil } @@ -182,7 +190,7 @@ func (rm *RoundManager) Start(ctx context.Context) error { } // Stop gracefully stops the round manager -func (rm *RoundManager) Stop(ctx context.Context) { +func (rm *RoundManager) Stop(ctx context.Context) error { rm.logger.Info("Stopping Round Manager") if err := rm.Deactivate(ctx); err != nil { @@ -200,6 +208,7 @@ func (rm *RoundManager) Stop(ctx context.Context) { rm.wg.Wait() rm.logger.Info("Round Manager stopped") + return nil } // GetCurrentRound returns information about the current round @@ -328,6 +337,18 @@ func (rm *RoundManager) StartNewRound(ctx context.Context, roundNumber *api.BigI rm.logger.WithContext(ctx).Error("Failed to process round", "roundNumber", roundNumber.String(), "error", err.Error()) + + if rm.config.Sharding.Mode.IsChild() { + nextRoundNumber := big.NewInt(0).Add(roundNumber.Int, big.NewInt(1)) + rm.logger.WithContext(ctx).Info("Attempting to start new round after processing failure.", + "failedRound", roundNumber.String(), + "nextRound", nextRoundNumber.String()) + if startErr := rm.StartNewRound(ctx, api.NewBigInt(nextRoundNumber)); startErr != nil { + rm.logger.WithContext(ctx).Error("Failed to start new round after processing error", + "nextRound", nextRoundNumber.String(), + "error", startErr) + } + } } }) @@ -373,7 +394,11 @@ func (rm *RoundManager) processCurrentRound(ctx context.Context) error { // Initialize commitments slice for this round // Note: Any commitments consumed from the channel MUST be processed in this round - rm.currentRound.Commitments = make([]*models.Commitment, 0, 10000) // Larger pre-allocation for high throughput + capacity := 10000 // Default capacity + if rm.config.Processing.MaxCommitmentsPerRound > 0 { + capacity = rm.config.Processing.MaxCommitmentsPerRound + } + rm.currentRound.Commitments = make([]*models.Commitment, 0, capacity) // Calculate adaptive processing deadline based on historical data processingDuration := time.Duration(float64(rm.roundDuration) * rm.processingRatio) @@ -400,6 +425,19 @@ ProcessLoop: rm.roundMutex.Lock() rm.currentRound.Commitments = append(rm.currentRound.Commitments, commitment) commitmentsProcessed++ + currentLen := len(rm.currentRound.Commitments) + + // Process in mini-batches for SMT efficiency + if currentLen%miniBatchSize == 0 { + batchStart := time.Now() + batchSlice := rm.currentRound.Commitments[len(rm.currentRound.Commitments)-miniBatchSize:] + if err := rm.processMiniBatch(ctx, batchSlice); err != nil { + rm.logger.WithContext(ctx).Error("Failed to process mini-batch", + "error", err.Error(), + "roundNumber", roundNumber.String()) + } + smtUpdateTime += time.Since(batchStart) + } // Check if we've reached the configured cap if rm.config.Processing.MaxCommitmentsPerRound > 0 && commitmentsProcessed >= rm.config.Processing.MaxCommitmentsPerRound { @@ -411,17 +449,6 @@ ProcessLoop: break ProcessLoop } - // Process in mini-batches for SMT efficiency - if len(rm.currentRound.Commitments)%100 == 0 { - // Process this mini-batch into SMT - batchStart := time.Now() - if err := rm.processMiniBatch(ctx, rm.currentRound.Commitments[len(rm.currentRound.Commitments)-100:]); err != nil { - rm.logger.WithContext(ctx).Error("Failed to process mini-batch", - "error", err.Error(), - "roundNumber", roundNumber.String()) - } - smtUpdateTime += time.Since(batchStart) - } rm.roundMutex.Unlock() case <-ctx.Done(): @@ -441,7 +468,7 @@ ProcessLoop: // Process any remaining commitments not in a full mini-batch rm.roundMutex.Lock() - lastBatchStart := (commitmentsProcessed / 100) * 100 + lastBatchStart := (commitmentsProcessed / miniBatchSize) * miniBatchSize if lastBatchStart < len(rm.currentRound.Commitments) { batchStart := time.Now() if err := rm.processMiniBatch(ctx, rm.currentRound.Commitments[lastBatchStart:]); err != nil { @@ -460,8 +487,9 @@ ProcessLoop: // Update average SMT update time (exponential moving average) if commitmentsProcessed > 0 { - avgBatchTime := smtUpdateTime / time.Duration((commitmentsProcessed+99)/100) // Number of batches - rm.avgSMTUpdateTime = (rm.avgSMTUpdateTime*4 + avgBatchTime) / 5 // Weight towards recent: 80/20 + numBatches := (commitmentsProcessed + miniBatchSize - 1) / miniBatchSize // Round up + avgBatchTime := smtUpdateTime / time.Duration(numBatches) + rm.avgSMTUpdateTime = (rm.avgSMTUpdateTime*4 + avgBatchTime) / 5 // Weight towards recent: 80/20 } rm.lastRoundMetrics = RoundMetrics{ @@ -535,10 +563,12 @@ ProcessLoop: rm.totalRounds++ rm.totalCommitments += int64(len(rm.currentRound.Commitments)) + roundTotalTime := time.Since(rm.currentRound.StartTime) rm.logger.WithContext(ctx).Info("Round processing completed successfully", "roundNumber", roundNumber.String(), - "totalRounds", rm.totalRounds, - "totalCommitments", rm.totalCommitments) + "commitments", len(rm.currentRound.Commitments), + "roundTotalTime", roundTotalTime, + "processingTime", processingTime) return nil } @@ -802,16 +832,38 @@ func (rm *RoundManager) restoreSmtFromStorage(ctx context.Context) (*api.BigInt, } func (rm *RoundManager) Activate(ctx context.Context) error { - if err := rm.bftClient.Start(ctx); err != nil { - return fmt.Errorf("failed to start BFT client: %w", err) + switch rm.config.Sharding.Mode { + case config.ShardingModeStandalone: + if err := rm.bftClient.Start(ctx); err != nil { + return fmt.Errorf("failed to start BFT client: %w", err) + } + case config.ShardingModeChild: + roundNumber := rm.stateTracker.GetLastSyncedBlock() + roundNumber.Add(roundNumber, big.NewInt(1)) + if err := rm.StartNewRound(ctx, api.NewBigInt(roundNumber)); err != nil { + return fmt.Errorf("failed to start new round: %w", err) + } + default: + return fmt.Errorf("invalid shard mode: %s", rm.config.Sharding.Mode) } + rm.startCommitmentPrefetcher(ctx) return nil } func (rm *RoundManager) Deactivate(ctx context.Context) error { - rm.bftClient.Stop() rm.stopCommitmentPrefetcher() + if rm.bftClient != nil { + rm.bftClient.Stop() + } + + // stop creating blocks if we become follower + rm.roundMutex.Lock() + if rm.roundTimer != nil { + rm.roundTimer.Stop() + } + rm.roundMutex.Unlock() + return nil } diff --git a/internal/round/round_manager_test.go b/internal/round/round_manager_test.go new file mode 100644 index 0000000..5a0c305 --- /dev/null +++ b/internal/round/round_manager_test.go @@ -0,0 +1,119 @@ +package round + +import ( + "context" + "errors" + "math/big" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/ha/state" + "github.com/unicitynetwork/aggregator-go/internal/logger" + testsharding "github.com/unicitynetwork/aggregator-go/internal/sharding" + "github.com/unicitynetwork/aggregator-go/internal/smt" + "github.com/unicitynetwork/aggregator-go/internal/testutil" + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +// test the good case where blocks are created and stored successfully +func TestParentShardIntegration_GoodCase(t *testing.T) { + // setup dependencies + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg := config.Config{ + Processing: config.ProcessingConfig{ + RoundDuration: 100 * time.Millisecond, + BatchLimit: 1000, + }, + Sharding: config.ShardingConfig{ + Mode: config.ShardingModeChild, + Child: config.ChildConfig{ + ShardID: 0b11, + ParentPollTimeout: 5 * time.Second, + ParentPollInterval: 100 * time.Millisecond, + }, + }, + } + storage := testutil.SetupTestStorage(t, cfg) + testLogger, err := logger.New("info", "text", "stdout", false) + require.NoError(t, err) + rootAggregatorClient := testsharding.NewRootAggregatorClientStub() + + // create round manager + rm, err := NewRoundManager(ctx, &cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, rootAggregatorClient, state.NewSyncStateTracker()) + require.NoError(t, err) + + // start round manager + require.NoError(t, rm.Start(ctx)) + require.NoError(t, rm.Activate(ctx)) + + // verify first 3 blocks + for i := 1; i <= 3; i++ { + require.Eventually(t, func() bool { + block, err := storage.BlockStorage().GetByNumber(ctx, api.NewBigInt(big.NewInt(int64(i)))) + if err != nil { + return false + } + return block != nil + }, 3*time.Second, 100*time.Millisecond, "block %d should have been created", i) + } + + // verify metrics + require.Equal(t, 3, rootAggregatorClient.SubmissionCount()) + require.Equal(t, 3, rootAggregatorClient.ProofCount()) +} + +// test that any error on block processing e.g. error on root aggregator communication does not hang the block processor +func TestParentShardIntegration_RoundProcessingError(t *testing.T) { + // setup dependencies + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg := config.Config{ + Processing: config.ProcessingConfig{ + RoundDuration: 100 * time.Millisecond, + BatchLimit: 1000, + }, + Sharding: config.ShardingConfig{ + Mode: config.ShardingModeChild, + Child: config.ChildConfig{ + ShardID: 0b11, + ParentPollTimeout: 5 * time.Second, + ParentPollInterval: 100 * time.Millisecond, + }, + }, + } + storage := testutil.SetupTestStorage(t, cfg) + testLogger, err := logger.New("info", "text", "stdout", false) + require.NoError(t, err) + + // create root aggregator client where all submissions fail + rootAggregatorClient := testsharding.NewRootAggregatorClientStub() + rootAggregatorClient.SetSubmissionError(errors.New("some error")) + + // create round manager + rm, err := NewRoundManager(ctx, &cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, rootAggregatorClient, state.NewSyncStateTracker()) + require.NoError(t, err) + + // start round manager + require.NoError(t, rm.Start(ctx)) + require.NoError(t, rm.Activate(ctx)) + + // wait for a couple of rounds worth of time + time.Sleep(500 * time.Millisecond) + + // verify that no blocks were created + latestBlock, err := storage.BlockStorage().GetLatest(ctx) + require.NoError(t, err) + require.Nil(t, latestBlock) + + // verify that the round manager is NOT stuck on round 1 + currentRound := rm.GetCurrentRound() + require.NotNil(t, currentRound) + require.Greater(t, currentRound.Number.Int64(), int64(1)) + + // verify no submission requests were made successfully + require.Equal(t, 0, rootAggregatorClient.SubmissionCount()) +} diff --git a/internal/round/smt_persistence_integration_test.go b/internal/round/smt_persistence_integration_test.go index a608d98..6985be3 100644 --- a/internal/round/smt_persistence_integration_test.go +++ b/internal/round/smt_persistence_integration_test.go @@ -36,8 +36,7 @@ var conf = config.Config{ // TestSmtPersistenceAndRestoration tests SMT persistence and restoration with consistent root hashes func TestSmtPersistenceAndRestoration(t *testing.T) { - storage, cleanup := testutil.SetupTestStorage(t, conf) - defer cleanup() + storage := testutil.SetupTestStorage(t, conf) ctx := context.Background() @@ -61,7 +60,7 @@ func TestSmtPersistenceAndRestoration(t *testing.T) { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, state.NewSyncStateTracker()) + rm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker()) require.NoError(t, err, "Should create RoundManager") // Test persistence @@ -80,7 +79,7 @@ func TestSmtPersistenceAndRestoration(t *testing.T) { freshHash := freshSmt.GetRootHashHex() // Create RoundManager and call Start() to trigger restoration - restoredRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, state.NewSyncStateTracker()) + restoredRm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker()) require.NoError(t, err, "Should create RoundManager") err = restoredRm.Start(ctx) @@ -92,7 +91,8 @@ func TestSmtPersistenceAndRestoration(t *testing.T) { // Verify inclusion proofs work for _, leaf := range testLeaves { - merkleTreePath := restoredRm.smt.GetPath(leaf.Path) + merkleTreePath, err := restoredRm.smt.GetPath(leaf.Path) + require.NoError(t, err) require.NotNil(t, merkleTreePath, "Should be able to get Merkle path") assert.NotEmpty(t, merkleTreePath.Root, "Merkle path should have root hash") } @@ -100,8 +100,7 @@ func TestSmtPersistenceAndRestoration(t *testing.T) { // TestLargeSmtRestoration tests multi-chunk restoration with large dataset func TestLargeSmtRestoration(t *testing.T) { - storage, cleanup := testutil.SetupTestStorage(t, conf) - defer cleanup() + storage := testutil.SetupTestStorage(t, conf) ctx := context.Background() testLogger, err := logger.New("info", "text", "stdout", false) @@ -112,7 +111,7 @@ func TestLargeSmtRestoration(t *testing.T) { RoundDuration: time.Second, }, } - rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, state.NewSyncStateTracker()) + rm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker()) require.NoError(t, err, "Should create RoundManager") const testNodeCount = 2500 // Ensure multiple chunks (chunkSize = 1000 in round_manager.go) @@ -142,7 +141,7 @@ func TestLargeSmtRestoration(t *testing.T) { require.Equal(t, int64(testNodeCount), count, "Should have stored all nodes") // Create new RoundManager and call Start() to restore from storage (uses multiple chunks) - newRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, state.NewSyncStateTracker()) + newRm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker()) require.NoError(t, err, "Should create new RoundManager") err = newRm.Start(ctx) @@ -157,8 +156,7 @@ func TestLargeSmtRestoration(t *testing.T) { // TestCompleteWorkflowWithRestart tests end-to-end workflow including service restart simulation func TestCompleteWorkflowWithRestart(t *testing.T) { - storage, cleanup := testutil.SetupTestStorage(t, conf) - defer cleanup() + storage := testutil.SetupTestStorage(t, conf) ctx := context.Background() @@ -178,7 +176,7 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { testLogger, err := logger.New("info", "text", "stdout", false) require.NoError(t, err) - rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, state.NewSyncStateTracker()) + rm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker()) require.NoError(t, err, "Should create RoundManager") rm.currentRound = &Round{ @@ -211,10 +209,13 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { block := models.NewBlock( blockNumber, "unicity", + 0, "1.0", "mainnet", rootHashBytes, api.HexBytes{}, + nil, + nil, ) err = rm.FinalizeBlock(ctx, block) @@ -226,7 +227,8 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { assert.Equal(t, int64(len(testCommitments)), count, "Should have persisted SMT nodes for all commitments") // Simulate service restart with new round manager - newRm, err := NewRoundManager(ctx, &config.Config{Processing: config.ProcessingConfig{RoundDuration: time.Second}}, testLogger, storage.CommitmentQueue(), storage, state.NewSyncStateTracker()) + cfg = &config.Config{Processing: config.ProcessingConfig{RoundDuration: time.Second}} + newRm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker()) require.NoError(t, err, "NewRoundManager should succeed after restart") // Call Start() to trigger SMT restoration @@ -243,7 +245,8 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { path, err := commitment.RequestID.GetPath() require.NoError(t, err, "Should be able to get path from request ID") - merkleTreePath := newRm.smt.GetPath(path) + merkleTreePath, err := newRm.smt.GetPath(path) + require.NoError(t, err) require.NotNil(t, merkleTreePath, "Should be able to get Merkle path") assert.NotEmpty(t, merkleTreePath.Root, "Merkle path should have root hash") assert.NotEmpty(t, merkleTreePath.Steps, "Merkle path should have steps") @@ -252,8 +255,7 @@ func TestCompleteWorkflowWithRestart(t *testing.T) { // TestSmtRestorationWithBlockVerification tests that SMT restoration verifies against existing blocks func TestSmtRestorationWithBlockVerification(t *testing.T) { - storage, cleanup := testutil.SetupTestStorage(t, conf) - defer cleanup() + storage := testutil.SetupTestStorage(t, conf) ctx := context.Background() testLogger, err := logger.New("info", "text", "stdout", false) @@ -279,15 +281,17 @@ func TestSmtRestorationWithBlockVerification(t *testing.T) { // Create a block with the expected root hash block := &models.Block{ - Index: api.NewBigInt(big.NewInt(1)), - ChainID: "test-chain", - Version: "1.0.0", - ForkID: "test-fork", - RootHash: api.HexBytes(expectedRootHashBytes), // Use bytes, not hex string - PreviousBlockHash: api.HexBytes("0000000000000000000000000000000000000000000000000000000000000000"), - NoDeletionProofHash: api.HexBytes(""), - CreatedAt: api.NewTimestamp(time.Now()), - UnicityCertificate: api.HexBytes("certificate_data"), + Index: api.NewBigInt(big.NewInt(1)), + ChainID: "test-chain", + ShardID: 0, + Version: "1.0.0", + ForkID: "test-fork", + RootHash: api.HexBytes(expectedRootHashBytes), // Use bytes, not hex string + PreviousBlockHash: api.HexBytes("0000000000000000000000000000000000000000000000000000000000000000"), + NoDeletionProofHash: api.HexBytes(""), + CreatedAt: api.NewTimestamp(time.Now()), + UnicityCertificate: api.HexBytes("certificate_data"), + ParentMerkleTreePath: nil, } // Store the block @@ -298,7 +302,7 @@ func TestSmtRestorationWithBlockVerification(t *testing.T) { cfg := &config.Config{ Processing: config.ProcessingConfig{RoundDuration: time.Second}, } - rm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, state.NewSyncStateTracker()) + rm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker()) require.NoError(t, err, "Should create RoundManager") // Persist SMT nodes to storage @@ -307,7 +311,7 @@ func TestSmtRestorationWithBlockVerification(t *testing.T) { // Test 1: Successful verification (matching root hash) t.Run("SuccessfulVerification", func(t *testing.T) { - successRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, state.NewSyncStateTracker()) + successRm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker()) require.NoError(t, err, "Should create RoundManager") err = successRm.Start(ctx) @@ -323,22 +327,24 @@ func TestSmtRestorationWithBlockVerification(t *testing.T) { t.Run("FailedVerification", func(t *testing.T) { // Create a block with a different root hash to simulate mismatch wrongBlock := &models.Block{ - Index: api.NewBigInt(big.NewInt(2)), - ChainID: "test-chain", - Version: "1.0.0", - ForkID: "test-fork", - RootHash: api.HexBytes("wrong_root_hash_value"), // Intentionally wrong hash - PreviousBlockHash: api.HexBytes("0000000000000000000000000000000000000000000000000000000000000001"), - NoDeletionProofHash: api.HexBytes(""), - CreatedAt: api.NewTimestamp(time.Now()), - UnicityCertificate: api.HexBytes("certificate_data"), + Index: api.NewBigInt(big.NewInt(2)), + ChainID: "test-chain", + ShardID: 0, + Version: "1.0.0", + ForkID: "test-fork", + RootHash: api.HexBytes("wrong_root_hash_value"), // Intentionally wrong hash + PreviousBlockHash: api.HexBytes("0000000000000000000000000000000000000000000000000000000000000001"), + NoDeletionProofHash: api.HexBytes(""), + CreatedAt: api.NewTimestamp(time.Now()), + UnicityCertificate: api.HexBytes("certificate_data"), + ParentMerkleTreePath: nil, } // Store the wrong block (this will become the "latest" block) err = storage.BlockStorage().Store(ctx, wrongBlock) require.NoError(t, err, "Should store wrong test block") - failRm, err := NewRoundManager(ctx, cfg, testLogger, storage.CommitmentQueue(), storage, state.NewSyncStateTracker()) + failRm, err := NewRoundManager(ctx, cfg, testLogger, smt.NewSparseMerkleTree(api.SHA256, 16+256), storage.CommitmentQueue(), storage, nil, state.NewSyncStateTracker()) require.NoError(t, err, "Should create RoundManager") // This should fail because the restored SMT root hash doesn't match the latest block diff --git a/internal/service/parent_service.go b/internal/service/parent_service.go new file mode 100644 index 0000000..5f62f5e --- /dev/null +++ b/internal/service/parent_service.go @@ -0,0 +1,273 @@ +package service + +import ( + "context" + "errors" + "fmt" + "math/big" + + "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/logger" + "github.com/unicitynetwork/aggregator-go/internal/models" + "github.com/unicitynetwork/aggregator-go/internal/round" + "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +// ParentAggregatorService implements the business logic for the parent aggregator +type ParentAggregatorService struct { + config *config.Config + logger *logger.Logger + storage interfaces.Storage + parentRoundManager *round.ParentRoundManager + leaderSelector LeaderSelector +} + +func (pas *ParentAggregatorService) isLeader(ctx context.Context) (bool, error) { + if pas.leaderSelector == nil { + return true, nil + } + + isLeader, err := pas.leaderSelector.IsLeader(ctx) + if err != nil { + pas.logger.WithContext(ctx).Error("Failed to determine leadership status", "error", err.Error()) + return false, err + } + + return isLeader, nil +} + +// NewParentAggregatorService creates a new parent aggregator service +func NewParentAggregatorService(cfg *config.Config, logger *logger.Logger, parentRoundManager *round.ParentRoundManager, storage interfaces.Storage, leaderSelector LeaderSelector) *ParentAggregatorService { + return &ParentAggregatorService{ + config: cfg, + logger: logger, + storage: storage, + parentRoundManager: parentRoundManager, + leaderSelector: leaderSelector, + } +} + +// Start starts the parent aggregator service +func (pas *ParentAggregatorService) Start(ctx context.Context) error { + pas.logger.WithContext(ctx).Info("Starting Parent Aggregator Service") + + if err := pas.parentRoundManager.Start(ctx); err != nil { + return fmt.Errorf("failed to start parent round manager: %w", err) + } + + pas.logger.WithContext(ctx).Info("Parent Aggregator Service started successfully") + return nil +} + +// Stop stops the parent aggregator service +func (pas *ParentAggregatorService) Stop(ctx context.Context) error { + pas.logger.WithContext(ctx).Info("Stopping Parent Aggregator Service") + + if err := pas.parentRoundManager.Stop(ctx); err != nil { + pas.logger.WithContext(ctx).Error("Failed to stop parent round manager", "error", err.Error()) + return fmt.Errorf("failed to stop parent round manager: %w", err) + } + + pas.logger.WithContext(ctx).Info("Parent Aggregator Service stopped successfully") + return nil +} + +// SubmitShardRoot handles shard root submission from child aggregators +func (pas *ParentAggregatorService) SubmitShardRoot(ctx context.Context, req *api.SubmitShardRootRequest) (*api.SubmitShardRootResponse, error) { + isLeader, err := pas.isLeader(ctx) + if err != nil { + return &api.SubmitShardRootResponse{ + Status: api.ShardRootStatusInternalError, + }, nil + } + + if !isLeader { + pas.logger.WithContext(ctx).Warn("Rejecting shard root submission because node is not leader", + "shardId", req.ShardID) + return &api.SubmitShardRootResponse{ + Status: api.ShardRootStatusNotLeader, + }, nil + } + + if err := pas.validateShardID(req.ShardID); err != nil { + pas.logger.WithContext(ctx).Warn("Invalid shard ID", "shardId", req.ShardID, "error", err.Error()) + return &api.SubmitShardRootResponse{ + Status: api.ShardRootStatusInvalidShardID, + }, nil + } + + update := models.NewShardRootUpdate(req.ShardID, req.RootHash) + + err = pas.parentRoundManager.SubmitShardRoot(ctx, update) + if err != nil { + pas.logger.WithContext(ctx).Error("Failed to submit shard root to round manager", "error", err.Error()) + return &api.SubmitShardRootResponse{ + Status: api.ShardRootStatusInternalError, + }, nil + } + + pas.logger.WithContext(ctx).Info("Shard root update accepted", + "shardId", req.ShardID, + "rootHash", req.RootHash.String()) + + return &api.SubmitShardRootResponse{ + Status: api.ShardRootStatusSuccess, + }, nil +} + +// GetShardProof handles shard proof requests +func (pas *ParentAggregatorService) GetShardProof(ctx context.Context, req *api.GetShardProofRequest) (*api.GetShardProofResponse, error) { + if pas.leaderSelector != nil { + isLeader, err := pas.isLeader(ctx) + if err != nil { + return nil, fmt.Errorf("failed to determine leadership status: %w", err) + } + if !isLeader { + pas.logger.WithContext(ctx).Debug("Serving shard proof while node is follower") + } + } + + if err := pas.validateShardID(req.ShardID); err != nil { + pas.logger.WithContext(ctx).Warn("Invalid shard ID", "shardId", req.ShardID, "error", err.Error()) + return nil, fmt.Errorf("invalid shard ID: %w", err) + } + + pas.logger.WithContext(ctx).Info("Shard proof requested", "shardId", req.ShardID) + + shardPath := new(big.Int).SetInt64(int64(req.ShardID)) + merkleTreePath, err := pas.parentRoundManager.GetSMT().GetPath(shardPath) + if err != nil { + return nil, fmt.Errorf("failed to get merkle tree path: %w", err) + } + + var proofPath *api.MerkleTreePath + if len(merkleTreePath.Steps) > 0 && merkleTreePath.Steps[0].Data != nil { + proofPath = merkleTreePath + pas.logger.WithContext(ctx).Info("Generated shard proof from current state", + "shardId", req.ShardID, + "rootHash", merkleTreePath.Root) + } else { + proofPath = nil + pas.logger.WithContext(ctx).Info("Shard has not submitted root yet, returning nil proof", + "shardId", req.ShardID) + } + + latestBlock, err := pas.storage.BlockStorage().GetLatest(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get latest block: %w", err) + } + + var unicityCertificate api.HexBytes + if latestBlock != nil { + unicityCertificate = latestBlock.UnicityCertificate + } + + return &api.GetShardProofResponse{ + MerkleTreePath: proofPath, + UnicityCertificate: unicityCertificate, + }, nil +} + +func (pas *ParentAggregatorService) validateShardID(shardID api.ShardID) error { + if shardID <= 1 { + return errors.New("shard ID must be positive and have at least 2 bits") + } + + // ShardID encoding: MSB=1 (prefix bit) + ShardIDLength bits for actual shard identifier + // For ShardIDLength=1: valid range is [2,3] (0b10, 0b11) + // For ShardIDLength=2: valid range is [4,7] (0b100-0b111) + // The prefix bit ensures leading zeros are preserved in the path calculation + + minShardID := int64(1 << pas.config.Sharding.ShardIDLength) + maxShardID := int64((1 << (pas.config.Sharding.ShardIDLength + 1)) - 1) + + shardValue := int64(shardID) + + if shardValue < minShardID { + return fmt.Errorf("shard ID %d is below minimum %d for %d-bit shard IDs (MSB prefix bit not set)", + shardValue, minShardID, pas.config.Sharding.ShardIDLength) + } + + if shardValue > maxShardID { + return fmt.Errorf("shard ID %d exceeds maximum %d for %d-bit shard IDs", + shardValue, maxShardID, pas.config.Sharding.ShardIDLength) + } + + return nil +} + +// SubmitCommitment - not used in parent mode +func (pas *ParentAggregatorService) SubmitCommitment(ctx context.Context, req *api.SubmitCommitmentRequest) (*api.SubmitCommitmentResponse, error) { + return nil, fmt.Errorf("submit_commitment is not supported in parent mode - use submit_shard_root instead") +} + +// GetInclusionProof - not used in parent mode +func (pas *ParentAggregatorService) GetInclusionProof(ctx context.Context, req *api.GetInclusionProofRequest) (*api.GetInclusionProofResponse, error) { + return nil, fmt.Errorf("get_inclusion_proof is not supported in parent mode - use get_shard_proof instead") +} + +// GetNoDeletionProof - TODO: implement +func (pas *ParentAggregatorService) GetNoDeletionProof(ctx context.Context) (*api.GetNoDeletionProofResponse, error) { + return nil, fmt.Errorf("get_no_deletion_proof not implemented yet in parent mode") +} + +// GetBlockHeight retrieves the current parent block height +func (pas *ParentAggregatorService) GetBlockHeight(ctx context.Context) (*api.GetBlockHeightResponse, error) { + latestBlockNumber, err := pas.storage.BlockStorage().GetLatestNumber(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get latest parent block number: %w", err) + } + + return &api.GetBlockHeightResponse{ + BlockNumber: latestBlockNumber, + }, nil +} + +// GetBlock - TODO: implement +func (pas *ParentAggregatorService) GetBlock(ctx context.Context, req *api.GetBlockRequest) (*api.GetBlockResponse, error) { + return nil, fmt.Errorf("get_block not implemented yet in parent mode") +} + +// GetBlockCommitments - TODO: implement +func (pas *ParentAggregatorService) GetBlockCommitments(ctx context.Context, req *api.GetBlockCommitmentsRequest) (*api.GetBlockCommitmentsResponse, error) { + return nil, fmt.Errorf("get_block_commitments not implemented yet in parent mode") +} + +// GetHealthStatus retrieves the health status of the parent aggregator service +func (pas *ParentAggregatorService) GetHealthStatus(ctx context.Context) (*api.HealthStatus, error) { + // Check if HA is enabled and determine role + var role string + if pas.leaderSelector != nil { + isLeader, err := pas.leaderSelector.IsLeader(ctx) + if err != nil { + pas.logger.WithContext(ctx).Warn("Failed to check leadership status", "error", err.Error()) + // Don't fail health check on leadership query failure + isLeader = false + } + + if isLeader { + role = "parent-leader" + } else { + role = "parent-follower" + } + } else { + role = "parent-standalone" + } + + sharding := api.Sharding{ + Mode: pas.config.Sharding.Mode.String(), + ShardIDLen: pas.config.Sharding.ShardIDLength, + } + status := models.NewHealthStatus(role, pas.config.HA.ServerID, sharding) + + // Add database connectivity check + if err := pas.storage.Ping(ctx); err != nil { + status.AddDetail("database", "disconnected") + pas.logger.WithContext(ctx).Error("Database health check failed", "error", err.Error()) + } else { + status.AddDetail("database", "connected") + } + + return modelToAPIHealthStatus(status), nil +} diff --git a/internal/service/parent_service_test.go b/internal/service/parent_service_test.go new file mode 100644 index 0000000..5cb2adf --- /dev/null +++ b/internal/service/parent_service_test.go @@ -0,0 +1,425 @@ +package service + +import ( + "context" + "math/big" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/logger" + "github.com/unicitynetwork/aggregator-go/internal/round" + "github.com/unicitynetwork/aggregator-go/internal/smt" + "github.com/unicitynetwork/aggregator-go/internal/storage/mongodb" + "github.com/unicitynetwork/aggregator-go/internal/testutil" + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +// ParentServiceTestSuite is the test suite for parent aggregator service +type ParentServiceTestSuite struct { + suite.Suite + cfg *config.Config + logger *logger.Logger + storage *mongodb.Storage + cleanup func() + service *ParentAggregatorService + prm *round.ParentRoundManager +} + +type staticLeaderSelector struct { + leader bool + err error +} + +func (s *staticLeaderSelector) IsLeader(ctx context.Context) (bool, error) { + if s.err != nil { + return false, s.err + } + return s.leader, nil +} + +// SetupSuite runs once before all tests +func (suite *ParentServiceTestSuite) SetupSuite() { + var err error + suite.logger, err = logger.New("info", "text", "stdout", false) + require.NoError(suite.T(), err, "Should create logger") + + suite.cfg = &config.Config{ + Sharding: config.ShardingConfig{ + Mode: config.ShardingModeParent, + ShardIDLength: 2, // 2 bits = 4 possible shards [4,5,6,7] + }, + Database: config.DatabaseConfig{ + Database: "test_parent_service", + ConnectTimeout: 5 * time.Second, + }, + BFT: config.BFTConfig{ + Enabled: false, // Use BFT stub + }, + Processing: config.ProcessingConfig{ + RoundDuration: 100 * time.Millisecond, + }, + } + + suite.storage = testutil.SetupTestStorage(suite.T(), *suite.cfg) +} + +// TearDownSuite runs once after all tests +func (suite *ParentServiceTestSuite) TearDownSuite() { + if suite.cleanup != nil { + suite.cleanup() + } +} + +// SetupTest runs before each test +func (suite *ParentServiceTestSuite) SetupTest() { + ctx := context.Background() + + // Create parent round manager + var err error + suite.prm, err = round.NewParentRoundManager(ctx, suite.cfg, suite.logger, suite.storage) + require.NoError(suite.T(), err, "Should create parent round manager") + require.NotNil(suite.T(), suite.prm, "Parent round manager should not be nil") + + // Start the round manager + err = suite.prm.Start(ctx) + require.NoError(suite.T(), err, "Should start parent round manager") + + // Activate the round manager (starts rounds) + err = suite.prm.Activate(ctx) + require.NoError(suite.T(), err, "Should activate parent round manager") + + // Create parent service with the round manager + suite.service = NewParentAggregatorService(suite.cfg, suite.logger, suite.prm, suite.storage, nil) + require.NotNil(suite.T(), suite.service, "Parent service should not be nil") +} + +// TearDownTest runs after each test +func (suite *ParentServiceTestSuite) TearDownTest() { + ctx := context.Background() + + // Stop round manager + if suite.prm != nil { + suite.prm.Stop(ctx) + } + + // Clean all collections + if err := suite.storage.CleanAllCollections(ctx); err != nil { + suite.T().Logf("Warning: failed to clean collections: %v", err) + } +} + +// Test helpers +func makeTestHash(value byte) []byte { + hash := make([]byte, 32) + hash[0] = value + return hash +} + +// waitForShardToExist polls until the shard proof is available or times out +func (suite *ParentServiceTestSuite) waitForShardToExist(ctx context.Context, shardID api.ShardID) { + timeout := time.After(5 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-timeout: + suite.T().Fatalf("Timeout waiting for shard %d to be processed", shardID) + case <-ticker.C: + response, err := suite.service.GetShardProof(ctx, &api.GetShardProofRequest{ShardID: shardID}) + if err == nil && response.MerkleTreePath != nil { + return // Shard exists and has a proof! + } + // Continue polling + } + } +} + +// SubmitShardRoot - Valid Submission +func (suite *ParentServiceTestSuite) TestSubmitShardRoot_ValidSubmission() { + ctx := context.Background() + + request := &api.SubmitShardRootRequest{ + ShardID: 4, // 0b100 - valid for 2-bit sharding + RootHash: makeTestHash(0xAA), + } + + response, err := suite.service.SubmitShardRoot(ctx, request) + suite.Require().NoError(err, "Valid submission should succeed") + suite.Require().NotNil(response, "Response should not be nil") + suite.Assert().Equal(api.ShardRootStatusSuccess, response.Status, "Response should indicate success") + + suite.T().Log("✓ Valid shard root submission accepted") +} + +func (suite *ParentServiceTestSuite) TestSubmitShardRoot_NotLeader() { + ctx := context.Background() + + notLeaderSelector := &staticLeaderSelector{leader: false} + suite.service = NewParentAggregatorService(suite.cfg, suite.logger, suite.prm, suite.storage, notLeaderSelector) + + request := &api.SubmitShardRootRequest{ + ShardID: 4, + RootHash: makeTestHash(0xAA), + } + + response, err := suite.service.SubmitShardRoot(ctx, request) + suite.Require().NoError(err, "Follower rejection should not return Go error") + suite.Require().NotNil(response, "Response should not be nil") + suite.Assert().Equal(api.ShardRootStatusNotLeader, response.Status, "Follower should reject shard submissions with NOT_LEADER") +} + +// SubmitShardRoot - Invalid ShardID (empty) +func (suite *ParentServiceTestSuite) TestSubmitShardRoot_EmptyShardID() { + ctx := context.Background() + + request := &api.SubmitShardRootRequest{ + ShardID: 0, // Empty + RootHash: makeTestHash(0xAA), + } + + response, err := suite.service.SubmitShardRoot(ctx, request) + suite.Require().NoError(err, "Should not return Go error") + suite.Require().NotNil(response, "Response should not be nil") + suite.Assert().Equal(api.ShardRootStatusInvalidShardID, response.Status, "Should return INVALID_SHARD_ID status") + + suite.T().Log("✓ Empty shard ID rejected correctly") +} + +// SubmitShardRoot - Invalid ShardID (out of range) +func (suite *ParentServiceTestSuite) TestSubmitShardRoot_OutOfRange() { + ctx := context.Background() + + // Test shard ID below minimum (MSB not set) + // For ShardIDLength=2, valid range is [4,7] (0b100-0b111) + // ShardID=1 (0b001) has no MSB prefix bit + requestLow := &api.SubmitShardRootRequest{ + ShardID: 1, // Below minimum + RootHash: makeTestHash(0xAA), + } + + responseLow, err := suite.service.SubmitShardRoot(ctx, requestLow) + suite.Require().NoError(err, "Should not return Go error") + suite.Require().NotNil(responseLow, "Response should not be nil") + suite.Assert().Equal(api.ShardRootStatusInvalidShardID, responseLow.Status, "Should return INVALID_SHARD_ID for below minimum") + + // Test shard ID above maximum + // ShardID=8 (0b1000) exceeds the 2-bit range + requestHigh := &api.SubmitShardRootRequest{ + ShardID: 8, // Above maximum + RootHash: makeTestHash(0xBB), + } + + responseHigh, err := suite.service.SubmitShardRoot(ctx, requestHigh) + suite.Require().NoError(err, "Should not return Go error") + suite.Require().NotNil(responseHigh, "Response should not be nil") + suite.Assert().Equal(api.ShardRootStatusInvalidShardID, responseHigh.Status, "Should return INVALID_SHARD_ID for above maximum") + + suite.T().Log("✓ Out of range shard IDs rejected correctly") +} + +// SubmitShardRoot - Verify Update is Queued +func (suite *ParentServiceTestSuite) TestSubmitShardRoot_UpdateQueued() { + ctx := context.Background() + + shard0ID := 0 + shard0Root := makeTestHash(0xAA) + + request := &api.SubmitShardRootRequest{ + ShardID: shard0ID, + RootHash: shard0Root, + } + + response, err := suite.service.SubmitShardRoot(ctx, request) + suite.Require().NoError(err, "Submission should succeed") + suite.Require().NotNil(response, "Response should not be nil") + + // Wait for round to process + time.Sleep(150 * time.Millisecond) + + // Verify that the parent SMT root has changed (indicating the update was processed) + rootHash := suite.prm.GetSMT().GetRootHash() + suite.Assert().NotEmpty(rootHash, "Parent SMT root should be calculated") + + suite.T().Log("✓ Shard update was queued and processed correctly") +} + +// GetShardProof - Success - Full E2E test with child SMT, proof joining, and verification +func (suite *ParentServiceTestSuite) TestGetShardProof_Success() { + ctx := context.Background() + + shard0ID := 4 // 0b100 - valid for 2-bit sharding (ShardIDLength=2) + + // 1. Create a real child SMT with some data (simulating what child aggregator does) + childSMT := smt.NewChildSparseMerkleTree(api.SHA256, 4, shard0ID) + + // Add some leaves to the child SMT + testLeafPath := big.NewInt(0b10000) // Request ID within this shard + testLeafValue := []byte{0x61} // Commitment data + err := childSMT.AddLeaf(testLeafPath, testLeafValue) + suite.Require().NoError(err, "Should add leaf to child SMT") + + // 2. Extract child root hash (what child aggregator would submit to parent) + childRootHash := childSMT.GetRootHash() + suite.Require().NotEmpty(childRootHash, "Child SMT should have root hash") + suite.T().Logf("Child SMT root hash: %x", childRootHash) + + // 3. Submit child root to parent (strip algorithm prefix - first 2 bytes) + // This is required for JoinPaths to work: parent stores raw 32-byte hashes + childRootRaw := childRootHash[2:] // Remove algorithm identifier (2 bytes) + suite.Require().True(len(childRootRaw) == 32, "Root hash should be 32 bytes after stripping prefix") + suite.T().Logf("Sending %d bytes to parent (WITHOUT algorithm prefix)", len(childRootRaw)) + + submitReq := &api.SubmitShardRootRequest{ + ShardID: shard0ID, + RootHash: childRootRaw, + } + _, err = suite.service.SubmitShardRoot(ctx, submitReq) + suite.Require().NoError(err) + + // 4. Wait for round to process + suite.waitForShardToExist(ctx, shard0ID) + + // 5. Get child proof from child SMT + childProof, err := childSMT.GetPath(testLeafPath) + suite.Require().NoError(err, "Should get child proof") + suite.Require().NotNil(childProof, "Child proof should not be nil") + suite.T().Logf("Child proof has %d steps", len(childProof.Steps)) + + // 6. Request parent proof from parent aggregator + proofReq := &api.GetShardProofRequest{ + ShardID: shard0ID, + } + parentResponse, err := suite.service.GetShardProof(ctx, proofReq) + suite.Require().NoError(err, "Should get parent proof successfully") + suite.Require().NotNil(parentResponse, "Parent response should not be nil") + suite.Require().NotNil(parentResponse.MerkleTreePath, "Parent proof should not be nil") + suite.T().Logf("Parent proof has %d steps", len(parentResponse.MerkleTreePath.Steps)) + + // 7. Join child and parent proofs + joinedProof, err := smt.JoinPaths(childProof, parentResponse.MerkleTreePath) + suite.Require().NoError(err, "Should join proofs successfully") + suite.Require().NotNil(joinedProof, "Joined proof should not be nil") + suite.T().Logf("Joined proof has %d steps", len(joinedProof.Steps)) + + // 8. Verify the joined proof + result, err := joinedProof.Verify(testLeafPath) + suite.Require().NoError(err, "Proof verification should not error") + suite.Require().NotNil(result, "Verification result should not be nil") + + // Both PathValid and PathIncluded should be true + suite.Assert().True(result.PathValid, "Joined proof path must be valid") + suite.Assert().True(result.PathIncluded, "Joined proof should show path is included") + suite.Assert().True(result.Result, "Overall verification result should be true") + + suite.T().Log("✓ End-to-end test: child SMT → parent submission → proof joining → verification SUCCESS") +} + +// GetShardProof - Non-existent Shard (returns nil MerkleTreePath) +func (suite *ParentServiceTestSuite) TestGetShardProof_NonExistentShard() { + ctx := context.Background() + + // Submit one shard + shard0ID := 0b100 // 4 - valid for 2-bit sharding + submitReq := &api.SubmitShardRootRequest{ + ShardID: shard0ID, + RootHash: makeTestHash(0xAA), + } + _, err := suite.service.SubmitShardRoot(ctx, submitReq) + suite.Require().NoError(err) + + // Wait for round to process + suite.waitForShardToExist(ctx, shard0ID) + + // Request proof for a shard that was never submitted + shard5ID := 0b101 // 5 - valid for 2-bit sharding + proofReq := &api.GetShardProofRequest{ + ShardID: shard5ID, + } + + response, err := suite.service.GetShardProof(ctx, proofReq) + suite.Require().NoError(err, "Should not return error for non-existent shard") + suite.Require().NotNil(response, "Response should not be nil") + suite.Assert().Nil(response.MerkleTreePath, "MerkleTreePath should be nil for non-existent shard") + + suite.T().Log("✓ GetShardProof returns nil MerkleTreePath for non-existent shard") +} + +// GetShardProof - Empty Tree (no shards submitted yet) +func (suite *ParentServiceTestSuite) TestGetShardProof_EmptyTree() { + ctx := context.Background() + + // Request proof before any shards have been submitted + shard0ID := 4 // 0b100 - valid for 2-bit sharding + proofReq := &api.GetShardProofRequest{ + ShardID: shard0ID, + } + + response, err := suite.service.GetShardProof(ctx, proofReq) + suite.Require().NoError(err, "Should not return error for empty tree") + suite.Require().NotNil(response, "Response should not be nil") + suite.Assert().Nil(response.MerkleTreePath, "MerkleTreePath should be nil when no shards submitted") + + suite.T().Log("✓ GetShardProof returns nil MerkleTreePath for empty tree") +} + +// GetShardProof - Multiple Shards (verify each has correct proof) +func (suite *ParentServiceTestSuite) TestGetShardProof_MultipleShards() { + ctx := context.Background() + + // Submit 3 different shards + shard2ID := 0b100 + shard0ID := 0b101 + shard1ID := 0b111 + + _, err := suite.service.SubmitShardRoot(ctx, &api.SubmitShardRootRequest{ + ShardID: shard0ID, + RootHash: makeTestHash(0xAA), + }) + suite.Require().NoError(err) + + _, err = suite.service.SubmitShardRoot(ctx, &api.SubmitShardRootRequest{ + ShardID: shard1ID, + RootHash: makeTestHash(0xBB), + }) + suite.Require().NoError(err) + + _, err = suite.service.SubmitShardRoot(ctx, &api.SubmitShardRootRequest{ + ShardID: shard2ID, + RootHash: makeTestHash(0xCC), + }) + suite.Require().NoError(err) + + // Wait for all shards to be processed + suite.waitForShardToExist(ctx, shard0ID) + suite.waitForShardToExist(ctx, shard1ID) + suite.waitForShardToExist(ctx, shard2ID) + + // Get proofs for all 3 shards + proof0, err := suite.service.GetShardProof(ctx, &api.GetShardProofRequest{ShardID: shard0ID}) + suite.Require().NoError(err, "Should get proof for shard 0") + suite.Assert().NotNil(proof0.MerkleTreePath, "Proof 0 should not be nil") + + proof1, err := suite.service.GetShardProof(ctx, &api.GetShardProofRequest{ShardID: shard1ID}) + suite.Require().NoError(err, "Should get proof for shard 1") + suite.Assert().NotNil(proof1.MerkleTreePath, "Proof 1 should not be nil") + + proof2, err := suite.service.GetShardProof(ctx, &api.GetShardProofRequest{ShardID: shard2ID}) + suite.Require().NoError(err, "Should get proof for shard 2") + suite.Assert().NotNil(proof2.MerkleTreePath, "Proof 2 should not be nil") + + // All proofs should have the same root (same parent SMT) + suite.Assert().Equal(proof0.MerkleTreePath.Root, proof1.MerkleTreePath.Root, "All proofs should have same root") + suite.Assert().Equal(proof0.MerkleTreePath.Root, proof2.MerkleTreePath.Root, "All proofs should have same root") + + suite.T().Log("✓ GetShardProof returns valid proofs for multiple shards with same root") +} + +// TestParentServiceSuite runs the test suite +func TestParentServiceSuite(t *testing.T) { + suite.Run(t, new(ParentServiceTestSuite)) +} diff --git a/internal/service/service.go b/internal/service/service.go index 90c6f32..0b434e7 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -10,17 +10,56 @@ import ( "github.com/unicitynetwork/aggregator-go/internal/models" "github.com/unicitynetwork/aggregator-go/internal/round" "github.com/unicitynetwork/aggregator-go/internal/signing" + "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" "github.com/unicitynetwork/aggregator-go/pkg/api" ) +// Service defines the common interface that all service implementations must satisfy +type Service interface { + // JSON-RPC methods + SubmitCommitment(ctx context.Context, req *api.SubmitCommitmentRequest) (*api.SubmitCommitmentResponse, error) + GetInclusionProof(ctx context.Context, req *api.GetInclusionProofRequest) (*api.GetInclusionProofResponse, error) + GetNoDeletionProof(ctx context.Context) (*api.GetNoDeletionProofResponse, error) + GetBlockHeight(ctx context.Context) (*api.GetBlockHeightResponse, error) + GetBlock(ctx context.Context, req *api.GetBlockRequest) (*api.GetBlockResponse, error) + GetBlockCommitments(ctx context.Context, req *api.GetBlockCommitmentsRequest) (*api.GetBlockCommitmentsResponse, error) + GetHealthStatus(ctx context.Context) (*api.HealthStatus, error) + + // Parent mode specific methods + SubmitShardRoot(ctx context.Context, req *api.SubmitShardRootRequest) (*api.SubmitShardRootResponse, error) + GetShardProof(ctx context.Context, req *api.GetShardProofRequest) (*api.GetShardProofResponse, error) +} + +// NewService creates the appropriate service based on sharding mode +func NewService(ctx context.Context, cfg *config.Config, logger *logger.Logger, roundManager round.Manager, commitmentQueue interfaces.CommitmentQueue, storage interfaces.Storage, leaderSelector LeaderSelector) (Service, error) { + switch cfg.Sharding.Mode { + case config.ShardingModeStandalone: + rm, ok := roundManager.(*round.RoundManager) + if !ok { + return nil, fmt.Errorf("invalid round manager type for standalone mode") + } + return NewAggregatorService(cfg, logger, rm, commitmentQueue, storage, leaderSelector), nil + case config.ShardingModeParent: + prm, ok := roundManager.(*round.ParentRoundManager) + if !ok { + return nil, fmt.Errorf("invalid round manager type for parent mode") + } + return NewParentAggregatorService(cfg, logger, prm, storage, leaderSelector), nil + case config.ShardingModeChild: + return NewAggregatorService(cfg, logger, roundManager, commitmentQueue, storage, leaderSelector), nil + default: + return nil, fmt.Errorf("unsupported sharding mode: %s", cfg.Sharding.Mode) + } +} + // AggregatorService implements the business logic for the aggregator type AggregatorService struct { config *config.Config logger *logger.Logger commitmentQueue interfaces.CommitmentQueue storage interfaces.Storage - roundManager *round.RoundManager + roundManager round.Manager leaderSelector LeaderSelector commitmentValidator *signing.CommitmentValidator } @@ -31,20 +70,6 @@ type LeaderSelector interface { // Conversion functions between API and internal model types -func modelToAPIBigInt(modelBigInt *api.BigInt) *api.BigInt { - if modelBigInt == nil { - return nil - } - return &api.BigInt{Int: modelBigInt.Int} -} - -func apiToModelBigInt(apiBigInt *api.BigInt) *api.BigInt { - if apiBigInt == nil { - return nil - } - return &api.BigInt{Int: apiBigInt.Int} -} - func modelToAPIAggregatorRecord(modelRecord *models.AggregatorRecord) *api.AggregatorRecord { return &api.AggregatorRecord{ RequestID: modelRecord.RequestID, @@ -56,8 +81,8 @@ func modelToAPIAggregatorRecord(modelRecord *models.AggregatorRecord) *api.Aggre StateHash: api.StateHash(modelRecord.Authenticator.StateHash.String()), }, AggregateRequestCount: modelRecord.AggregateRequestCount, - BlockNumber: modelToAPIBigInt(modelRecord.BlockNumber), - LeafIndex: modelToAPIBigInt(modelRecord.LeafIndex), + BlockNumber: modelRecord.BlockNumber, + LeafIndex: modelRecord.LeafIndex, CreatedAt: modelRecord.CreatedAt, FinalizedAt: modelRecord.FinalizedAt, } @@ -65,15 +90,17 @@ func modelToAPIAggregatorRecord(modelRecord *models.AggregatorRecord) *api.Aggre func modelToAPIBlock(modelBlock *models.Block) *api.Block { return &api.Block{ - Index: modelToAPIBigInt(modelBlock.Index), - ChainID: modelBlock.ChainID, - Version: modelBlock.Version, - ForkID: modelBlock.ForkID, - RootHash: modelBlock.RootHash, - PreviousBlockHash: modelBlock.PreviousBlockHash, - NoDeletionProofHash: modelBlock.NoDeletionProofHash, - CreatedAt: modelBlock.CreatedAt, - UnicityCertificate: modelBlock.UnicityCertificate, + Index: modelBlock.Index, + ChainID: modelBlock.ChainID, + ShardID: modelBlock.ShardID, + Version: modelBlock.Version, + ForkID: modelBlock.ForkID, + RootHash: modelBlock.RootHash, + PreviousBlockHash: modelBlock.PreviousBlockHash, + NoDeletionProofHash: modelBlock.NoDeletionProofHash, + CreatedAt: modelBlock.CreatedAt, + UnicityCertificate: modelBlock.UnicityCertificate, + ParentMerkleTreePath: modelBlock.ParentMerkleTreePath, } } @@ -82,12 +109,13 @@ func modelToAPIHealthStatus(modelHealth *models.HealthStatus) *api.HealthStatus Status: modelHealth.Status, Role: modelHealth.Role, ServerID: modelHealth.ServerID, + Sharding: modelHealth.Sharding, Details: modelHealth.Details, } } // NewAggregatorService creates a new aggregator service -func NewAggregatorService(cfg *config.Config, logger *logger.Logger, roundManager *round.RoundManager, commitmentQueue interfaces.CommitmentQueue, storage interfaces.Storage, leaderSelector LeaderSelector) *AggregatorService { +func NewAggregatorService(cfg *config.Config, logger *logger.Logger, roundManager round.Manager, commitmentQueue interfaces.CommitmentQueue, storage interfaces.Storage, leaderSelector LeaderSelector) *AggregatorService { return &AggregatorService{ config: cfg, logger: logger, @@ -95,7 +123,7 @@ func NewAggregatorService(cfg *config.Config, logger *logger.Logger, roundManage storage: storage, roundManager: roundManager, leaderSelector: leaderSelector, - commitmentValidator: signing.NewCommitmentValidator(), + commitmentValidator: signing.NewCommitmentValidator(cfg.Sharding), } } @@ -186,17 +214,19 @@ func (as *AggregatorService) SubmitCommitment(ctx context.Context, req *api.Subm // GetInclusionProof retrieves inclusion proof for a commitment func (as *AggregatorService) GetInclusionProof(ctx context.Context, req *api.GetInclusionProofRequest) (*api.GetInclusionProofResponse, error) { - // First check if commitment exists in aggregator records (finalized) - record, err := as.storage.AggregatorRecordStorage().GetByRequestID(ctx, req.RequestID) - if err != nil { - return nil, fmt.Errorf("failed to get aggregator record: %w", err) + // verify that the request ID matches the shard ID of this aggregator + if err := as.commitmentValidator.ValidateShardID(req.RequestID); err != nil { + return nil, fmt.Errorf("request ID validation failed: %w", err) } path, err := req.RequestID.GetPath() if err != nil { return nil, fmt.Errorf("failed to get path for request ID %s: %w", req.RequestID, err) } - merkleTreePath := as.roundManager.GetSMT().GetPath(path) + merkleTreePath, err := as.roundManager.GetSMT().GetPath(path) + if err != nil { + return nil, fmt.Errorf("failed to get inclusion proof for request ID %s: %w", req.RequestID, err) + } // Find the latest block that matches the current SMT root hash rootHash, err := api.NewHexBytesFromString(merkleTreePath.Root) @@ -211,6 +241,19 @@ func (as *AggregatorService) GetInclusionProof(ctx context.Context, req *api.Get return nil, fmt.Errorf("no block found with root hash %s", rootHash) } + // Join parent and child SMT paths if sharding mode is enabled + if as.config.Sharding.Mode == config.ShardingModeChild { + merkleTreePath, err = smt.JoinPaths(merkleTreePath, block.ParentMerkleTreePath) + if err != nil { + return nil, fmt.Errorf("failed to join parent and child aggregator paths: %w", err) + } + } + + // First check if commitment exists in aggregator records (finalized) + record, err := as.storage.AggregatorRecordStorage().GetByRequestID(ctx, req.RequestID) + if err != nil { + return nil, fmt.Errorf("failed to get aggregator record: %w", err) + } if record == nil { // Non-inclusion proof return &api.GetInclusionProofResponse{ @@ -222,17 +265,9 @@ func (as *AggregatorService) GetInclusionProof(ctx context.Context, req *api.Get }, }, nil } - - authenticator := &api.Authenticator{ - Algorithm: record.Authenticator.Algorithm, - PublicKey: record.Authenticator.PublicKey, - Signature: record.Authenticator.Signature, - StateHash: record.Authenticator.StateHash, - } - return &api.GetInclusionProofResponse{ InclusionProof: &api.InclusionProof{ - Authenticator: authenticator, + Authenticator: record.Authenticator.ToAPI(), MerkleTreePath: merkleTreePath, TransactionHash: &record.TransactionHash, UnicityCertificate: block.UnicityCertificate, @@ -259,7 +294,7 @@ func (as *AggregatorService) GetBlockHeight(ctx context.Context) (*api.GetBlockH } return &api.GetBlockHeightResponse{ - BlockNumber: modelToAPIBigInt(latestBlockNumber), + BlockNumber: latestBlockNumber, }, nil } @@ -317,7 +352,7 @@ func (as *AggregatorService) GetBlock(ctx context.Context, req *api.GetBlockRequ // GetBlockCommitments retrieves all commitments in a block func (as *AggregatorService) GetBlockCommitments(ctx context.Context, req *api.GetBlockCommitmentsRequest) (*api.GetBlockCommitmentsResponse, error) { - records, err := as.storage.AggregatorRecordStorage().GetByBlockNumber(ctx, apiToModelBigInt(req.BlockNumber)) + records, err := as.storage.AggregatorRecordStorage().GetByBlockNumber(ctx, req.BlockNumber) if err != nil { return nil, fmt.Errorf("failed to get block commitments: %w", err) } @@ -354,7 +389,12 @@ func (as *AggregatorService) GetHealthStatus(ctx context.Context) (*api.HealthSt role = "standalone" } - status := models.NewHealthStatus(role, as.config.HA.ServerID) + sharding := api.Sharding{ + Mode: as.config.Sharding.Mode.String(), + ShardIDLen: as.config.Sharding.ShardIDLength, + ShardID: as.config.Sharding.Child.ShardID, + } + status := models.NewHealthStatus(role, as.config.HA.ServerID, sharding) // Add database connectivity check if err := as.storage.Ping(ctx); err != nil { @@ -389,3 +429,13 @@ func (as *AggregatorService) GetHealthStatus(ctx context.Context) (*api.HealthSt return modelToAPIHealthStatus(status), nil } + +// SubmitShardRoot - not supported in standalone mode +func (as *AggregatorService) SubmitShardRoot(ctx context.Context, req *api.SubmitShardRootRequest) (*api.SubmitShardRootResponse, error) { + return nil, fmt.Errorf("submit_shard_root is not supported in standalone mode") +} + +// GetShardProof - not supported in standalone mode +func (as *AggregatorService) GetShardProof(ctx context.Context, req *api.GetShardProofRequest) (*api.GetShardProofResponse, error) { + return nil, fmt.Errorf("get_shard_proof is not supported in standalone mode") +} diff --git a/internal/service/service_test.go b/internal/service/service_test.go index 9304a5b..a3f56bf 100644 --- a/internal/service/service_test.go +++ b/internal/service/service_test.go @@ -27,7 +27,9 @@ import ( "github.com/unicitynetwork/aggregator-go/internal/ha/state" "github.com/unicitynetwork/aggregator-go/internal/logger" "github.com/unicitynetwork/aggregator-go/internal/round" + "github.com/unicitynetwork/aggregator-go/internal/sharding" "github.com/unicitynetwork/aggregator-go/internal/signing" + "github.com/unicitynetwork/aggregator-go/internal/smt" mongodbStorage "github.com/unicitynetwork/aggregator-go/internal/storage/mongodb" "github.com/unicitynetwork/aggregator-go/pkg/api" "github.com/unicitynetwork/aggregator-go/pkg/jsonrpc" @@ -120,7 +122,8 @@ func setupMongoDBAndAggregator(t *testing.T, ctx context.Context) (string, func( commitmentQueue := mongoStorage.CommitmentQueue() // Initialize round manager - roundManager, err := round.NewRoundManager(ctx, cfg, log, commitmentQueue, mongoStorage, state.NewSyncStateTracker()) + rootAggregatorClient := sharding.NewRootAggregatorClientStub() + roundManager, err := round.NewRoundManager(ctx, cfg, log, smt.NewSparseMerkleTree(api.SHA256, 16+256), commitmentQueue, mongoStorage, rootAggregatorClient, state.NewSyncStateTracker()) require.NoError(t, err) // Start the round manager (restores SMT) diff --git a/internal/sharding/root_aggregator_client.go b/internal/sharding/root_aggregator_client.go new file mode 100644 index 0000000..93c7570 --- /dev/null +++ b/internal/sharding/root_aggregator_client.go @@ -0,0 +1,112 @@ +package sharding + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync/atomic" + + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +type ( + RootAggregatorClient struct { + rpcURL string + httpClient *http.Client + requestIDC *atomic.Int64 + } + + jsonRpcRequest struct { + JsonRpc string `json:"jsonrpc"` + Method string `json:"method"` + Params interface{} `json:"params"` + ID int64 `json:"id"` + } + + jsonRpcResponse[T any] struct { + JsonRpc string `json:"jsonrpc"` + Result *T `json:"result,omitempty"` + Error *jsonRpcError `json:"error,omitempty"` + ID int64 `json:"id"` + } + + jsonRpcError struct { + Code int `json:"code"` + Message string `json:"message"` + Data string `json:"data,omitempty"` + } +) + +func NewRootAggregatorClient(rpcURL string) *RootAggregatorClient { + return &RootAggregatorClient{ + rpcURL: rpcURL, + httpClient: &http.Client{}, + requestIDC: new(atomic.Int64), + } +} + +func (c *RootAggregatorClient) SubmitShardRoot(ctx context.Context, req *api.SubmitShardRootRequest) error { + result, err := doRpcRequest[api.SubmitShardRootResponse](ctx, c, "submit_shard_root", req) + if err != nil { + return fmt.Errorf("failed to submit shard root: %w", err) + } + if result.Status != api.ShardRootStatusSuccess { + return fmt.Errorf("unexpected status: %s", result.Status) + } + return nil +} + +func (c *RootAggregatorClient) GetShardProof(ctx context.Context, req *api.GetShardProofRequest) (*api.RootShardInclusionProof, error) { + response, err := doRpcRequest[api.RootShardInclusionProof](ctx, c, "get_shard_proof", req) + if err != nil { + return nil, fmt.Errorf("failed to fetch shard proof: %w", err) + } + return response, nil +} + +func doRpcRequest[T any](ctx context.Context, c *RootAggregatorClient, method string, params interface{}) (*T, error) { + rpcReq := jsonRpcRequest{ + JsonRpc: "2.0", + Method: method, + Params: params, + ID: c.requestIDC.Add(1), + } + + reqBody, err := json.Marshal(rpcReq) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.rpcURL, bytes.NewReader(reqBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("http error: %s, body: %s", resp.Status, body) + } + + var rpcResp jsonRpcResponse[T] + if err := json.NewDecoder(resp.Body).Decode(&rpcResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if rpcResp.Error != nil { + return nil, fmt.Errorf("rpc error: %s", rpcResp.Error.Message) + } + if rpcResp.Result == nil { + return nil, fmt.Errorf("rpc error: result is nil") + } + return rpcResp.Result, nil +} diff --git a/internal/sharding/root_aggregator_client_stub.go b/internal/sharding/root_aggregator_client_stub.go new file mode 100644 index 0000000..f933b46 --- /dev/null +++ b/internal/sharding/root_aggregator_client_stub.go @@ -0,0 +1,74 @@ +package sharding + +import ( + "context" + "sync" + + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +type RootAggregatorClientStub struct { + mu sync.Mutex + submissionCount int + returnedProofCount int + submissions map[int]*api.SubmitShardRootRequest // shardID => last request + submittedRootHash api.HexBytes + submissionError error +} + +func NewRootAggregatorClientStub() *RootAggregatorClientStub { + return &RootAggregatorClientStub{ + submissions: make(map[int]*api.SubmitShardRootRequest), + } +} + +func (m *RootAggregatorClientStub) SubmitShardRoot(ctx context.Context, request *api.SubmitShardRootRequest) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.submissionError != nil { + return m.submissionError + } + m.submissionCount++ + m.submissions[request.ShardID] = request + m.submittedRootHash = request.RootHash + return nil +} + +func (m *RootAggregatorClientStub) GetShardProof(ctx context.Context, request *api.GetShardProofRequest) (*api.RootShardInclusionProof, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.submissions[request.ShardID] != nil { + m.returnedProofCount++ + submittedRootHash := m.submittedRootHash.String() + return &api.RootShardInclusionProof{ + UnicityCertificate: api.HexBytes("1234"), + MerkleTreePath: &api.MerkleTreePath{ + Steps: []api.MerkleTreeStep{{Data: &submittedRootHash}}, + }, + }, nil + } + return nil, nil +} + +func (m *RootAggregatorClientStub) SubmissionCount() int { + m.mu.Lock() + defer m.mu.Unlock() + + return m.submissionCount +} + +func (m *RootAggregatorClientStub) ProofCount() int { + m.mu.Lock() + defer m.mu.Unlock() + + return m.returnedProofCount +} + +func (m *RootAggregatorClientStub) SetSubmissionError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.submissionError = err +} diff --git a/internal/sharding/root_aggregator_client_test.go b/internal/sharding/root_aggregator_client_test.go new file mode 100644 index 0000000..49b88e3 --- /dev/null +++ b/internal/sharding/root_aggregator_client_test.go @@ -0,0 +1,127 @@ +package sharding + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +func TestRootAggregatorClient_SubmitShardRoot(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req jsonRpcRequest + err := json.NewDecoder(r.Body).Decode(&req) + require.NoError(t, err) + + require.Equal(t, "2.0", req.JsonRpc) + require.Equal(t, int64(1), req.ID) + require.Equal(t, "submit_shard_root", req.Method) + + params := req.Params.(map[string]interface{}) + require.Equal(t, float64(4), params["shardId"]) + require.Equal(t, "010203", params["rootHash"]) + + resp := jsonRpcResponse[api.SubmitShardRootResponse]{ + JsonRpc: "2.0", + Result: &api.SubmitShardRootResponse{Status: api.ShardRootStatusSuccess}, + ID: req.ID, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + client := NewRootAggregatorClient(server.URL) + rootHash, err := api.NewHexBytesFromString("010203") + require.NoError(t, err) + req := &api.SubmitShardRootRequest{ + ShardID: 4, + RootHash: rootHash, + } + err = client.SubmitShardRoot(context.Background(), req) + require.NoError(t, err) +} + +func TestRootAggregatorClient_GetShardProof(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req jsonRpcRequest + err := json.NewDecoder(r.Body).Decode(&req) + require.NoError(t, err) + + require.Equal(t, "2.0", req.JsonRpc) + require.Equal(t, "get_shard_proof", req.Method) + + params := req.Params.(map[string]interface{}) + require.Equal(t, float64(4), params["shardId"]) + + proof := &api.RootShardInclusionProof{ + MerkleTreePath: &api.MerkleTreePath{Root: "0x1234"}, + UnicityCertificate: api.HexBytes("0xabcdef"), + } + + resp := jsonRpcResponse[api.RootShardInclusionProof]{ + JsonRpc: "2.0", + Result: proof, + ID: req.ID, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + client := NewRootAggregatorClient(server.URL) + req := &api.GetShardProofRequest{ + ShardID: 4, + } + proof, err := client.GetShardProof(context.Background(), req) + + require.NoError(t, err) + require.NotNil(t, proof) + require.Equal(t, "0x1234", proof.MerkleTreePath.Root) + require.Equal(t, api.HexBytes("0xabcdef"), proof.UnicityCertificate) +} + +func TestRootAggregatorClient_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad request", http.StatusBadRequest) + })) + defer server.Close() + + client := NewRootAggregatorClient(server.URL) + _, err := doRpcRequest[api.RootShardInclusionProof](context.Background(), client, "get_shard_proof", &api.GetShardProofRequest{ShardID: 1}) + require.ErrorContains(t, err, "http error") +} + +func TestRootAggregatorClient_RPCErrorResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := jsonRpcResponse[api.SubmitShardRootResponse]{ + JsonRpc: "2.0", + Error: &jsonRpcError{ + Code: -32000, + Message: "invalid shard", + Data: "shard not found", + }, + ID: 1, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + client := NewRootAggregatorClient(server.URL) + err := client.SubmitShardRoot(context.Background(), &api.SubmitShardRootRequest{ShardID: 99}) + require.ErrorContains(t, err, "rpc error") +} + +func TestRootAggregatorClient_InvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("{not-json")) + })) + defer server.Close() + + client := NewRootAggregatorClient(server.URL) + _, err := client.GetShardProof(context.Background(), &api.GetShardProofRequest{ShardID: 5}) + require.ErrorContains(t, err, "decode response") +} diff --git a/internal/signing/commitment_validator.go b/internal/signing/commitment_validator.go index 5f4cdb0..db30e05 100644 --- a/internal/signing/commitment_validator.go +++ b/internal/signing/commitment_validator.go @@ -2,8 +2,11 @@ package signing import ( "encoding/hex" + "errors" "fmt" + "math/big" + "github.com/unicitynetwork/aggregator-go/internal/config" "github.com/unicitynetwork/aggregator-go/internal/models" "github.com/unicitynetwork/aggregator-go/pkg/api" ) @@ -20,6 +23,7 @@ const ( ValidationStatusInvalidStateHashFormat ValidationStatusInvalidTransactionHashFormat ValidationStatusUnsupportedAlgorithm + ValidationStatusShardMismatch ) func (s ValidationStatus) String() string { @@ -40,6 +44,8 @@ func (s ValidationStatus) String() string { return "INVALID_TRANSACTION_HASH_FORMAT" case ValidationStatusUnsupportedAlgorithm: return "UNSUPPORTED_ALGORITHM" + case ValidationStatusShardMismatch: + return "INVALID_SHARD" default: return "UNKNOWN" } @@ -54,12 +60,14 @@ type ValidationResult struct { // CommitmentValidator validates commitment signatures and request IDs type CommitmentValidator struct { signingService *SigningService + shardConfig config.ShardingConfig } // NewCommitmentValidator creates a new commitment validator -func NewCommitmentValidator() *CommitmentValidator { +func NewCommitmentValidator(shardConfig config.ShardingConfig) *CommitmentValidator { return &CommitmentValidator{ signingService: NewSigningService(), + shardConfig: shardConfig, } } @@ -74,6 +82,14 @@ func (v *CommitmentValidator) ValidateCommitment(commitment *models.Commitment) } } + // 1.1 Verify correct shard + if err := v.ValidateShardID(commitment.RequestID); err != nil { + return ValidationResult{ + Status: ValidationStatusShardMismatch, + Error: fmt.Errorf("invalid shard: %w", err), + } + } + // 2. Parse and validate public key // HexBytes already contains the binary data, no need to decode publicKeyBytes := []byte(commitment.Authenticator.PublicKey) @@ -201,3 +217,49 @@ func (v *CommitmentValidator) ValidateCommitment(commitment *models.Commitment) Error: nil, } } + +// ValidateShardID verifies if the request id belongs to the configured shard +func (v *CommitmentValidator) ValidateShardID(requestID api.RequestID) error { + if !v.shardConfig.Mode.IsChild() { + return nil + } + ok, err := verifyShardID(requestID.String(), v.shardConfig.Child.ShardID) + if err != nil { + return fmt.Errorf("error verifying shard id: %w", err) + } + if !ok { + return errors.New("request ID shard part does not match the current shard identifier") + } + return nil +} + +// verifyShardID Checks if commitmentID's least significant bits match the shard bitmask. +func verifyShardID(commitmentID string, shardBitmask int) (bool, error) { + // convert to big.Ints + bytes, err := hex.DecodeString(commitmentID) + if err != nil { + return false, fmt.Errorf("failed to decode commitment ID: %w", err) + } + commitmentIdBigInt := new(big.Int).SetBytes(bytes) + shardBitmaskBigInt := new(big.Int).SetInt64(int64(shardBitmask)) + + // find position of MSB e.g. + // 0b111 -> BitLen=3 -> 3-1=2 + msbPos := shardBitmaskBigInt.BitLen() - 1 + + // build a mask covering bits below MSB e.g. + // 1<<2=0b100; 0b100-1=0b11; compareMask=0b11 + compareMask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), uint(msbPos)), big.NewInt(1)) + + // remove MSB from shardBitmask to get expected value e.g. + // 0b111 & 0b11 = 0b11 + expected := new(big.Int).And(shardBitmaskBigInt, compareMask) + + // extract low bits from commitment e.g. + // commitment=0b11111111 & 0b11 = 0b11 + commitmentLowBits := new(big.Int).And(commitmentIdBigInt, compareMask) + + // return true if the commitment low bits match bitmask bits e.g. + // 0b11 == 0b11 + return commitmentLowBits.Cmp(expected) == 0, nil +} diff --git a/internal/signing/commitment_validator_test.go b/internal/signing/commitment_validator_test.go index 35e7e8b..927258f 100644 --- a/internal/signing/commitment_validator_test.go +++ b/internal/signing/commitment_validator_test.go @@ -8,21 +8,13 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/stretchr/testify/require" + "github.com/unicitynetwork/aggregator-go/internal/config" "github.com/unicitynetwork/aggregator-go/internal/models" "github.com/unicitynetwork/aggregator-go/pkg/api" ) -// Helper function to convert hex string to HexBytes for tests -func hexStringToHexBytes(hexStr string) api.HexBytes { - data, err := hex.DecodeString(hexStr) - if err != nil { - panic(err) - } - return data -} - -func TestCommitmentValidator_ValidateCommitment_Success(t *testing.T) { - validator := NewCommitmentValidator() +func TestValidator_Success(t *testing.T) { + validator := newDefaultCommitmentValidator() // Generate a test key pair for signing privateKey, err := btcec.NewPrivateKey() @@ -69,8 +61,8 @@ func TestCommitmentValidator_ValidateCommitment_Success(t *testing.T) { require.NoError(t, result.Error, "Expected no error") } -func TestCommitmentValidator_ValidateCommitment_UnsupportedAlgorithm(t *testing.T) { - validator := NewCommitmentValidator() +func TestValidator_UnsupportedAlgorithm(t *testing.T) { + validator := newDefaultCommitmentValidator() commitment := &models.Commitment{ RequestID: api.RequestID("00000123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"), @@ -89,8 +81,8 @@ func TestCommitmentValidator_ValidateCommitment_UnsupportedAlgorithm(t *testing. require.Error(t, result.Error, "Expected error for unsupported algorithm") } -func TestCommitmentValidator_ValidateCommitment_InvalidPublicKeyFormat(t *testing.T) { - validator := NewCommitmentValidator() +func TestValidator_InvalidPublicKeyFormat(t *testing.T) { + validator := newDefaultCommitmentValidator() commitment := &models.Commitment{ RequestID: api.RequestID("00000123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"), @@ -111,8 +103,8 @@ func TestCommitmentValidator_ValidateCommitment_InvalidPublicKeyFormat(t *testin } } -func TestCommitmentValidator_ValidateCommitment_InvalidStateHashFormat(t *testing.T) { - validator := NewCommitmentValidator() +func TestValidator_InvalidStateHashFormat(t *testing.T) { + validator := newDefaultCommitmentValidator() // Create valid public key privateKey, _ := btcec.NewPrivateKey() @@ -137,8 +129,8 @@ func TestCommitmentValidator_ValidateCommitment_InvalidStateHashFormat(t *testin } } -func TestCommitmentValidator_ValidateCommitment_RequestIDMismatch(t *testing.T) { - validator := NewCommitmentValidator() +func TestValidator_RequestIDMismatch(t *testing.T) { + validator := newDefaultCommitmentValidator() // Create valid public key and state hash privateKey, _ := btcec.NewPrivateKey() @@ -167,8 +159,87 @@ func TestCommitmentValidator_ValidateCommitment_RequestIDMismatch(t *testing.T) } } -func TestCommitmentValidator_ValidateCommitment_InvalidSignatureFormat(t *testing.T) { - validator := NewCommitmentValidator() +func TestValidator_ShardID(t *testing.T) { + tests := []struct { + commitmentID string + shardBitmask int + match bool + }{ + // === TWO SHARD CONFIG === + // shard1=bitmask 0b10 + // shard2=bitmask 0b11 + + // commitment ending with 0b00000000 belongs to shard1 + {"00000000000000000000000000000000000000000000000000000000000000000000", 0b10, true}, + {"00000000000000000000000000000000000000000000000000000000000000000000", 0b11, false}, + + // commitment ending with 0b00000001 belongs to shard2 + {"00000000000000000000000000000000000000000000000000000000000000000001", 0b10, false}, + {"00000000000000000000000000000000000000000000000000000000000000000001", 0b11, true}, + + // commitment ending with 0b00000010 belongs to shard1 + {"00000000000000000000000000000000000000000000000000000000000000000002", 0b10, true}, + {"00000000000000000000000000000000000000000000000000000000000000000002", 0b11, false}, + + // commitment ending with 0b00000011 belongs to shard2 + {"00000000000000000000000000000000000000000000000000000000000000000003", 0b10, false}, + {"00000000000000000000000000000000000000000000000000000000000000000003", 0b11, true}, + + // commitment ending with 0b11111111 belongs to shard2 + {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b10, false}, + {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b11, true}, + + // === END TWO SHARD CONFIG === + + // === FOUR SHARD CONFIG === + // shard1=0b100 + // shard2=0b110 + // shard3=0b101 + // shard4=0b111 + + // commitment ending with 0b00000000 belongs to shard1 + {"00000000000000000000000000000000000000000000000000000000000000000000", 0b111, false}, + {"00000000000000000000000000000000000000000000000000000000000000000000", 0b101, false}, + {"00000000000000000000000000000000000000000000000000000000000000000000", 0b110, false}, + {"00000000000000000000000000000000000000000000000000000000000000000000", 0b100, true}, + + // commitment ending with 0b00000010 belongs to shard2 + {"00000000000000000000000000000000000000000000000000000000000000000002", 0b111, false}, + {"00000000000000000000000000000000000000000000000000000000000000000002", 0b100, false}, + {"00000000000000000000000000000000000000000000000000000000000000000002", 0b101, false}, + {"00000000000000000000000000000000000000000000000000000000000000000002", 0b110, true}, + + // commitment ending with 0b00000001 belongs to shard3 + {"00000000000000000000000000000000000000000000000000000000000000000001", 0b111, false}, + {"00000000000000000000000000000000000000000000000000000000000000000001", 0b101, true}, + {"00000000000000000000000000000000000000000000000000000000000000000001", 0b110, false}, + {"00000000000000000000000000000000000000000000000000000000000000000001", 0b100, false}, + + // commitment ending with 0b00000011 belongs to shard4 + {"00000000000000000000000000000000000000000000000000000000000000000003", 0b111, true}, + {"00000000000000000000000000000000000000000000000000000000000000000003", 0b101, false}, + {"00000000000000000000000000000000000000000000000000000000000000000003", 0b110, false}, + {"00000000000000000000000000000000000000000000000000000000000000000003", 0b100, false}, + + // commitment ending with 0b11111111 belongs to shard4 + {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b111, true}, + {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b101, false}, + {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b110, false}, + {"000000000000000000000000000000000000000000000000000000000000000000FF", 0b100, false}, + + // === END FOUR SHARD CONFIG === + } + for _, tc := range tests { + match, err := verifyShardID(tc.commitmentID, tc.shardBitmask) + require.NoError(t, err) + if match != tc.match { + t.Errorf("commitmentID=%s shardBitmask=%b expected %v got %v", tc.commitmentID, tc.shardBitmask, tc.match, match) + } + } +} + +func TestValidator_InvalidSignatureFormat(t *testing.T) { + validator := newDefaultCommitmentValidator() // Create valid data except signature privateKey, _ := btcec.NewPrivateKey() @@ -197,8 +268,8 @@ func TestCommitmentValidator_ValidateCommitment_InvalidSignatureFormat(t *testin } } -func TestCommitmentValidator_ValidateCommitment_InvalidTransactionHashFormat(t *testing.T) { - validator := NewCommitmentValidator() +func TestValidator_InvalidTransactionHashFormat(t *testing.T) { + validator := newDefaultCommitmentValidator() // Create valid data except transaction hash privateKey, _ := btcec.NewPrivateKey() @@ -227,8 +298,8 @@ func TestCommitmentValidator_ValidateCommitment_InvalidTransactionHashFormat(t * } } -func TestCommitmentValidator_ValidateCommitment_SignatureVerificationFailed(t *testing.T) { - validator := NewCommitmentValidator() +func TestValidator_SignatureVerificationFailed(t *testing.T) { + validator := newDefaultCommitmentValidator() // Generate a test key pair privateKey, _ := btcec.NewPrivateKey() @@ -265,9 +336,9 @@ func TestCommitmentValidator_ValidateCommitment_SignatureVerificationFailed(t *t } } -func TestCommitmentValidator_ValidateCommitment_RealSecp256k1Data(t *testing.T) { +func TestValidator_RealSecp256k1Data(t *testing.T) { // Test with real secp256k1 cryptographic operations to ensure compatibility - validator := NewCommitmentValidator() + validator := newDefaultCommitmentValidator() // Use known test vectors privateKeyHex := "c28a9f80738afe1441ba9a68e72033f4c8d52b4f5d6d8f1e6a6b1c4a7b8e9c1f" @@ -323,7 +394,7 @@ func TestCommitmentValidator_ValidateCommitment_RealSecp256k1Data(t *testing.T) } } -func TestCommitmentValidator_ValidationStatusString(t *testing.T) { +func TestValidator_ValidationStatusString(t *testing.T) { tests := []struct { status ValidationStatus expected string @@ -336,6 +407,7 @@ func TestCommitmentValidator_ValidationStatusString(t *testing.T) { {ValidationStatusInvalidStateHashFormat, "INVALID_STATE_HASH_FORMAT"}, {ValidationStatusInvalidTransactionHashFormat, "INVALID_TRANSACTION_HASH_FORMAT"}, {ValidationStatusUnsupportedAlgorithm, "UNSUPPORTED_ALGORITHM"}, + {ValidationStatusShardMismatch, "INVALID_SHARD"}, {ValidationStatus(999), "UNKNOWN"}, // Test unknown status } @@ -347,8 +419,8 @@ func TestCommitmentValidator_ValidationStatusString(t *testing.T) { } } -func TestCommitmentValidator_ValidateCommitment_vsTS(t *testing.T) { - validator := NewCommitmentValidator() +func TestValidator_vsTS(t *testing.T) { + validator := newDefaultCommitmentValidator() requestJson := `{ "authenticator": { @@ -385,3 +457,8 @@ func TestCommitmentValidator_ValidateCommitment_vsTS(t *testing.T) { t.Errorf("Expected no error with real secp256k1 data, got: %v", result.Error) } } + +func newDefaultCommitmentValidator() *CommitmentValidator { + // use standalone sharding mode to skip shard id validation + return &CommitmentValidator{shardConfig: config.ShardingConfig{Mode: config.ShardingModeStandalone}} +} diff --git a/internal/smt/smt.go b/internal/smt/smt.go index dabc915..1228ce1 100644 --- a/internal/smt/smt.go +++ b/internal/smt/smt.go @@ -13,12 +13,15 @@ import ( var ( ErrDuplicateLeaf = errors.New("smt: duplicate leaf") ErrLeafModification = errors.New("smt: attempt to modify an existing leaf") + ErrKeyLength = errors.New("smt: invalid key length") + ErrWrongShard = errors.New("smt: key does not belong in this shard") ) type ( - // SparseMerkleTree implements a sparse merkle tree compatible with Unicity SDK + // SparseMerkleTree implements a sparse Merkle tree compatible with Unicity SDK SparseMerkleTree struct { - keyLength int // bit length of the keys in the tree + parentMode bool // true if this tree operates in "parent mode" + keyLength int // bit length of the keys in the tree algorithm api.HashAlgorithm root *NodeBranch isSnapshot bool // true if this is a snapshot, false if original tree @@ -31,24 +34,77 @@ type ( } ) -// NewSparseMerkleTree creates a new sparse merkle tree +// NewSparseMerkleTree creates a new sparse Merkle tree for a monolithic aggregator func NewSparseMerkleTree(algorithm api.HashAlgorithm, keyLength int) *SparseMerkleTree { if keyLength <= 0 { panic("SMT key length must be positive") } return &SparseMerkleTree{ + parentMode: false, keyLength: keyLength, algorithm: algorithm, - root: newRootNode(nil, nil), + root: newRootBranch(big.NewInt(1), nil, nil), isSnapshot: false, original: nil, } } +// NewChildSparseMerkleTree creates a new sparse Merkle tree for a child aggregator in sharded setup +func NewChildSparseMerkleTree(algorithm api.HashAlgorithm, keyLength int, shardID api.ShardID) *SparseMerkleTree { + if keyLength <= 0 { + panic("SMT key length must be positive") + } + if shardID <= 1 { + panic("Shard ID must be positive and have at least 2 bits") + } + path := big.NewInt(int64(shardID)) + if path.BitLen() > keyLength { + panic("Shard ID must be shorter than SMT key length") + } + return &SparseMerkleTree{ + parentMode: false, + keyLength: keyLength, + algorithm: algorithm, + root: newRootBranch(path, nil, nil), + isSnapshot: false, + original: nil, + } +} + +// NewParentSparseMerkleTree creates a new sparse Merkle tree for the parent aggregator in sharded setup +func NewParentSparseMerkleTree(algorithm api.HashAlgorithm, keyLength int) *SparseMerkleTree { + tree := NewSparseMerkleTree(algorithm, keyLength) + tree.parentMode = true + + // Populate all leaves with null hashes + // To allow the child aggregators to compute the correct root hashes + // for their respective leaves in the parent aggregator's tree, the + // parent tree is fully populated (thus not a sparse tree at all) + // It is expected that these nulls will be replaced from the stored + // state at once after the tree is initially constructed, but still + // better to ensure all the leaves exist; otherwise the hash values + // of siblings of the missing nodes would not match the structure of + // the tree and the corresponding inclusion proofs would fail to verify + tree.root.Left = populate(0b10, keyLength) + tree.root.Right = populate(0b11, keyLength) + + return tree +} + +func populate(path, levels int) branch { + if levels == 1 { + return newChildLeafBranch(big.NewInt(int64(path)), nil) + } + left := populate(0b10, levels-1) + right := populate(0b11, levels-1) + return newNodeBranch(big.NewInt(int64(path)), left, right) +} + // CreateSnapshot creates a snapshot of the current SMT state // The snapshot shares nodes with the original tree (copy-on-write) func (smt *SparseMerkleTree) CreateSnapshot() *SmtSnapshot { snapshot := &SparseMerkleTree{ + parentMode: smt.parentMode, keyLength: smt.keyLength, algorithm: smt.algorithm, root: smt.root, // Share the root initially @@ -66,11 +122,15 @@ func (snapshot *SmtSnapshot) Commit() { } // AddLeaf adds a single leaf to the snapshot +// In regular and child mode, only new leaves can be added and any attempt +// to overwrite an existing leaf is an error; in parent mode, updates are allowed func (snapshot *SmtSnapshot) AddLeaf(path *big.Int, value []byte) error { return snapshot.SparseMerkleTree.AddLeaf(path, value) } // AddLeaves adds multiple leaves to the snapshot +// In regular and child mode, only new leaves can be added and any attempt +// to overwrite an existing leaf is an error; in parent mode, updates are allowed func (snapshot *SmtSnapshot) AddLeaves(leaves []*Leaf) error { return snapshot.SparseMerkleTree.AddLeaves(leaves) } @@ -93,7 +153,7 @@ func (smt *SparseMerkleTree) CanModify() bool { // copyOnWriteRoot creates a new root if this snapshot is sharing it with the original func (smt *SparseMerkleTree) copyOnWriteRoot() *NodeBranch { if smt.original != nil && smt.root == smt.original.root { - return newRootNode(smt.root.Left, smt.root.Right) + return newRootBranch(smt.root.Path, smt.root.Left, smt.root.Right) } return smt.root } @@ -106,7 +166,10 @@ func (smt *SparseMerkleTree) cloneBranch(branch branch) branch { if branch.isLeaf() { leafBranch := branch.(*LeafBranch) - return newLeafBranch(leafBranch.Path, leafBranch.Value) + cloned := newLeafBranch(leafBranch.Path, leafBranch.Value) + // Preserve the isChild flag for parent mode trees + cloned.isChild = leafBranch.isChild + return cloned } else { nodeBranch := branch.(*NodeBranch) return newNodeBranch(nodeBranch.Path, nodeBranch.Left, nodeBranch.Right) @@ -122,36 +185,58 @@ type branch interface { // LeafBranch represents a leaf node type LeafBranch struct { - Path *big.Int - Value []byte - hash *api.DataHash + Path *big.Int + Value []byte + hash *api.DataHash + isChild bool // true if this is the root hash form a child aggregator } // NodeBranch represents an internal node type NodeBranch struct { - Path *big.Int - Left branch - Right branch - hash *api.DataHash + Path *big.Int + Left branch + Right branch + hash *api.DataHash + isRoot bool // true if this is the root node } -// NewLeafBranch creates a leaf branch +// NewLeafBranch creates a regular leaf branch func newLeafBranch(path *big.Int, value []byte) *LeafBranch { return &LeafBranch{ - Path: new(big.Int).Set(path), - Value: append([]byte(nil), value...), + Path: new(big.Int).Set(path), + Value: append([]byte(nil), value...), + isChild: false, // Hash will be computed on demand } } +// NewChildLeafBranch creates a parent tree leaf containing the root hash of a child tree +func newChildLeafBranch(path *big.Int, value []byte) *LeafBranch { + if value != nil { + value = append([]byte(nil), value...) + } + return &LeafBranch{ + Path: new(big.Int).Set(path), + Value: value, + isChild: true, + // Hash will be set on demand + } +} + func (l *LeafBranch) calculateHash(hasher *api.DataHasher) *api.DataHash { if l.hash != nil { return l.hash } - pathBytes := api.BigintEncode(l.Path) - l.hash = hasher.Reset().AddData(api.CborArray(2)). - AddCborBytes(pathBytes).AddCborBytes(l.Value).GetHash() + if l.isChild { + if l.Value != nil { + l.hash = api.NewDataHash(hasher.GetAlgorithm(), l.Value) + } + } else { + pathBytes := api.BigintEncode(l.Path) + l.hash = hasher.Reset().AddData(api.CborArray(2)). + AddCborBytes(pathBytes).AddCborBytes(l.Value).GetHash() + } return l.hash } @@ -163,17 +248,24 @@ func (l *LeafBranch) isLeaf() bool { return true } -// NewRootNode creates a new root node -func newRootNode(left, right branch) *NodeBranch { - return newNodeBranch(big.NewInt(1), left, right) +// NewNodeBranch creates a regular node branch +func newNodeBranch(path *big.Int, left, right branch) *NodeBranch { + return &NodeBranch{ + Path: new(big.Int).Set(path), + Left: left, + Right: right, + isRoot: false, + // Hash will be computed on demand + } } -// NewNodeBranch creates a node branch -func newNodeBranch(path *big.Int, left, right branch) *NodeBranch { +// NewRootBranch creates a root node branch +func newRootBranch(path *big.Int, left, right branch) *NodeBranch { return &NodeBranch{ - Path: new(big.Int).Set(path), - Left: left, - Right: right, + Path: new(big.Int).Set(path), + Left: left, + Right: right, + isRoot: true, // Hash will be computed on demand } } @@ -196,8 +288,16 @@ func (n *NodeBranch) calculateHash(hasher *api.DataHasher) *api.DataHash { hasher.Reset().AddData(api.CborArray(3)) - pathBytes := api.BigintEncode(n.Path) - hasher.AddCborBytes(pathBytes) + if n.isRoot && n.Path.BitLen() > 1 { + // This is root of a child tree in sharded setup + // The path to add is the last bit of the shard ID + pos := n.Path.BitLen() - 2 + path := big.NewInt(int64(2 + n.Path.Bit(pos))) + hasher.AddCborBytes(api.BigintEncode(path)) + } else { + // In all other cases we just add the actual path + hasher.AddCborBytes(api.BigintEncode(n.Path)) + } if leftHash == nil { hasher.AddCborNull() @@ -226,7 +326,10 @@ func (n *NodeBranch) isLeaf() bool { // AddLeaf adds a single leaf to the tree func (smt *SparseMerkleTree) AddLeaf(path *big.Int, value []byte) error { if path.BitLen()-1 != smt.keyLength { - return fmt.Errorf("invalid key length %d, should be %d", path.BitLen()-1, smt.keyLength) + return ErrKeyLength + } + if calculateCommonPath(path, smt.root.Path).BitLen() != smt.root.Path.BitLen() { + return ErrWrongShard } // Implement copy-on-write for snapshots only @@ -234,8 +337,8 @@ func (smt *SparseMerkleTree) AddLeaf(path *big.Int, value []byte) error { smt.root = smt.copyOnWriteRoot() } - // TypeScript: const isRight = path & 1n; - isRight := path.Bit(0) == 1 + shifted := new(big.Int).Rsh(path, uint(smt.root.Path.BitLen()-1)) + isRight := shifted.Bit(0) == 1 var left, right branch @@ -249,13 +352,13 @@ func (smt *SparseMerkleTree) AddLeaf(path *big.Int, value []byte) error { } else { rightBranch = smt.root.Right } - newRight, err := smt.buildTree(rightBranch, path, value) + newRight, err := smt.buildTree(rightBranch, shifted, value) if err != nil { return err } right = newRight } else { - right = newLeafBranch(path, value) + right = newLeafBranch(shifted, value) } } else { if smt.root.Left != nil { @@ -266,18 +369,18 @@ func (smt *SparseMerkleTree) AddLeaf(path *big.Int, value []byte) error { } else { leftBranch = smt.root.Left } - newLeft, err := smt.buildTree(leftBranch, path, value) + newLeft, err := smt.buildTree(leftBranch, shifted, value) if err != nil { return err } left = newLeft } else { - left = newLeafBranch(path, value) + left = newLeafBranch(shifted, value) } right = smt.root.Right } - smt.root = newRootNode(left, right) + smt.root = newRootBranch(smt.root.Path, left, right) return nil } @@ -334,12 +437,12 @@ func (smt *SparseMerkleTree) findLeafInBranch(branch branch, targetPath *big.Int commonPath := calculateCommonPath(targetPath, b.Path) // Check if targetPath can be in this subtree - if commonPath.path.Cmp(targetPath) == 0 { + if commonPath.Cmp(targetPath) == 0 { return nil, fmt.Errorf("leaf not found") } // Navigate using the same logic as buildTree - shifted := new(big.Int).Rsh(targetPath, commonPath.length) + shifted := new(big.Int).Rsh(targetPath, uint(commonPath.BitLen()-1)) isRight := shifted.Bit(0) == 1 // KEY FIX: Pass the shifted path to match tree construction @@ -361,7 +464,9 @@ func (smt *SparseMerkleTree) buildTree(branch branch, remainingPath *big.Int, va // Special checks for adding a leaf that already exists in the tree if branch.isLeaf() && branch.getPath().Cmp(remainingPath) == 0 { leafBranch := branch.(*LeafBranch) - if bytes.Equal(leafBranch.Value, value) { + if leafBranch.isChild { + return newChildLeafBranch(leafBranch.Path, value), nil + } else if bytes.Equal(leafBranch.Value, value) { return nil, ErrDuplicateLeaf } else { return nil, ErrLeafModification @@ -369,59 +474,59 @@ func (smt *SparseMerkleTree) buildTree(branch branch, remainingPath *big.Int, va } commonPath := calculateCommonPath(remainingPath, branch.getPath()) - shifted := new(big.Int).Rsh(remainingPath, commonPath.length) + shifted := new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)) isRight := shifted.Bit(0) == 1 - if commonPath.path.Cmp(remainingPath) == 0 { - return nil, fmt.Errorf("cannot add leaf inside branch, commonPath: '%s', remainingPath: '%s'", commonPath.path, remainingPath) + if commonPath.Cmp(remainingPath) == 0 { + return nil, fmt.Errorf("cannot add leaf inside branch, commonPath: '%s', remainingPath: '%s'", commonPath, remainingPath) } // If a leaf must be split from the middle if branch.isLeaf() { leafBranch := branch.(*LeafBranch) - if commonPath.path.Cmp(leafBranch.Path) == 0 { + if commonPath.Cmp(leafBranch.Path) == 0 { return nil, fmt.Errorf("cannot extend tree through leaf") } // TypeScript: branch.path >> commonPath.length - oldBranchPath := new(big.Int).Rsh(leafBranch.Path, commonPath.length) + oldBranchPath := new(big.Int).Rsh(leafBranch.Path, uint(commonPath.BitLen()-1)) oldBranch := newLeafBranch(oldBranchPath, leafBranch.Value) // TypeScript: remainingPath >> commonPath.length - newBranchPath := new(big.Int).Rsh(remainingPath, commonPath.length) + newBranchPath := new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)) newBranch := newLeafBranch(newBranchPath, value) if isRight { - return newNodeBranch(commonPath.path, oldBranch, newBranch), nil + return newNodeBranch(commonPath, oldBranch, newBranch), nil } else { - return newNodeBranch(commonPath.path, newBranch, oldBranch), nil + return newNodeBranch(commonPath, newBranch, oldBranch), nil } } // If node branch is split in the middle nodeBranch := branch.(*NodeBranch) - if commonPath.path.Cmp(nodeBranch.Path) < 0 { - newBranchPath := new(big.Int).Rsh(remainingPath, commonPath.length) + if commonPath.Cmp(nodeBranch.Path) < 0 { + newBranchPath := new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)) newBranch := newLeafBranch(newBranchPath, value) - oldBranchPath := new(big.Int).Rsh(nodeBranch.Path, commonPath.length) + oldBranchPath := new(big.Int).Rsh(nodeBranch.Path, uint(commonPath.BitLen()-1)) oldBranch := newNodeBranch(oldBranchPath, nodeBranch.Left, nodeBranch.Right) if isRight { - return newNodeBranch(commonPath.path, oldBranch, newBranch), nil + return newNodeBranch(commonPath, oldBranch, newBranch), nil } else { - return newNodeBranch(commonPath.path, newBranch, oldBranch), nil + return newNodeBranch(commonPath, newBranch, oldBranch), nil } } if isRight { - newRight, err := smt.buildTree(nodeBranch.Right, new(big.Int).Rsh(remainingPath, commonPath.length), value) + newRight, err := smt.buildTree(nodeBranch.Right, new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)), value) if err != nil { return nil, err } return newNodeBranch(nodeBranch.Path, nodeBranch.Left, newRight), nil } else { - newLeft, err := smt.buildTree(nodeBranch.Left, new(big.Int).Rsh(remainingPath, commonPath.length), value) + newLeft, err := smt.buildTree(nodeBranch.Left, new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)), value) if err != nil { return nil, err } @@ -429,11 +534,12 @@ func (smt *SparseMerkleTree) buildTree(branch branch, remainingPath *big.Int, va } } -func (smt *SparseMerkleTree) GetPath(path *big.Int) *api.MerkleTreePath { +func (smt *SparseMerkleTree) GetPath(path *big.Int) (*api.MerkleTreePath, error) { if path.BitLen()-1 != smt.keyLength { - // TODO: better error handling - fmt.Printf("SparseMerkleTree.GetPath(): invalid key length %d, should be %d", path.BitLen()-1, smt.keyLength) - return nil + return nil, ErrKeyLength + } + if calculateCommonPath(path, smt.root.Path).BitLen() != smt.root.Path.BitLen() { + return nil, ErrWrongShard } // Create a new hasher to ensure thread safety @@ -445,7 +551,7 @@ func (smt *SparseMerkleTree) GetPath(path *big.Int) *api.MerkleTreePath { return &api.MerkleTreePath{ Root: rootHash.ToHex(), Steps: steps, - } + }, nil } // generatePath recursively generates the Merkle tree path steps @@ -459,9 +565,13 @@ func (smt *SparseMerkleTree) generatePath(hasher *api.DataHasher, remainingPath // Create the corresponding leaf hash step currentLeaf, _ := currentNode.(*LeafBranch) path := currentLeaf.Path.String() - data := hex.EncodeToString(currentLeaf.Value) + var data *string + if currentLeaf.Value != nil { + tmp := hex.EncodeToString(currentLeaf.Value) + data = &tmp + } return []api.MerkleTreeStep{ - {Path: path, Data: &data}, + {Path: path, Data: data}, } } @@ -470,74 +580,70 @@ func (smt *SparseMerkleTree) generatePath(hasher *api.DataHasher, remainingPath panic("Unknown target branch type") } + var path *big.Int + if currentBranch.isRoot && currentBranch.Path.BitLen() > 1 { + // This is root of a child tree in sharded setup + // The path to add is the last bit of the shard ID + pos := currentBranch.Path.BitLen() - 2 + path = big.NewInt(int64(0b10 | currentBranch.Path.Bit(pos))) + } else { + // In all other cases we just add the actual path + path = currentBranch.Path + } + + var leftHash, rightHash *string + if currentBranch.Left != nil { + hash := currentBranch.Left.calculateHash(hasher) + if hash != nil { + tmp := hex.EncodeToString(hash.RawHash) + leftHash = &tmp + } + } + if currentBranch.Right != nil { + hash := currentBranch.Right.calculateHash(hasher) + if hash != nil { + tmp := hex.EncodeToString(hash.RawHash) + rightHash = &tmp + } + } + commonPath := calculateCommonPath(remainingPath, currentBranch.Path) - if commonPath.length < uint(currentBranch.Path.BitLen()-1) { + if currentBranch != smt.root && commonPath.BitLen() < currentBranch.Path.BitLen() { // Remaining path diverges or ends here - // Root node is a special case, because of its empty path // Create the corresponding 2-step proof - // No nil children in non-root nodes - leftHash := hex.EncodeToString(currentBranch.Left.calculateHash(hasher).RawHash) - rightHash := hex.EncodeToString(currentBranch.Right.calculateHash(hasher).RawHash) - // This looks weird, but see the effect in api.MerkleTreePath.Verify() return []api.MerkleTreeStep{ - {Path: "0", Data: &rightHash}, - {Path: currentBranch.Path.String(), Data: &leftHash}, + {Path: "0", Data: leftHash}, + {Path: path.String(), Data: rightHash}, } } // Trim remaining path for descending into subtree - remainingPath = new(big.Int).Rsh(remainingPath, commonPath.length) + remainingPath = new(big.Int).Rsh(remainingPath, uint(commonPath.BitLen()-1)) - var target, sibling branch + var step api.MerkleTreeStep + var steps []api.MerkleTreeStep if remainingPath.Bit(0) == 0 { - // Target in the left child - target = currentBranch.Left - sibling = currentBranch.Right - } else { - // Target in the right child - target = currentBranch.Right - sibling = currentBranch.Left - } - - if target == nil { - // Target branch empty - // This can happen only at the root node - // Create the 2-step exclusion proof - // There may be nil children here - var leftHash, rightHash *string - if currentBranch.Left != nil { - tmp := hex.EncodeToString(currentBranch.Left.calculateHash(hasher).RawHash) - leftHash = &tmp - } - if currentBranch.Right != nil { - tmp := hex.EncodeToString(currentBranch.Right.calculateHash(hasher).RawHash) - rightHash = &tmp + // Target in the left child, right child is sibling + step = api.MerkleTreeStep{Path: path.String(), Data: rightHash} + if leftHash == nil { + steps = []api.MerkleTreeStep{{Path: "0", Data: nil}} + } else { + steps = smt.generatePath(hasher, remainingPath, currentBranch.Left) } - // This looks weird, but see the effect in api.MerkleTreePath.Verify() - return []api.MerkleTreeStep{ - {Path: "0", Data: rightHash}, - {Path: "1", Data: leftHash}, + } else { + step = api.MerkleTreeStep{Path: path.String(), Data: leftHash} + // Target in the right child, left child is sibling + if rightHash == nil { + steps = []api.MerkleTreeStep{{Path: "1", Data: nil}} + } else { + steps = smt.generatePath(hasher, remainingPath, currentBranch.Right) } } - - steps := smt.generatePath(hasher, remainingPath, target) - - // Add the step for the current branch - step := api.MerkleTreeStep{ - Path: currentBranch.Path.String(), - } - if sibling != nil { - tmp := hex.EncodeToString(sibling.calculateHash(hasher).RawHash) - step.Data = &tmp - } return append(steps, step) } // calculateCommonPath computes the longest common prefix of path1 and path2 -func calculateCommonPath(path1, path2 *big.Int) struct { - length uint - path *big.Int -} { +func calculateCommonPath(path1, path2 *big.Int) *big.Int { if path1.Sign() != 1 || path2.Sign() != 1 { panic("Non-positive path value") } @@ -553,10 +659,7 @@ func calculateCommonPath(path1, path2 *big.Int) struct { res.And(res, path1) // res &= path res.Or(res, mask) // res |= mask - return struct { - length uint - path *big.Int - }{uint(pos), res} + return res } // Leaf represents a leaf to be inserted (for batch operations) @@ -572,3 +675,40 @@ func NewLeaf(path *big.Int, value []byte) *Leaf { Value: append([]byte(nil), value...), } } + +// JoinPaths joins the hash proofs from a child and parent in sharded setting +func JoinPaths(child, parent *api.MerkleTreePath) (*api.MerkleTreePath, error) { + if child == nil { + return nil, errors.New("nil child path") + } + if parent == nil { + return nil, errors.New("nil parent path") + } + + // Root hashes are hex-encoded imprints, the first 4 characters are hash function identifiers + if len(child.Root) < 4 { + return nil, errors.New("invalid child root hash format") + } + if len(parent.Root) < 4 { + return nil, errors.New("invalid parent root hash format") + } + if child.Root[:4] != parent.Root[:4] { + return nil, errors.New("can't join paths: child hash algorithm does not match parent") + } + + if len(parent.Steps) == 0 { + return nil, errors.New("empty parent hash steps") + } + if parent.Steps[0].Data == nil || *parent.Steps[0].Data != child.Root[4:] { + return nil, errors.New("can't join paths: child root hash does not match parent input hash") + } + + steps := make([]api.MerkleTreeStep, len(child.Steps)+len(parent.Steps)-1) + copy(steps, child.Steps) + copy(steps[len(child.Steps):], parent.Steps[1:]) + + return &api.MerkleTreePath{ + Root: parent.Root, + Steps: steps, + }, nil +} diff --git a/internal/smt/smt_debug_test.go b/internal/smt/smt_debug_test.go index a96e8c5..6fcb7c7 100644 --- a/internal/smt/smt_debug_test.go +++ b/internal/smt/smt_debug_test.go @@ -26,16 +26,12 @@ func TestAddLeaves_DebugInvalidPath(t *testing.T) { leafValue, err := commitment.CreateLeafValue() require.NoError(t, err) - leaf := &Leaf{ - Path: path, - Value: leafValue, - } - - err = tree.AddLeaves([]*Leaf{leaf}) + err = tree.AddLeaves([]*Leaf{NewLeaf(path, leafValue)}) require.NoError(t, err, "Expected error due to invalid path") // now validate the path of request - merkleTreePath := tree.GetPath(path) + merkleTreePath, err := tree.GetPath(path) + require.NoError(t, err) require.NotNil(t, merkleTreePath, "Expected non-nil Merkle tree path for valid request ID") res, err := merkleTreePath.Verify(path) @@ -84,7 +80,8 @@ func TestAddLeaves_DebugInvalidPath(t *testing.T) { require.NoError(t, err, "Failed to create request ID") path, err := req.GetPath() require.NoError(t, err) - merkleTreePath := _smt.GetPath(path) + merkleTreePath, err := _smt.GetPath(path) + require.NoError(t, err) require.NotNil(t, merkleTreePath, "Expected non-nil Merkle tree path for valid request ID") res, err := merkleTreePath.Verify(path) diff --git a/internal/smt/smt_test.go b/internal/smt/smt_test.go index 9272e0a..2850866 100644 --- a/internal/smt/smt_test.go +++ b/internal/smt/smt_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/unicitynetwork/aggregator-go/pkg/api" "github.com/stretchr/testify/require" @@ -15,12 +16,14 @@ import ( // TestSMTGetRoot test basic SMT root hash computation func TestSMTGetRoot(t *testing.T) { + // "Singleton" example from the spec t.Run("EmptyTree", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) expected := "00001e54402898172f2948615fb17627733abbd120a85381c624ad060d28321be672" require.Equal(t, expected, smt.GetRootHashHex()) }) + // "Left Child Only" example from the spec t.Run("LeftLeaf", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) smt.AddLeaf(big.NewInt(0b100), []byte{0x61}) @@ -29,6 +32,7 @@ func TestSMTGetRoot(t *testing.T) { require.Equal(t, expected, smt.GetRootHashHex()) }) + // "Right Child Only" example from the spec t.Run("RightLeaf", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) smt.AddLeaf(big.NewInt(0b111), []byte{0x62}) @@ -37,6 +41,7 @@ func TestSMTGetRoot(t *testing.T) { require.Equal(t, expected, smt.GetRootHashHex()) }) + // "Two Leaves" example from the spec t.Run("TwoLeaves", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) smt.AddLeaf(big.NewInt(0b100), []byte{0x61}) @@ -46,6 +51,7 @@ func TestSMTGetRoot(t *testing.T) { require.Equal(t, expected, smt.GetRootHashHex()) }) + // "Four Leaves" example from the spec t.Run("FourLeaves", func(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 3) smt.AddLeaf(big.NewInt(0b1000), []byte{0x61}) @@ -58,6 +64,72 @@ func TestSMTGetRoot(t *testing.T) { }) } +func TestChildSMTGetRoot(t *testing.T) { + // Left child of the "Two Leaves, Sharded" example from the spec + t.Run("LeftOfTwoLeaves", func(t *testing.T) { + smt := NewChildSparseMerkleTree(api.SHA256, 2, 0b10) + smt.AddLeaf(big.NewInt(0b100), []byte{0x61}) + + expected := "0000256aedd9f31e69a4b0803616beab77234bae5dff519a10e519a0753be49f0534" + require.Equal(t, expected, smt.GetRootHashHex()) + }) + + // Right child of the "Two Leaves, Sharded" example from the spec + t.Run("RightOfTwoLeaves", func(t *testing.T) { + smt := NewChildSparseMerkleTree(api.SHA256, 2, 0b11) + smt.AddLeaf(big.NewInt(0b111), []byte{0x62}) + + expected := "0000e777763b4ce391c2f8acdf480dd64758bc8063a3aa5f62670a499a61d3bc7b9a" + require.Equal(t, expected, smt.GetRootHashHex()) + }) + + // Left child of the "Four Leaves, Sharded" example from the spec + t.Run("LeftOfFourLeaves", func(t *testing.T) { + smt := NewChildSparseMerkleTree(api.SHA256, 4, 0b110) + smt.AddLeaf(big.NewInt(0b10010), []byte{0x61}) + smt.AddLeaf(big.NewInt(0b11010), []byte{0x62}) + + expected := "000010c1dc89e30d51613f2c1a182d16f87fe6709b9735db612adaadaa91955bdaf0" + require.Equal(t, expected, smt.GetRootHashHex()) + }) + + // Right child of the "Four Leaves, Sharded" example from the spec + t.Run("RightOfFourLeaves", func(t *testing.T) { + smt := NewChildSparseMerkleTree(api.SHA256, 4, 0b101) + smt.AddLeaf(big.NewInt(0b10101), []byte{0x63}) + smt.AddLeaf(big.NewInt(0b11101), []byte{0x64}) + + expected := "0000981d2f4e01189506c5a36430e7774e3f9498c1c4cc27801d8e6400d4965a8860" + require.Equal(t, expected, smt.GetRootHashHex()) + }) +} + +func TestParentSMTGetRoot(t *testing.T) { + // Parent of the "Two Leaves, Sharded" example from the spec + t.Run("TwoLeaves", func(t *testing.T) { + left, _ := hex.DecodeString("256aedd9f31e69a4b0803616beab77234bae5dff519a10e519a0753be49f0534") + right, _ := hex.DecodeString("e777763b4ce391c2f8acdf480dd64758bc8063a3aa5f62670a499a61d3bc7b9a") + smt := NewParentSparseMerkleTree(api.SHA256, 1) + smt.AddLeaf(big.NewInt(0b10), left) + smt.AddLeaf(big.NewInt(0b11), right) + + expected := "0000413b961d0069adfea0b4e122cf6dbf98e0a01ef7fd573d68c084ddfa03e4f9d6" + require.Equal(t, expected, smt.GetRootHashHex()) + }) + + // Parent of the "Four Leaves, Sharded" example from the spec + t.Run("FourLeaves", func(t *testing.T) { + left, _ := hex.DecodeString("10c1dc89e30d51613f2c1a182d16f87fe6709b9735db612adaadaa91955bdaf0") + right, _ := hex.DecodeString("981d2f4e01189506c5a36430e7774e3f9498c1c4cc27801d8e6400d4965a8860") + smt := NewParentSparseMerkleTree(api.SHA256, 2) + smt.AddLeaf(big.NewInt(0b110), left) + smt.AddLeaf(big.NewInt(0b101), right) + + expected := "0000eb1a95574056c988f441a50bd18d0555f038276aecf3d155eb9e008a72afcb45" + require.Equal(t, expected, smt.GetRootHashHex()) + }) +} + // TestSMTBatchOperations tests batch functionality func TestSMTBatchOperations(t *testing.T) { t.Run("SimpleRetrievalTest", func(t *testing.T) { @@ -184,18 +256,16 @@ func TestSMTCommonPath(t *testing.T) { testCases := []struct { path1 *big.Int path2 *big.Int - expLen uint expPath *big.Int }{ - {big.NewInt(0b11), big.NewInt(0b111101111), 1, big.NewInt(0b11)}, - {big.NewInt(0b111101111), big.NewInt(0b11), 1, big.NewInt(0b11)}, - {big.NewInt(0b110010000), big.NewInt(0b100010000), 7, big.NewInt(0b10010000)}, + {big.NewInt(0b11), big.NewInt(0b111101111), big.NewInt(0b11)}, + {big.NewInt(0b111101111), big.NewInt(0b11), big.NewInt(0b11)}, + {big.NewInt(0b110010000), big.NewInt(0b100010000), big.NewInt(0b10010000)}, } for i, tc := range testCases { result := calculateCommonPath(tc.path1, tc.path2) - assert.Equal(t, tc.expLen, result.length, "Test %d: length mismatch", i) - assert.Equal(t, tc.expPath, result.path, "Test %d: path mismatch", i) + assert.Equal(t, tc.expPath, result, "Test %d: path mismatch", i) } } @@ -417,7 +487,8 @@ func TestSMTGetPath(t *testing.T) { require.NoError(t, err, "AddLeaf failed") // Test getting path for an existing leaf - merklePath := smt.GetPath(path) + merklePath, err := smt.GetPath(path) + require.NoError(t, err) require.NotNil(t, merklePath, "GetPath should return a path") require.Equal(t, smt.GetRootHashHex(), merklePath.Root, "Root hash should match expected value") require.NotNil(t, merklePath.Steps, "Steps should not be nil") @@ -443,7 +514,8 @@ func TestSMTGetPath(t *testing.T) { require.NoError(t, err, "AddLeaf failed") // Test getting path for an existing leaf - path := smt.GetPath(big.NewInt(0b10)) + path, err := smt.GetPath(big.NewInt(0b10)) + require.NoError(t, err) require.NotNil(t, path, "GetPath should return a path") require.NotEmpty(t, path.Root, "Root hash should not be empty") require.NotNil(t, path.Steps, "Steps should not be nil") @@ -463,7 +535,8 @@ func TestSMTGetPath(t *testing.T) { require.NoError(t, err, "AddLeaf failed") // Test getting path for a non-existent leaf - path := smt.GetPath(big.NewInt(0b11)) + path, err := smt.GetPath(big.NewInt(0b11)) + require.NoError(t, err) require.NotNil(t, path, "GetPath should return a path even for non-existent leaves") require.NotEmpty(t, path.Root, "Root hash should not be empty") require.NotNil(t, path.Steps, "Steps should not be nil") @@ -482,7 +555,8 @@ func TestSMTGetPath(t *testing.T) { require.NoError(t, err, "AddLeaf failed") // Test path structure - path := smt.GetPath(big.NewInt(0b10)) + path, err := smt.GetPath(big.NewInt(0b10)) + require.NoError(t, err) require.NotNil(t, path, "GetPath should return a path") // Verify step structure @@ -498,7 +572,8 @@ func TestSMTGetPath(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 1) // Test getting path from empty tree - path := smt.GetPath(big.NewInt(0b10)) + path, err := smt.GetPath(big.NewInt(0b10)) + require.NoError(t, err) require.NotNil(t, path, "GetPath should return a path even for empty tree") require.NotEmpty(t, path.Root, "Root hash should not be empty even for empty tree") require.NotNil(t, path.Steps, "Steps should not be nil") @@ -519,7 +594,8 @@ func TestSMTGetPathComprehensive(t *testing.T) { require.NoError(t, err, "AddLeaf failed") // Get path for the leaf - path := smt.GetPath(leafPath) + path, err := smt.GetPath(leafPath) + require.NoError(t, err) require.NotNil(t, path, "GetPath should return a path") require.Equal(t, smt.GetRootHashHex(), path.Root, "Path root should match tree root") @@ -545,13 +621,15 @@ func TestSMTGetPathComprehensive(t *testing.T) { require.NoError(t, err, "AddLeaf 2 failed") // Get path for first leaf - merkPath1 := smt.GetPath(path1) + merkPath1, err := smt.GetPath(path1) + require.NoError(t, err) require.NotNil(t, merkPath1, "GetPath should return a path") require.Equal(t, smt.GetRootHashHex(), merkPath1.Root, "Path root should match tree root") require.NotEmpty(t, merkPath1.Steps, "Should have steps") // Get path for second leaf - merkPath2 := smt.GetPath(path2) + merkPath2, err := smt.GetPath(path2) + require.NoError(t, err) require.NotNil(t, merkPath2, "GetPath should return a path") require.Equal(t, smt.GetRootHashHex(), merkPath2.Root, "Path root should match tree root") require.NotEmpty(t, merkPath2.Steps, "Should have steps") @@ -585,8 +663,8 @@ func TestSMTGetPathComprehensive(t *testing.T) { // Try to get path for non-existent leaf nonExistentPath := big.NewInt(0b111) - merkPath := smt.GetPath(nonExistentPath) - + merkPath, err := smt.GetPath(nonExistentPath) + require.NoError(t, err) require.NotNil(t, merkPath, "GetPath should return a path even for non-existent paths") require.Equal(t, smt.GetRootHashHex(), merkPath.Root, "Path root should match tree root") require.NotEmpty(t, merkPath.Steps, "Should have steps even for non-existent path") @@ -624,7 +702,8 @@ func TestSMTGetPathComprehensive(t *testing.T) { rootHash := smt.GetRootHashHex() for i, path := range testPaths { - merkPath := smt.GetPath(path) + merkPath, err := smt.GetPath(path) + require.NoError(t, err) require.NotNil(t, merkPath, "GetPath should return a path for leaf %d", i) require.Equal(t, rootHash, merkPath.Root, "All paths should have same root") require.NotEmpty(t, merkPath.Steps, "Path should have steps for leaf %d", i) @@ -652,14 +731,15 @@ func TestSMTGetPathComprehensive(t *testing.T) { smt := NewSparseMerkleTree(api.SHA256, 2) // Get path from empty tree - path := smt.GetPath(big.NewInt(0b101)) + path, err := smt.GetPath(big.NewInt(0b101)) + require.NoError(t, err) require.NotNil(t, path, "GetPath should return a path even for empty tree") require.NotEmpty(t, path.Root, "Root should not be empty even for empty tree") require.NotNil(t, path.Steps, "Steps should not be nil") require.Len(t, path.Steps, 2, "Should have two steps") step0 := path.Steps[0] - require.Equal(t, "0", step0.Path, "Input step path") + require.Equal(t, "1", step0.Path, "Input step path") require.Nil(t, step0.Data, "Empty tree step should have no data") step1 := path.Steps[1] @@ -690,7 +770,8 @@ func TestSMTGetPathComprehensive(t *testing.T) { // Get paths and validate structure for _, leaf := range testLeaves { - merkPath := smt.GetPath(leaf.path) + merkPath, err := smt.GetPath(leaf.path) + require.NoError(t, err) require.NotNil(t, merkPath, "GetPath should return a path") // Validate path structure for verification compatibility @@ -727,10 +808,10 @@ func TestSMTGetPathComprehensive(t *testing.T) { require.NoError(t, err, "AddLeaf failed") // Get paths multiple times and verify consistency - merkPath1a := smt.GetPath(path1) - merkPath1b := smt.GetPath(path1) - merkPath2a := smt.GetPath(path2) - merkPath2b := smt.GetPath(path2) + merkPath1a, _ := smt.GetPath(path1) + merkPath1b, _ := smt.GetPath(path1) + merkPath2a, _ := smt.GetPath(path2) + merkPath2b, _ := smt.GetPath(path2) // Same path should return identical results require.Equal(t, merkPath1a.Root, merkPath1b.Root, "Same path should have same root") @@ -777,7 +858,8 @@ func TestSMTGetPathComprehensive(t *testing.T) { require.NoError(t, err, "AddLeaf failed for %s", tc.name) // Get path - merkPath := smt.GetPath(path) + merkPath, err := smt.GetPath(path) + require.NoError(t, err) require.NotNil(t, merkPath, "GetPath should return a path for %s", tc.name) // Verify the path representation in steps @@ -1220,3 +1302,176 @@ func TestSMTAddingLeafAboveNode(t *testing.T) { } require.Error(t, smt2.AddLeaves(leaves2), "SMT should not allow adding leaves above existing nodes, even in a batch") } + +func TestJoinPaths(t *testing.T) { + // "Two Leaves, Sharded" example from the spec + t.Run("TwoLeaves", func(t *testing.T) { + left := NewChildSparseMerkleTree(api.SHA256, 2, 0b10) + left.AddLeaf(big.NewInt(0b100), []byte{0x61}) + + right := NewChildSparseMerkleTree(api.SHA256, 2, 0b11) + right.AddLeaf(big.NewInt(0b111), []byte{0x62}) + + parent := NewParentSparseMerkleTree(api.SHA256, 1) + parent.AddLeaf(big.NewInt(0b10), left.GetRootHash()[2:]) + parent.AddLeaf(big.NewInt(0b11), right.GetRootHash()[2:]) + + leftChild, _ := left.GetPath(big.NewInt(0b100)) + leftParent, _ := parent.GetPath(big.NewInt(0b10)) + leftPath, err := JoinPaths(leftChild, leftParent) + assert.Nil(t, err) + assert.NotNil(t, leftPath) + leftRes, err := leftPath.Verify(big.NewInt(0b100)) + assert.Nil(t, err) + assert.NotNil(t, leftRes) + assert.True(t, leftRes.PathValid) + assert.True(t, leftRes.PathIncluded) + + rightChild, _ := right.GetPath(big.NewInt(0b111)) + rightParent, _ := parent.GetPath(big.NewInt(0b11)) + rightPath, err := JoinPaths(rightChild, rightParent) + assert.Nil(t, err) + assert.NotNil(t, rightPath) + rightRes, err := rightPath.Verify(big.NewInt(0b111)) + assert.Nil(t, err) + assert.NotNil(t, rightRes) + assert.True(t, rightRes.PathValid) + assert.True(t, rightRes.PathIncluded) + }) + + // "Four Leaves, Sharded" example from the spec + t.Run("FourLeaves", func(t *testing.T) { + left := NewChildSparseMerkleTree(api.SHA256, 4, 0b110) + left.AddLeaf(big.NewInt(0b10010), []byte{0x61}) + left.AddLeaf(big.NewInt(0b11010), []byte{0x62}) + + right := NewChildSparseMerkleTree(api.SHA256, 4, 0b101) + right.AddLeaf(big.NewInt(0b10101), []byte{0x63}) + right.AddLeaf(big.NewInt(0b11101), []byte{0x64}) + + parent := NewParentSparseMerkleTree(api.SHA256, 2) + parent.AddLeaf(big.NewInt(0b110), left.GetRootHash()[2:]) + parent.AddLeaf(big.NewInt(0b101), right.GetRootHash()[2:]) + + child1, _ := left.GetPath(big.NewInt(0b10010)) + parent1, _ := parent.GetPath(big.NewInt(0b110)) + path1, err := JoinPaths(child1, parent1) + assert.Nil(t, err) + assert.NotNil(t, path1) + res1, err := path1.Verify(big.NewInt(0b10010)) + assert.Nil(t, err) + assert.NotNil(t, res1) + assert.True(t, res1.PathValid) + assert.True(t, res1.PathIncluded) + + child2, _ := left.GetPath(big.NewInt(0b11010)) + parent2, _ := parent.GetPath(big.NewInt(0b110)) + path2, err := JoinPaths(child2, parent2) + assert.Nil(t, err) + assert.NotNil(t, path2) + res2, err := path2.Verify(big.NewInt(0b11010)) + assert.Nil(t, err) + assert.NotNil(t, res2) + assert.True(t, res2.PathValid) + assert.True(t, res2.PathIncluded) + + child3, _ := right.GetPath(big.NewInt(0b10101)) + parent3, _ := parent.GetPath(big.NewInt(0b101)) + path3, err := JoinPaths(child3, parent3) + assert.Nil(t, err) + assert.NotNil(t, path3) + res3, err := path3.Verify(big.NewInt(0b10101)) + assert.Nil(t, err) + assert.NotNil(t, res3) + assert.True(t, res3.PathValid) + assert.True(t, res3.PathIncluded) + + child4, _ := right.GetPath(big.NewInt(0b11101)) + parent4, _ := parent.GetPath(big.NewInt(0b101)) + path4, err := JoinPaths(child4, parent4) + assert.Nil(t, err) + assert.NotNil(t, path4) + res4, err := path4.Verify(big.NewInt(0b11101)) + assert.Nil(t, err) + assert.NotNil(t, res4) + assert.True(t, res4.PathValid) + assert.True(t, res4.PathIncluded) + }) + + t.Run("NilInputDoesNotPanic", func(t *testing.T) { + dummy := &api.MerkleTreePath{Root: "0000"} + + joinedPath, err := JoinPaths(nil, dummy) + assert.ErrorContains(t, err, "nil child path") + assert.Nil(t, joinedPath) + + joinedPath, err = JoinPaths(dummy, nil) + assert.ErrorContains(t, err, "nil parent path") + assert.Nil(t, joinedPath) + }) + + t.Run("NilRootDoesNotPanic", func(t *testing.T) { + dummyNil := &api.MerkleTreePath{} + dummyShort := &api.MerkleTreePath{Root: ""} + dummyOK := &api.MerkleTreePath{Root: "0000"} + + joinedPath, err := JoinPaths(dummyNil, dummyOK) + assert.ErrorContains(t, err, "invalid child root hash format") + assert.Nil(t, joinedPath) + + joinedPath, err = JoinPaths(dummyShort, dummyOK) + assert.ErrorContains(t, err, "invalid child root hash format") + assert.Nil(t, joinedPath) + + joinedPath, err = JoinPaths(dummyOK, dummyNil) + assert.ErrorContains(t, err, "invalid parent root hash format") + assert.Nil(t, joinedPath) + + joinedPath, err = JoinPaths(dummyOK, dummyShort) + assert.ErrorContains(t, err, "invalid parent root hash format") + assert.Nil(t, joinedPath) + }) + + t.Run("HashFunctionMismatch", func(t *testing.T) { + child := &api.MerkleTreePath{Root: "0000"} + parent := &api.MerkleTreePath{Root: "0001"} + + joinedPath, err := JoinPaths(child, parent) + assert.ErrorContains(t, err, "child hash algorithm does not match parent") + assert.Nil(t, joinedPath) + }) + + t.Run("HashValueMismatch", func(t *testing.T) { + smt := NewSparseMerkleTree(api.SHA256, 1) + smt.AddLeaf(big.NewInt(0b10), []byte{0}) + path, _ := smt.GetPath(big.NewInt(0b10)) + + joinedPath, err := JoinPaths(path, path) + assert.ErrorContains(t, err, "child root hash does not match parent input hash") + assert.Nil(t, joinedPath) + }) +} + +// TestParentSMTSnapshotUpdateLeaf tests that parent SMT snapshots can update pre-populated leaves +func TestParentSMTSnapshotUpdateLeaf(t *testing.T) { + // Create parent SMT with ShardIDLength=1 (creates pre-populated leaves at paths 2 and 3) + parentSMT := NewParentSparseMerkleTree(api.SHA256, 1) + + // Create snapshot for copy-on-write semantics + snapshot := parentSMT.CreateSnapshot() + + // Update a pre-populated leaf (shard ID 2) + path := big.NewInt(2) + value := []byte{0x01, 0x02, 0x03} + + err := snapshot.AddLeaf(path, value) + if err != nil { + t.Fatalf("Failed to update leaf in parent SMT snapshot: %v", err) + } + + // Verify the update worked by checking root hash changed + rootHash := snapshot.GetRootHash() + if len(rootHash) == 0 { + t.Fatal("Expected non-empty root hash after updating leaf") + } +} diff --git a/internal/smt/thread_safe_smt.go b/internal/smt/thread_safe_smt.go index c3478aa..6efee2a 100644 --- a/internal/smt/thread_safe_smt.go +++ b/internal/smt/thread_safe_smt.go @@ -45,6 +45,17 @@ func (ts *ThreadSafeSMT) AddLeaf(path *big.Int, value []byte) error { return ts.smt.AddLeaf(path, value) } +// AddPreHashedLeaf adds a leaf where the value is already a hash calculated externally +// This operation is exclusive and blocks all other operations +func (ts *ThreadSafeSMT) AddPreHashedLeaf(path *big.Int, hash []byte) error { + ts.rwMux.Lock() + defer ts.rwMux.Unlock() + + // TODO(SMT): Implement AddPreHashedLeaf in SparseMerkleTree + //return ts.smt.AddPreHashedLeaf(path, hash) + return nil +} + // GetRootHash returns the current root hash // This is a read operation that can be performed concurrently func (ts *ThreadSafeSMT) GetRootHash() string { @@ -65,7 +76,7 @@ func (ts *ThreadSafeSMT) GetLeaf(path *big.Int) (*LeafBranch, error) { // GetPath generates a Merkle tree path for the given path // This is a read operation and allows concurrent access -func (ts *ThreadSafeSMT) GetPath(path *big.Int) *api.MerkleTreePath { +func (ts *ThreadSafeSMT) GetPath(path *big.Int) (*api.MerkleTreePath, error) { ts.rwMux.RLock() defer ts.rwMux.RUnlock() return ts.smt.GetPath(path) diff --git a/internal/smt/thread_safe_smt_snapshot.go b/internal/smt/thread_safe_smt_snapshot.go index ee3c89b..2e6e145 100644 --- a/internal/smt/thread_safe_smt_snapshot.go +++ b/internal/smt/thread_safe_smt_snapshot.go @@ -65,7 +65,7 @@ func (tss *ThreadSafeSmtSnapshot) GetRootHash() string { return tss.snapshot.GetRootHashHex() } -func (tss *ThreadSafeSmtSnapshot) GetPath(path *big.Int) *api.MerkleTreePath { +func (tss *ThreadSafeSmtSnapshot) GetPath(path *big.Int) (*api.MerkleTreePath, error) { tss.rwMux.RLock() defer tss.rwMux.RUnlock() diff --git a/internal/smt/thread_safe_smt_snapshot_test.go b/internal/smt/thread_safe_smt_snapshot_test.go index 547af66..59f8d1f 100644 --- a/internal/smt/thread_safe_smt_snapshot_test.go +++ b/internal/smt/thread_safe_smt_snapshot_test.go @@ -167,7 +167,8 @@ func TestThreadSafeSMTSnapshot(t *testing.T) { assert.Equal(t, value, leaf.Value, "Retrieved leaf value should match") // Test path generation on original SMT - merkleTreePath := threadSafeSMT.GetPath(path) + merkleTreePath, err := threadSafeSMT.GetPath(path) + require.NoError(t, err) assert.NotNil(t, merkleTreePath, "Should be able to get Merkle tree path from original SMT") assert.NotEmpty(t, merkleTreePath.Root, "Root should not be empty") assert.NotEmpty(t, merkleTreePath.Steps, "Steps should not be empty") diff --git a/internal/storage/interfaces/interfaces.go b/internal/storage/interfaces/interfaces.go index 479a4b6..ea8ce05 100644 --- a/internal/storage/interfaces/interfaces.go +++ b/internal/storage/interfaces/interfaces.go @@ -100,6 +100,9 @@ type SmtStorage interface { // StoreBatch stores multiple SMT nodes StoreBatch(ctx context.Context, nodes []*models.SmtNode) error + // UpsertBatch stores or updates multiple SMT nodes, replacing existing values + UpsertBatch(ctx context.Context, nodes []*models.SmtNode) error + // GetByKey retrieves an SMT node by key GetByKey(ctx context.Context, key api.HexBytes) (*models.SmtNode, error) @@ -144,6 +147,7 @@ type BlockRecordsStorage interface { GetLatestBlock(ctx context.Context) (*models.BlockRecords, error) } + // LeadershipStorage handles high availability leadership state type LeadershipStorage interface { // TryAcquireLock attempts to acquire the leadership lock, diff --git a/internal/storage/mongodb/aggregator_record.go b/internal/storage/mongodb/aggregator_record.go index 81ce632..4822e45 100644 --- a/internal/storage/mongodb/aggregator_record.go +++ b/internal/storage/mongodb/aggregator_record.go @@ -29,9 +29,11 @@ func NewAggregatorRecordStorage(db *mongo.Database) *AggregatorRecordStorage { // Store stores a new aggregator record func (ars *AggregatorRecordStorage) Store(ctx context.Context, record *models.AggregatorRecord) error { - recordBSON := record.ToBSON() - _, err := ars.collection.InsertOne(ctx, recordBSON) + recordBSON, err := record.ToBSON() if err != nil { + return fmt.Errorf("failed to marshal aggregator record to BSON: %w", err) + } + if _, err := ars.collection.InsertOne(ctx, recordBSON); err != nil { return fmt.Errorf("failed to store aggregator record: %w", err) } return nil @@ -43,13 +45,17 @@ func (ars *AggregatorRecordStorage) StoreBatch(ctx context.Context, records []*m return nil } + var err error documents := make([]interface{}, len(records)) for i, record := range records { - documents[i] = record.ToBSON() + documents[i], err = record.ToBSON() + if err != nil { + return fmt.Errorf("failed to marshal aggregator record to BSON: %w", err) + } } opts := options.InsertMany().SetOrdered(false) - _, err := ars.collection.InsertMany(ctx, documents, opts) + _, err = ars.collection.InsertMany(ctx, documents, opts) if err != nil { if mongo.IsDuplicateKeyError(err) { return nil @@ -79,20 +85,19 @@ func (ars *AggregatorRecordStorage) GetByRequestID(ctx context.Context, requestI // GetByBlockNumber retrieves all records for a specific block func (ars *AggregatorRecordStorage) GetByBlockNumber(ctx context.Context, blockNumber *api.BigInt) ([]*models.AggregatorRecord, error) { - filter := bson.M{"blockNumber": blockNumber.String()} + filter := bson.M{"blockNumber": bigIntToDecimal128(blockNumber)} cursor, err := ars.collection.Find(ctx, filter) if err != nil { return nil, fmt.Errorf("failed to find records by block number: %w", err) } defer cursor.Close(ctx) - var records []*models.AggregatorRecord + records := make([]*models.AggregatorRecord, 0) for cursor.Next(ctx) { var recordBSON models.AggregatorRecordBSON if err := cursor.Decode(&recordBSON); err != nil { return nil, fmt.Errorf("failed to decode aggregator record: %w", err) } - record, err := recordBSON.FromBSON() if err != nil { return nil, fmt.Errorf("failed to convert from BSON: %w", err) diff --git a/internal/storage/mongodb/aggregator_record_test.go b/internal/storage/mongodb/aggregator_record_test.go index 7a74f5c..9b6bcff 100644 --- a/internal/storage/mongodb/aggregator_record_test.go +++ b/internal/storage/mongodb/aggregator_record_test.go @@ -17,7 +17,7 @@ import ( ) // setupAggregatorRecordTestDB creates a test database connection using Testcontainers -func setupAggregatorRecordTestDB(t *testing.T) (*mongo.Database, func()) { +func setupAggregatorRecordTestDB(t *testing.T) *mongo.Database { ctx := context.Background() // Create MongoDB container @@ -46,7 +46,7 @@ func setupAggregatorRecordTestDB(t *testing.T) (*mongo.Database, func()) { db := client.Database("test_aggregator_records") - cleanup := func() { + t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -57,9 +57,9 @@ func setupAggregatorRecordTestDB(t *testing.T) (*mongo.Database, func()) { if err := mongoContainer.Terminate(ctx); err != nil { t.Logf("Failed to terminate MongoDB container: %v", err) } - } + }) - return db, cleanup + return db } // createTestAggregatorRecord creates a test aggregator record @@ -84,8 +84,7 @@ func createTestAggregatorRecord(requestID string, blockNumber int64, leafIndex i } func TestAggregatorRecordStorage_StoreBatch_DuplicateHandling(t *testing.T) { - db, cleanup := setupAggregatorRecordTestDB(t) - defer cleanup() + db := setupAggregatorRecordTestDB(t) storage := NewAggregatorRecordStorage(db) ctx := context.Background() @@ -131,3 +130,69 @@ func TestAggregatorRecordStorage_StoreBatch_DuplicateHandling(t *testing.T) { require.NoError(t, err, "GetByRequestID should not return an error") assert.NotNil(t, record1, "Original record should still exist") } + +func TestAggregatorRecordStorage_GetByBlockNumber(t *testing.T) { + db := setupAggregatorRecordTestDB(t) + storage := NewAggregatorRecordStorage(db) + ctx := context.Background() + + // Create indexes first + err := storage.CreateIndexes(ctx) + require.NoError(t, err, "CreateIndexes should not return an error") + + t.Run("should return empty slice when no records exist", func(t *testing.T) { + blockNum := api.NewBigInt(big.NewInt(100)) + retrieved, err := storage.GetByBlockNumber(ctx, blockNum) + require.NoError(t, err) + require.Len(t, retrieved, 0) + }) + + // Store some test records + records := []*models.AggregatorRecord{ + createTestAggregatorRecord("req1-b100", 100, 0), + createTestAggregatorRecord("req2-b100", 100, 1), + createTestAggregatorRecord("req3-b100", 100, 2), + createTestAggregatorRecord("req1-b101", 101, 0), + createTestAggregatorRecord("req2-b101", 101, 1), + createTestAggregatorRecord("req1-b0", 0, 0), + } + err = storage.StoreBatch(ctx, records) + require.NoError(t, err, "StoreBatch should not return an error") + + largeBlockNumberRecord := createTestAggregatorRecord("req1-large", 99999999999999999, 0) + err = storage.Store(ctx, largeBlockNumberRecord) + require.NoError(t, err, "Store should not return an error for large block number") + + t.Run("should return records for a specific block number", func(t *testing.T) { + blockNum := api.NewBigInt(big.NewInt(100)) + retrieved, err := storage.GetByBlockNumber(ctx, blockNum) + require.NoError(t, err) + require.NotNil(t, retrieved) + require.Len(t, retrieved, 3) + + // Check request IDs to be sure + requestIDs := make(map[api.RequestID]bool) + for _, r := range retrieved { + requestIDs[r.RequestID] = true + } + require.True(t, requestIDs["req1-b100"]) + require.True(t, requestIDs["req2-b100"]) + require.True(t, requestIDs["req3-b100"]) + }) + + t.Run("should return empty slice for non-existent block number", func(t *testing.T) { + blockNum := api.NewBigInt(big.NewInt(999)) + retrieved, err := storage.GetByBlockNumber(ctx, blockNum) + require.NoError(t, err) + require.Len(t, retrieved, 0) + }) + + t.Run("should handle zero block number", func(t *testing.T) { + blockNum := api.NewBigInt(big.NewInt(0)) + retrieved, err := storage.GetByBlockNumber(ctx, blockNum) + require.NoError(t, err) + require.NotNil(t, retrieved) + require.Len(t, retrieved, 1) + require.Equal(t, api.RequestID("req1-b0"), retrieved[0].RequestID) + }) +} diff --git a/internal/storage/mongodb/block.go b/internal/storage/mongodb/block.go index f0e926e..8f144ca 100644 --- a/internal/storage/mongodb/block.go +++ b/internal/storage/mongodb/block.go @@ -40,7 +40,11 @@ func bigIntToDecimal128(bigInt *api.BigInt) primitive.Decimal128 { // Store stores a new block func (bs *BlockStorage) Store(ctx context.Context, block *models.Block) error { - _, err := bs.collection.InsertOne(ctx, block.ToBSON()) + blockBSON, err := block.ToBSON() + if err != nil { + return fmt.Errorf("failed to convert block to bson: %w", err) + } + _, err = bs.collection.InsertOne(ctx, blockBSON) if err != nil { return fmt.Errorf("failed to store block: %w", err) } diff --git a/internal/storage/mongodb/block_records.go b/internal/storage/mongodb/block_records.go index f58d442..59ed35e 100644 --- a/internal/storage/mongodb/block_records.go +++ b/internal/storage/mongodb/block_records.go @@ -33,8 +33,11 @@ func (brs *BlockRecordsStorage) Store(ctx context.Context, records *models.Block if records == nil { return errors.New("block records is nil") } - _, err := brs.collection.InsertOne(ctx, records.ToBSON()) + recordsBSON, err := records.ToBSON() if err != nil { + return fmt.Errorf("failed to convert block records to BSON: %w", err) + } + if _, err = brs.collection.InsertOne(ctx, recordsBSON); err != nil { return fmt.Errorf("failed to store block records: %w", err) } return nil @@ -61,7 +64,7 @@ func (brs *BlockRecordsStorage) GetByBlockNumber(ctx context.Context, blockNumbe // GetByRequestID retrieves the block number for a request ID func (brs *BlockRecordsStorage) GetByRequestID(ctx context.Context, requestID api.RequestID) (*api.BigInt, error) { - filter := bson.M{"requestIds": requestID} + filter := bson.M{"requestIds": requestID.String()} opts := options.FindOne().SetProjection(bson.M{"blockNumber": 1}) var result struct { diff --git a/internal/storage/mongodb/block_records_test.go b/internal/storage/mongodb/block_records_test.go index bd7cc6d..abc97e3 100644 --- a/internal/storage/mongodb/block_records_test.go +++ b/internal/storage/mongodb/block_records_test.go @@ -750,16 +750,21 @@ func TestBlockRecordsStorage_Store_BSON(t *testing.T) { } originalRecords := createTestBlockRecords(blockNumber, requestIDs) + originalBSON, err := originalRecords.ToBSON() + require.NoError(t, err) // Marshal to BSON - bsonData, err := bson.Marshal(originalRecords) + bsonData, err := bson.Marshal(originalBSON) require.NoError(t, err, "Should be able to marshal BlockRecords to BSON") // Unmarshal from BSON - var unmarshaledRecords models.BlockRecords - err = bson.Unmarshal(bsonData, &unmarshaledRecords) + var unmarshaledRecordsBSON models.BlockRecordsBSON + err = bson.Unmarshal(bsonData, &unmarshaledRecordsBSON) require.NoError(t, err, "Should be able to unmarshal BlockRecords from BSON") + unmarshaledRecords, err := unmarshaledRecordsBSON.FromBSON() + require.NoError(t, err) + // Verify the data matches assert.Equal(t, originalRecords.BlockNumber.String(), unmarshaledRecords.BlockNumber.String()) assert.Equal(t, len(originalRecords.RequestIDs), len(unmarshaledRecords.RequestIDs)) @@ -778,15 +783,19 @@ func TestBlockRecordsStorage_Store_BSON(t *testing.T) { requestIDs := []api.RequestID{} originalRecords := createTestBlockRecords(blockNumber, requestIDs) + originalBSON, err := originalRecords.ToBSON() + require.NoError(t, err) // Marshal to BSON - bsonData, err := bson.Marshal(originalRecords) + bsonData, err := bson.Marshal(originalBSON) require.NoError(t, err, "Should be able to marshal BlockRecords with empty requestIDs to BSON") // Unmarshal from BSON - var unmarshaledRecords models.BlockRecords - err = bson.Unmarshal(bsonData, &unmarshaledRecords) + var unmarshaledRecordsBSON models.BlockRecordsBSON + err = bson.Unmarshal(bsonData, &unmarshaledRecordsBSON) require.NoError(t, err, "Should be able to unmarshal BlockRecords with empty requestIDs from BSON") + unmarshaledRecords, err := unmarshaledRecordsBSON.FromBSON() + require.NoError(t, err) // Verify the data matches assert.Equal(t, originalRecords.BlockNumber.String(), unmarshaledRecords.BlockNumber.String()) @@ -796,7 +805,7 @@ func TestBlockRecordsStorage_Store_BSON(t *testing.T) { t.Run("should marshal and unmarshal large block numbers", func(t *testing.T) { // Create test data with large block number - largeNumber, ok := new(big.Int).SetString("999999999999999999999999999999999999999999", 10) + largeNumber, ok := new(big.Int).SetString("999999999999999999999999999999", 10) require.True(t, ok, "Should be able to create large big.Int") blockNumber := api.NewBigInt(largeNumber) @@ -805,15 +814,19 @@ func TestBlockRecordsStorage_Store_BSON(t *testing.T) { } originalRecords := createTestBlockRecords(blockNumber, requestIDs) + originalBSON, err := originalRecords.ToBSON() + require.NoError(t, err) // Marshal to BSON - bsonData, err := bson.Marshal(originalRecords) + bsonData, err := bson.Marshal(originalBSON) require.NoError(t, err, "Should be able to marshal BlockRecords with large block number to BSON") // Unmarshal from BSON - var unmarshaledRecords models.BlockRecords - err = bson.Unmarshal(bsonData, &unmarshaledRecords) + var unmarshaledRecordsBSON models.BlockRecordsBSON + err = bson.Unmarshal(bsonData, &unmarshaledRecordsBSON) require.NoError(t, err, "Should be able to unmarshal BlockRecords with large block number from BSON") + unmarshaledRecords, err := unmarshaledRecordsBSON.FromBSON() + require.NoError(t, err) // Verify the data matches assert.Equal(t, originalRecords.BlockNumber.String(), unmarshaledRecords.BlockNumber.String()) @@ -842,6 +855,8 @@ func TestBlockRecordsStorage_Store_Comprehensive(t *testing.T) { // Create BlockRecords originalRecords := createTestBlockRecords(blockNumber, requestIDs) + originalRecordsBSON, err := originalRecords.ToBSON() + require.NoError(t, err) // Verify structure is correct assert.NotNil(t, originalRecords) @@ -850,12 +865,14 @@ func TestBlockRecordsStorage_Store_Comprehensive(t *testing.T) { assert.NotNil(t, originalRecords.CreatedAt) // Test BSON round-trip - bsonData, err := bson.Marshal(originalRecords) + bsonData, err := bson.Marshal(originalRecordsBSON) require.NoError(t, err, "Should marshal BlockRecords to BSON") - var unmarshaledRecords models.BlockRecords - err = bson.Unmarshal(bsonData, &unmarshaledRecords) + var unmarshaledRecordsBSON models.BlockRecordsBSON + err = bson.Unmarshal(bsonData, &unmarshaledRecordsBSON) require.NoError(t, err, "Should unmarshal BlockRecords from BSON") + unmarshaledRecords, err := unmarshaledRecordsBSON.FromBSON() + require.NoError(t, err) // Verify all data is preserved through BSON round-trip assert.Equal(t, originalRecords.BlockNumber.String(), unmarshaledRecords.BlockNumber.String()) @@ -913,7 +930,7 @@ func TestBlockRecordsStorage_Store_Comprehensive(t *testing.T) { { name: "large block number", blockNumber: func() *api.BigInt { - large, _ := new(big.Int).SetString("999999999999999999999999999999999999999999", 10) + large, _ := new(big.Int).SetString("999999999999999999999999999999999", 10) return api.NewBigInt(large) }(), requestIDs: []api.RequestID{"ffff000000000000000000000000000000000000000000000000000000000000"}, @@ -923,14 +940,18 @@ func TestBlockRecordsStorage_Store_Comprehensive(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { records := createTestBlockRecords(tc.blockNumber, tc.requestIDs) + recordsBSON, err := records.ToBSON() + require.NoError(t, err) // Test BSON serialization - bsonData, err := bson.Marshal(records) + bsonData, err := bson.Marshal(recordsBSON) require.NoError(t, err, "Should marshal edge case to BSON") - var unmarshaled models.BlockRecords - err = bson.Unmarshal(bsonData, &unmarshaled) + var unmarshaledBSON models.BlockRecordsBSON + err = bson.Unmarshal(bsonData, &unmarshaledBSON) require.NoError(t, err, "Should unmarshal edge case from BSON") + unmarshaled, err := unmarshaledBSON.FromBSON() + require.NoError(t, err) // Verify data integrity assert.Equal(t, records.BlockNumber.String(), unmarshaled.BlockNumber.String()) diff --git a/internal/storage/mongodb/commitment.go b/internal/storage/mongodb/commitment.go index b96b125..150ed2d 100644 --- a/internal/storage/mongodb/commitment.go +++ b/internal/storage/mongodb/commitment.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" @@ -31,7 +32,7 @@ func NewCommitmentStorage(db *mongo.Database) *CommitmentStorage { // Store stores a new commitment func (cs *CommitmentStorage) Store(ctx context.Context, commitment *models.Commitment) error { - _, err := cs.collection.InsertOne(ctx, commitment) + _, err := cs.collection.InsertOne(ctx, commitment.ToBSON()) if err != nil { return fmt.Errorf("failed to store commitment: %w", err) } @@ -40,19 +41,25 @@ func (cs *CommitmentStorage) Store(ctx context.Context, commitment *models.Commi // GetByRequestID retrieves a commitment by request ID func (cs *CommitmentStorage) GetByRequestID(ctx context.Context, requestID api.RequestID) (*models.Commitment, error) { - var commitment models.Commitment - err := cs.collection.FindOne(ctx, bson.M{"requestId": requestID}).Decode(&commitment) + var commitmentBSON models.CommitmentBSON + err := cs.collection.FindOne(ctx, bson.M{"requestId": requestID}).Decode(&commitmentBSON) if err != nil { if errors.Is(err, mongo.ErrNoDocuments) { return nil, nil } return nil, fmt.Errorf("failed to get commitment by request ID: %w", err) } + + commitment, err := commitmentBSON.FromBSON() + if err != nil { + return nil, fmt.Errorf("failed to convert commitment from BSON: %w", err) + } + // Handle backward compatibility: default to 1 if AggregateRequestCount is 0 if commitment.AggregateRequestCount == 0 { commitment.AggregateRequestCount = 1 } - return &commitment, nil + return commitment, nil } // GetUnprocessedBatch retrieves a batch of unprocessed commitments @@ -87,15 +94,20 @@ func (cs *CommitmentStorage) GetUnprocessedBatchWithCursor(ctx context.Context, var newCursor string for cursor.Next(ctx) { - var commitment models.Commitment - if err := cursor.Decode(&commitment); err != nil { + var commitmentBSON models.CommitmentBSON + if err := cursor.Decode(&commitmentBSON); err != nil { return nil, "", fmt.Errorf("failed to decode commitment: %w", err) } + commitment, err := commitmentBSON.FromBSON() + if err != nil { + return nil, "", fmt.Errorf("failed to convert commitment from BSON: %w", err) + } + // Handle backward compatibility: default to 1 if AggregateRequestCount is 0 if commitment.AggregateRequestCount == 0 { commitment.AggregateRequestCount = 1 } - commitments = append(commitments, &commitment) + commitments = append(commitments, commitment) // Update cursor to the last fetched ID if commitment.ID != primitive.NilObjectID { newCursor = commitment.ID.Hex() @@ -121,7 +133,7 @@ func (cs *CommitmentStorage) MarkProcessed(ctx context.Context, entries []interf } filter := bson.M{"requestId": bson.M{"$in": requestIDs}} - update := bson.M{"$set": bson.M{"processedAt": api.Now()}} + update := bson.M{"$set": bson.M{"processedAt": time.Now()}} _, err := cs.collection.UpdateMany(ctx, filter, update) if err != nil { diff --git a/internal/storage/mongodb/connection.go b/internal/storage/mongodb/connection.go index cbff0ff..95df37a 100644 --- a/internal/storage/mongodb/connection.go +++ b/internal/storage/mongodb/connection.go @@ -122,6 +122,22 @@ func (s *Storage) Close(ctx context.Context) error { return s.client.Disconnect(ctx) } +// CleanAllCollections drops all collections in the database (useful for testing) +func (s *Storage) CleanAllCollections(ctx context.Context) error { + collections, err := s.database.ListCollectionNames(ctx, map[string]interface{}{}) + if err != nil { + return fmt.Errorf("failed to list collections: %w", err) + } + + for _, collName := range collections { + if err := s.database.Collection(collName).Drop(ctx); err != nil { + return fmt.Errorf("failed to drop collection %s: %w", collName, err) + } + } + + return nil +} + // WithTransaction executes a function within a MongoDB transaction func (s *Storage) WithTransaction(ctx context.Context, fn func(context.Context) error) error { session, err := s.client.StartSession() diff --git a/internal/storage/mongodb/smt.go b/internal/storage/mongodb/smt.go index a1ecaf5..5780846 100644 --- a/internal/storage/mongodb/smt.go +++ b/internal/storage/mongodb/smt.go @@ -29,14 +29,8 @@ func NewSmtStorage(db *mongo.Database) *SmtStorage { // Store stores a new SMT node using upsert to handle duplicates gracefully func (ss *SmtStorage) Store(ctx context.Context, node *models.SmtNode) error { - filter := bson.M{"key": node.Key} - update := bson.M{ - "$setOnInsert": bson.M{ - "key": node.Key, - "value": node.Value, - "createdAt": node.CreatedAt, - }, - } + filter := bson.M{"key": node.Key.String()} + update := bson.M{"$setOnInsert": node.ToBSON()} opts := options.Update().SetUpsert(true) _, err := ss.collection.UpdateOne(ctx, filter, update, opts) @@ -46,6 +40,36 @@ func (ss *SmtStorage) Store(ctx context.Context, node *models.SmtNode) error { return nil } +// UpsertBatch stores or updates multiple SMT nodes, replacing existing values for the same keys +func (ss *SmtStorage) UpsertBatch(ctx context.Context, nodes []*models.SmtNode) error { + if len(nodes) == 0 { + return nil + } + + // Use bulk write operations for efficiency + var operations []mongo.WriteModel + for _, node := range nodes { + filter := bson.M{"key": node.Key.String()} + update := bson.M{"$set": node.ToBSON()} + + operation := mongo.NewUpdateOneModel() + operation.SetFilter(filter) + operation.SetUpdate(update) + operation.SetUpsert(true) + + operations = append(operations, operation) + } + + // Execute bulk write + opts := options.BulkWrite().SetOrdered(false) // Allow partial success + _, err := ss.collection.BulkWrite(ctx, operations, opts) + if err != nil { + return fmt.Errorf("failed to upsert SMT nodes batch: %w", err) + } + + return nil +} + // StoreBatch stores multiple SMT nodes using insert operations, skipping duplicates func (ss *SmtStorage) StoreBatch(ctx context.Context, nodes []*models.SmtNode) error { if len(nodes) == 0 { @@ -54,11 +78,7 @@ func (ss *SmtStorage) StoreBatch(ctx context.Context, nodes []*models.SmtNode) e documents := make([]interface{}, len(nodes)) for i, node := range nodes { - documents[i] = bson.M{ - "key": node.Key, - "value": node.Value, - "createdAt": node.CreatedAt, - } + documents[i] = node.ToBSON() } opts := options.InsertMany().SetOrdered(false) @@ -74,15 +94,15 @@ func (ss *SmtStorage) StoreBatch(ctx context.Context, nodes []*models.SmtNode) e // GetByKey retrieves an SMT node by key func (ss *SmtStorage) GetByKey(ctx context.Context, key api.HexBytes) (*models.SmtNode, error) { - var node models.SmtNode - err := ss.collection.FindOne(ctx, bson.M{"key": key}).Decode(&node) + var nodeBSON models.SmtNodeBSON + err := ss.collection.FindOne(ctx, bson.M{"key": key.String()}).Decode(&nodeBSON) if err != nil { if errors.Is(err, mongo.ErrNoDocuments) { return nil, nil } return nil, fmt.Errorf("failed to get SMT node by key: %w", err) } - return &node, nil + return nodeBSON.FromBSON() } // GetByKeys retrieves multiple SMT nodes by their keys @@ -90,7 +110,13 @@ func (ss *SmtStorage) GetByKeys(ctx context.Context, keys []api.HexBytes) ([]*mo if len(keys) == 0 { return []*models.SmtNode{}, nil } - filter := bson.M{"key": bson.M{"$in": keys}} + + keyStrings := make([]string, len(keys)) + for i, key := range keys { + keyStrings[i] = key.String() + } + + filter := bson.M{"key": bson.M{"$in": keyStrings}} cursor, err := ss.collection.Find(ctx, filter) if err != nil { return nil, fmt.Errorf("failed to query SMT nodes by keys: %w", err) @@ -98,15 +124,27 @@ func (ss *SmtStorage) GetByKeys(ctx context.Context, keys []api.HexBytes) ([]*mo defer cursor.Close(ctx) var nodes []*models.SmtNode - if err := cursor.All(ctx, &nodes); err != nil { - return nil, fmt.Errorf("failed to decode SMT nodes: %w", err) + for cursor.Next(ctx) { + var nodeBSON models.SmtNodeBSON + if err := cursor.Decode(&nodeBSON); err != nil { + return nil, fmt.Errorf("failed to decode SMT node: %w", err) + } + node, err := nodeBSON.FromBSON() + if err != nil { + return nil, fmt.Errorf("failed to convert SMT node from BSON: %w", err) + } + nodes = append(nodes, node) + } + + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("cursor error: %w", err) } return nodes, nil } // Delete removes an SMT node func (ss *SmtStorage) Delete(ctx context.Context, key api.HexBytes) error { - _, err := ss.collection.DeleteOne(ctx, bson.M{"key": key}) + _, err := ss.collection.DeleteOne(ctx, bson.M{"key": key.String()}) if err != nil { return fmt.Errorf("failed to delete SMT node: %w", err) } @@ -119,7 +157,12 @@ func (ss *SmtStorage) DeleteBatch(ctx context.Context, keys []api.HexBytes) erro return nil } - filter := bson.M{"key": bson.M{"$in": keys}} + keyStrings := make([]string, len(keys)) + for i, key := range keys { + keyStrings[i] = key.String() + } + + filter := bson.M{"key": bson.M{"$in": keyStrings}} _, err := ss.collection.DeleteMany(ctx, filter) if err != nil { return fmt.Errorf("failed to delete SMT nodes batch: %w", err) @@ -147,11 +190,15 @@ func (ss *SmtStorage) GetAll(ctx context.Context) ([]*models.SmtNode, error) { var nodes []*models.SmtNode for cursor.Next(ctx) { - var node models.SmtNode - if err := cursor.Decode(&node); err != nil { + var nodeBSON models.SmtNodeBSON + if err := cursor.Decode(&nodeBSON); err != nil { return nil, fmt.Errorf("failed to decode SMT node: %w", err) } - nodes = append(nodes, &node) + node, err := nodeBSON.FromBSON() + if err != nil { + return nil, fmt.Errorf("failed to decode SMT node: %w", err) + } + nodes = append(nodes, node) } if err := cursor.Err(); err != nil { @@ -176,11 +223,15 @@ func (ss *SmtStorage) GetChunked(ctx context.Context, offset, limit int) ([]*mod var nodes []*models.SmtNode for cursor.Next(ctx) { - var node models.SmtNode - if err := cursor.Decode(&node); err != nil { + var nodeBSON models.SmtNodeBSON + if err := cursor.Decode(&nodeBSON); err != nil { + return nil, fmt.Errorf("failed to decode SMT node: %w", err) + } + node, err := nodeBSON.FromBSON() + if err != nil { return nil, fmt.Errorf("failed to decode SMT node: %w", err) } - nodes = append(nodes, &node) + nodes = append(nodes, node) } if err := cursor.Err(); err != nil { diff --git a/internal/storage/mongodb/smt_test.go b/internal/storage/mongodb/smt_test.go index 4bdc1db..3125bf5 100644 --- a/internal/storage/mongodb/smt_test.go +++ b/internal/storage/mongodb/smt_test.go @@ -6,7 +6,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go/modules/mongodb" "go.mongodb.org/mongo-driver/mongo" @@ -22,7 +21,7 @@ const ( ) // setupSmtTestDB creates a test database connection using Testcontainers -func setupSmtTestDB(t *testing.T) (*mongo.Database, func()) { +func setupSmtTestDB(t *testing.T) *mongo.Database { ctx := context.Background() // Create MongoDB container @@ -33,29 +32,24 @@ func setupSmtTestDB(t *testing.T) (*mongo.Database, func()) { // Get connection URI mongoURI, err := mongoContainer.ConnectionString(ctx) - if err != nil { - t.Fatalf("Failed to get MongoDB connection string: %v", err) - } + require.NoError(t, err) // Connect to MongoDB connectCtx, cancel := context.WithTimeout(ctx, smtTestTimeout) defer cancel() client, err := mongo.Connect(connectCtx, options.Client().ApplyURI(mongoURI)) - if err != nil { - t.Fatalf("Failed to connect to MongoDB: %v", err) - } + require.NoError(t, err) // Ping to verify connection - if err := client.Ping(connectCtx, nil); err != nil { - t.Fatalf("Failed to ping MongoDB: %v", err) - } + err = client.Ping(connectCtx, nil) + require.NoError(t, err) // Create test database db := client.Database("test_smt_db") // Cleanup function - cleanup := func() { + t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), smtTestTimeout) defer cancel() @@ -73,18 +67,9 @@ func setupSmtTestDB(t *testing.T) (*mongo.Database, func()) { if err := mongoContainer.Terminate(ctx); err != nil { t.Logf("Failed to terminate MongoDB container: %v", err) } - } - - return db, cleanup -} + }) -// createTestSmtNode creates a test SMT node -func createTestSmtNode(key api.HexBytes, value []byte) *models.SmtNode { - return &models.SmtNode{ - Key: key, - Value: value, - CreatedAt: api.Now(), - } + return db } // createTestSmtNodes creates multiple test SMT nodes with truly unique keys @@ -106,14 +91,13 @@ func createTestSmtNodes(count int) []*models.SmtNode { value[30] = byte((i + 100) >> 8) value[31] = byte(i + 100) - nodes[i] = createTestSmtNode(api.HexBytes(key), value) + nodes[i] = models.NewSmtNode(key, value) } return nodes } func TestSmtStorage_Store(t *testing.T) { - db, cleanup := setupSmtTestDB(t) - defer cleanup() + db := setupSmtTestDB(t) storage := NewSmtStorage(db) ctx := context.Background() @@ -124,10 +108,9 @@ func TestSmtStorage_Store(t *testing.T) { t.Run("should store valid SMT node", func(t *testing.T) { // Create test data - key := api.HexBytes([]byte("test_key_1234567890abcdef1234567890ab")) - value := []byte("test_value_1234567890abcdef") - - node := createTestSmtNode(key, value) + key := api.HexBytes("test_key_1234567890abcdef1234567890ab") + value := api.HexBytes("test_value_1234567890abcdef") + node := models.NewSmtNode(key, value) // Store the node err := storage.Store(ctx, node) @@ -139,9 +122,9 @@ func TestSmtStorage_Store(t *testing.T) { require.NotNil(t, storedNode, "Retrieved node should not be nil") // Verify the stored data matches - assert.Equal(t, key, storedNode.Key) - assert.Equal(t, api.HexBytes(value), storedNode.Value) - assert.NotNil(t, storedNode.CreatedAt) + require.Equal(t, key, storedNode.Key) + require.Equal(t, value, storedNode.Value) + require.NotNil(t, storedNode.CreatedAt) }) t.Run("should handle duplicate key on single insert", func(t *testing.T) { @@ -150,8 +133,8 @@ func TestSmtStorage_Store(t *testing.T) { value1 := []byte("first_value_123456789") value2 := []byte("second_value_987654321") - node1 := createTestSmtNode(key, value1) - node2 := createTestSmtNode(key, value2) + node1 := models.NewSmtNode(key, value1) + node2 := models.NewSmtNode(key, value2) // Store the first node err := storage.Store(ctx, node1) @@ -159,7 +142,7 @@ func TestSmtStorage_Store(t *testing.T) { // Attempt to store the second node with the same key - should succeed (ignore duplicate) err = storage.Store(ctx, node2) - assert.NoError(t, err, "Second store should not return an error (duplicate should be ignored)") + require.NoError(t, err, "Second store should not return an error (duplicate should be ignored)") // Verify only the first node remains in storage storedNode, err := storage.GetByKey(ctx, key) @@ -167,8 +150,8 @@ func TestSmtStorage_Store(t *testing.T) { require.NotNil(t, storedNode, "Retrieved node should not be nil") // Should have the first node's data (not overwritten) - assert.Equal(t, key, storedNode.Key) - assert.Equal(t, api.HexBytes(value1), storedNode.Value) + require.Equal(t, key, storedNode.Key) + require.Equal(t, api.HexBytes(value1), storedNode.Value) }) t.Run("should store multiple different nodes", func(t *testing.T) { @@ -187,16 +170,14 @@ func TestSmtStorage_Store(t *testing.T) { require.NoError(t, err, "Should be able to retrieve stored node %d", i) require.NotNil(t, storedNode, "Retrieved node %d should not be nil", i) - assert.Equal(t, []byte(node.Key), []byte(storedNode.Key)) - assert.Equal(t, node.Value, storedNode.Value) + require.Equal(t, []byte(node.Key), []byte(storedNode.Key)) + require.Equal(t, node.Value, storedNode.Value) } }) } func TestSmtStorage_StoreBatch(t *testing.T) { - db, cleanup := setupSmtTestDB(t) - defer cleanup() - + db := setupSmtTestDB(t) storage := NewSmtStorage(db) ctx := context.Background() @@ -218,8 +199,8 @@ func TestSmtStorage_StoreBatch(t *testing.T) { require.NoError(t, err, "Should be able to retrieve stored node %d", i) require.NotNil(t, storedNode, "Retrieved node %d should not be nil", i) - assert.Equal(t, []byte(node.Key), []byte(storedNode.Key)) - assert.Equal(t, node.Value, storedNode.Value) + require.Equal(t, []byte(node.Key), []byte(storedNode.Key)) + require.Equal(t, node.Value, storedNode.Value) } }) @@ -235,8 +216,8 @@ func TestSmtStorage_StoreBatch(t *testing.T) { batch2 := make([]*models.SmtNode, 5) // First 2 nodes have duplicate keys from batch1 - batch2[0] = createTestSmtNode(batch1[0].Key, []byte("new_value_1")) - batch2[1] = createTestSmtNode(batch1[1].Key, []byte("new_value_2")) + batch2[0] = models.NewSmtNode(batch1[0].Key, []byte("new_value_1")) + batch2[1] = models.NewSmtNode(batch1[1].Key, []byte("new_value_2")) // Last 3 nodes have new unique keys for i := 2; i < 5; i++ { @@ -248,12 +229,12 @@ func TestSmtStorage_StoreBatch(t *testing.T) { key[31] = byte(i + 200) value[31] = byte(i + 50) - batch2[i] = createTestSmtNode(api.HexBytes(key), value) + batch2[i] = models.NewSmtNode(api.HexBytes(key), value) } // Store second batch - should succeed despite duplicates err = storage.StoreBatch(ctx, batch2) - assert.NoError(t, err, "Second StoreBatch should not return an error (duplicates should be ignored)") + require.NoError(t, err, "Second StoreBatch should not return an error (duplicates should be ignored)") // Verify original nodes are unchanged (not overwritten) for i := 0; i < 2; i++ { @@ -262,8 +243,8 @@ func TestSmtStorage_StoreBatch(t *testing.T) { require.NotNil(t, storedNode, "Retrieved node %d should not be nil", i) // Should have original values, not the new ones - assert.Equal(t, batch1[i].Key, storedNode.Key) - assert.Equal(t, batch1[i].Value, storedNode.Value) + require.Equal(t, batch1[i].Key, storedNode.Key) + require.Equal(t, batch1[i].Value, storedNode.Value) } // Verify new unique nodes were stored @@ -272,8 +253,8 @@ func TestSmtStorage_StoreBatch(t *testing.T) { require.NoError(t, err, "Should be able to retrieve new node %d", i) require.NotNil(t, storedNode, "Retrieved new node %d should not be nil", i) - assert.Equal(t, batch2[i].Key, storedNode.Key) - assert.Equal(t, batch2[i].Value, storedNode.Value) + require.Equal(t, batch2[i].Key, storedNode.Key) + require.Equal(t, batch2[i].Value, storedNode.Value) } }) @@ -288,12 +269,12 @@ func TestSmtStorage_StoreBatch(t *testing.T) { for i := 0; i < 3; i++ { newValue := []byte(fmt.Sprintf("duplicate_value_%d", i)) - duplicateBatch[i] = createTestSmtNode(originalBatch[i].Key, newValue) + duplicateBatch[i] = models.NewSmtNode(originalBatch[i].Key, newValue) } // Store duplicate batch - should succeed and ignore all duplicates err = storage.StoreBatch(ctx, duplicateBatch) - assert.NoError(t, err, "Duplicate StoreBatch should not return an error") + require.NoError(t, err, "Duplicate StoreBatch should not return an error") // Verify original data is unchanged for i := 0; i < 3; i++ { @@ -302,19 +283,19 @@ func TestSmtStorage_StoreBatch(t *testing.T) { require.NotNil(t, storedNode, "Retrieved node %d should not be nil", i) // Should have original values, not the duplicate values - assert.Equal(t, originalBatch[i].Key, storedNode.Key) - assert.Equal(t, originalBatch[i].Value, storedNode.Value) + require.Equal(t, originalBatch[i].Key, storedNode.Key) + require.Equal(t, originalBatch[i].Value, storedNode.Value) } }) t.Run("should handle empty batch", func(t *testing.T) { // Store empty batch err := storage.StoreBatch(ctx, []*models.SmtNode{}) - assert.NoError(t, err, "Empty StoreBatch should not return an error") + require.NoError(t, err, "Empty StoreBatch should not return an error") // Store nil batch err = storage.StoreBatch(ctx, nil) - assert.NoError(t, err, "Nil StoreBatch should not return an error") + require.NoError(t, err, "Nil StoreBatch should not return an error") }) t.Run("should handle large batch with duplicates", func(t *testing.T) { @@ -332,7 +313,7 @@ func TestSmtStorage_StoreBatch(t *testing.T) { for i := 0; i < 50; i++ { newValue := []byte(fmt.Sprintf("duplicate_large_value_%d", i)) - mixedBatch[i] = createTestSmtNode(largeBatch[i].Key, newValue) + mixedBatch[i] = models.NewSmtNode(largeBatch[i].Key, newValue) } // Last 50 are new @@ -346,12 +327,12 @@ func TestSmtStorage_StoreBatch(t *testing.T) { key[31] = byte(i + 200) value[31] = byte(i % 256) - mixedBatch[i] = createTestSmtNode(api.HexBytes(key), value) + mixedBatch[i] = models.NewSmtNode(api.HexBytes(key), value) } // Store mixed batch - should succeed err = storage.StoreBatch(ctx, mixedBatch) - assert.NoError(t, err, "Mixed batch with duplicates should not return an error") + require.NoError(t, err, "Mixed batch with duplicates should not return an error") // Verify original data is preserved for duplicates for i := 0; i < 50; i++ { @@ -359,7 +340,7 @@ func TestSmtStorage_StoreBatch(t *testing.T) { require.NoError(t, err, "Should be able to retrieve original node %d", i) // Should have original values - assert.Equal(t, largeBatch[i].Value, storedNode.Value) + require.Equal(t, largeBatch[i].Value, storedNode.Value) } // Verify new data is stored @@ -368,15 +349,13 @@ func TestSmtStorage_StoreBatch(t *testing.T) { require.NoError(t, err, "Should be able to retrieve new node %d", i) // Should have new values - assert.Equal(t, mixedBatch[i].Value, storedNode.Value) + require.Equal(t, mixedBatch[i].Value, storedNode.Value) } }) } func TestSmtStorage_Count(t *testing.T) { - db, cleanup := setupSmtTestDB(t) - defer cleanup() - + db := setupSmtTestDB(t) storage := NewSmtStorage(db) ctx := context.Background() @@ -387,7 +366,7 @@ func TestSmtStorage_Count(t *testing.T) { // Initially should be 0 count, err := storage.Count(ctx) require.NoError(t, err, "Count should not return an error") - assert.Equal(t, int64(0), count, "Initial count should be 0") + require.Equal(t, int64(0), count, "Initial count should be 0") // Store some nodes nodes := createTestSmtNodes(5) @@ -397,7 +376,7 @@ func TestSmtStorage_Count(t *testing.T) { // Count should be 5 count, err = storage.Count(ctx) require.NoError(t, err, "Count should not return an error") - assert.Equal(t, int64(5), count, "Count should be 5 after storing 5 nodes") + require.Equal(t, int64(5), count, "Count should be 5 after storing 5 nodes") // Store duplicates err = storage.StoreBatch(ctx, nodes) @@ -406,14 +385,12 @@ func TestSmtStorage_Count(t *testing.T) { // Count should still be 5 (duplicates ignored) count, err = storage.Count(ctx) require.NoError(t, err, "Count should not return an error") - assert.Equal(t, int64(5), count, "Count should still be 5 after storing duplicates") + require.Equal(t, int64(5), count, "Count should still be 5 after storing duplicates") }) } func TestSmtStorage_GetChunked(t *testing.T) { - db, cleanup := setupSmtTestDB(t) - defer cleanup() - + db := setupSmtTestDB(t) storage := NewSmtStorage(db) ctx := context.Background() @@ -433,33 +410,31 @@ func TestSmtStorage_GetChunked(t *testing.T) { // Retrieve all nodes in chunks chunk1, err := storage.GetChunked(ctx, 0, 5) require.NoError(t, err, "GetChunked should not return an error") - assert.Len(t, chunk1, 5, "First chunk should have 5 nodes") + require.Len(t, chunk1, 5, "First chunk should have 5 nodes") chunk2, err := storage.GetChunked(ctx, 5, 5) require.NoError(t, err, "GetChunked should not return an error") - assert.Len(t, chunk2, 5, "Second chunk should have 5 nodes") + require.Len(t, chunk2, 5, "Second chunk should have 5 nodes") // Third chunk should be empty chunk3, err := storage.GetChunked(ctx, 10, 5) require.NoError(t, err, "GetChunked should not return an error") - assert.Len(t, chunk3, 0, "Third chunk should be empty") + require.Len(t, chunk3, 0, "Third chunk should be empty") // Verify we got all unique nodes (no duplicates) allChunkedKeys := make(map[string]bool) for _, node := range append(chunk1, chunk2...) { keyStr := string(node.Key) - assert.False(t, allChunkedKeys[keyStr], "Should not have duplicate keys in chunked results") + require.False(t, allChunkedKeys[keyStr], "Should not have duplicate keys in chunked results") allChunkedKeys[keyStr] = true } - assert.Len(t, allChunkedKeys, 10, "Should have exactly 10 unique keys") + require.Len(t, allChunkedKeys, 10, "Should have exactly 10 unique keys") }) } func TestSmtStorage_StoreBatch_DuplicateHandling(t *testing.T) { - db, cleanup := setupSmtTestDB(t) - defer cleanup() - + db := setupSmtTestDB(t) storage := NewSmtStorage(db) ctx := context.Background() @@ -472,7 +447,7 @@ func TestSmtStorage_StoreBatch_DuplicateHandling(t *testing.T) { nodes2 := []*models.SmtNode{ nodes1[0], // Duplicate of first node nodes1[1], // Duplicate of second node - createTestSmtNode(api.HexBytes("newkey"), []byte("newvalue")), // New node + models.NewSmtNode(api.HexBytes("newkey"), api.HexBytes("newvalue")), } // Store first batch @@ -481,16 +456,376 @@ func TestSmtStorage_StoreBatch_DuplicateHandling(t *testing.T) { // Store second batch with duplicates - should not error err = storage.StoreBatch(ctx, nodes2) - assert.NoError(t, err, "StoreBatch with duplicates should not return an error") + require.NoError(t, err, "StoreBatch with duplicates should not return an error") // Verify that we only have 4 unique nodes (3 from first batch + 1 new from second) count, err := storage.Count(ctx) require.NoError(t, err, "Count should not return an error") - assert.Equal(t, int64(4), count, "Should have exactly 4 nodes (duplicates ignored)") + require.Equal(t, int64(4), count, "Should have exactly 4 nodes (duplicates ignored)") // Verify the new node was stored newNode, err := storage.GetByKey(ctx, api.HexBytes("newkey")) require.NoError(t, err, "GetByKey should not return an error") - assert.NotNil(t, newNode, "New node should be found") - assert.Equal(t, []byte("newvalue"), []byte(newNode.Value), "New node should have correct value") + require.NotNil(t, newNode, "New node should be found") + require.Equal(t, []byte("newvalue"), []byte(newNode.Value), "New node should have correct value") +} + +func TestSmtStorage_GetByKeys(t *testing.T) { + db := setupSmtTestDB(t) + storage := NewSmtStorage(db) + ctx := context.Background() + + err := storage.CreateIndexes(ctx) + require.NoError(t, err, "CreateIndexes should not return an error") + + nodesToStore := createTestSmtNodes(10) + err = storage.StoreBatch(ctx, nodesToStore) + require.NoError(t, err, "StoreBatch should not return an error") + + t.Run("should retrieve multiple existing keys", func(t *testing.T) { + keysToGet := []api.HexBytes{ + nodesToStore[1].Key, + nodesToStore[3].Key, + nodesToStore[5].Key, + } + + retrievedNodes, err := storage.GetByKeys(ctx, keysToGet) + require.NoError(t, err, "GetByKeys should not return an error") + require.NotNil(t, retrievedNodes) + require.Len(t, retrievedNodes, 3, "Should retrieve exactly 3 nodes") + + // Verify the correct nodes were returned + expectedNodes := make(map[string]*models.SmtNode) + expectedNodes[string(nodesToStore[1].Key)] = nodesToStore[1] + expectedNodes[string(nodesToStore[3].Key)] = nodesToStore[3] + expectedNodes[string(nodesToStore[5].Key)] = nodesToStore[5] + + for _, retrieved := range retrievedNodes { + expected, ok := expectedNodes[string(retrieved.Key)] + require.True(t, ok, "Retrieved a key that was not requested") + require.Equal(t, expected.Value, retrieved.Value, "Value for key %s does not match", retrieved.Key) + } + }) + + t.Run("should handle a mix of existing and non-existing keys", func(t *testing.T) { + nonExistentKey := api.HexBytes("deadbeefdeadbeefdeadbeefdeadbeef") + keysToGet := []api.HexBytes{ + nodesToStore[0].Key, + nonExistentKey, + nodesToStore[2].Key, + } + + retrievedNodes, err := storage.GetByKeys(ctx, keysToGet) + require.NoError(t, err, "GetByKeys should not return an error") + require.NotNil(t, retrievedNodes) + + // Verify the results - should only get the 2 existing nodes + expectedNodes := make(map[string]*models.SmtNode) + expectedNodes[string(nodesToStore[0].Key)] = nodesToStore[0] + expectedNodes[string(nodesToStore[2].Key)] = nodesToStore[2] + + require.Len(t, retrievedNodes, len(expectedNodes), "Should retrieve only the existing nodes") + + for _, retrieved := range retrievedNodes { + keyStr := string(retrieved.Key) + expected, ok := expectedNodes[keyStr] + require.True(t, ok, "Retrieved a key that was not requested: %s", keyStr) + require.Equal(t, expected.Value, retrieved.Value, "Value for key %s does not match", keyStr) + // Remove from map to ensure no duplicates are retrieved and all were found + delete(expectedNodes, keyStr) + } + + require.Empty(t, expectedNodes, "Not all expected nodes were retrieved") + }) + + t.Run("should return an empty slice for all non-existing keys", func(t *testing.T) { + keysToGet := []api.HexBytes{ + api.HexBytes("facefeedfacefeedfacefeedfacefeed"), + api.HexBytes("badcoffeebadcoffeebadcoffeebadcoffee"), + } + + retrievedNodes, err := storage.GetByKeys(ctx, keysToGet) + require.NoError(t, err, "GetByKeys should not return an error") + require.Len(t, retrievedNodes, 0, "Should return an empty slice for non-existing keys") + }) + + t.Run("should return an empty slice for an empty key list", func(t *testing.T) { + retrievedNodes, err := storage.GetByKeys(ctx, []api.HexBytes{}) + require.NoError(t, err, "GetByKeys should not return an error") + require.NotNil(t, retrievedNodes) + require.Len(t, retrievedNodes, 0, "Should return an empty slice for an empty key list") + }) +} + +func TestSmtStorage_DeleteBatch(t *testing.T) { + db := setupSmtTestDB(t) + storage := NewSmtStorage(db) + ctx := context.Background() + + err := storage.CreateIndexes(ctx) + require.NoError(t, err, "CreateIndexes should not return an error") + + t.Run("should delete a batch of existing nodes", func(t *testing.T) { + // 1. Store a known set of nodes for this specific test. + nodesToStore := createTestSmtNodes(10) + err := storage.StoreBatch(ctx, nodesToStore) + require.NoError(t, err, "StoreBatch should not return an error") + + // 2. Define keys to delete and a key to keep. + keysToDelete := []api.HexBytes{ + nodesToStore[2].Key, + nodesToStore[4].Key, + nodesToStore[6].Key, + } + keyToKeep := nodesToStore[0].Key + + // 3. Call DeleteBatch. + err = storage.DeleteBatch(ctx, keysToDelete) + require.NoError(t, err, "DeleteBatch should not return an error") + + // 4. Verify that the deleted keys are gone. + for _, deletedKey := range keysToDelete { + node, err := storage.GetByKey(ctx, deletedKey) + require.NoError(t, err) + require.Nil(t, node, "Deleted node with key %s should not be found", deletedKey) + } + + // 5. Verify a non-deleted node still exists. + stillExistsNode, err := storage.GetByKey(ctx, keyToKeep) + require.NoError(t, err) + require.NotNil(t, stillExistsNode, "Node that was not deleted should still exist") + }) + + t.Run("should handle a mix of existing and non-existing keys", func(t *testing.T) { + // 1. Store a known set of nodes. + nodesToStore := createTestSmtNodes(5) + err := storage.StoreBatch(ctx, nodesToStore) + require.NoError(t, err, "StoreBatch should not return an error") + + // 2. Define keys to delete, including some that don't exist. + keyToDelete1 := nodesToStore[1].Key + keyToDelete2 := nodesToStore[3].Key + nonExistentKey := api.HexBytes("deadbeefdeadbeefdeadbeefdeadbeef") + keysToDelete := []api.HexBytes{ + keyToDelete1, + nonExistentKey, + keyToDelete2, + } + + // 3. Call DeleteBatch - should not error. + err = storage.DeleteBatch(ctx, keysToDelete) + require.NoError(t, err, "DeleteBatch should not error on non-existent keys") + + // 4. Verify the nodes that should have been deleted are gone. + deletedNode1, err := storage.GetByKey(ctx, keyToDelete1) + require.NoError(t, err) + require.Nil(t, deletedNode1, "Node 1 should be deleted") + + deletedNode3, err := storage.GetByKey(ctx, keyToDelete2) + require.NoError(t, err) + require.Nil(t, deletedNode3, "Node 3 should be deleted") + }) + + t.Run("should handle an empty key list", func(t *testing.T) { + // 1. Store a known set of nodes. + nodesToStore := createTestSmtNodes(5) + err := storage.StoreBatch(ctx, nodesToStore) + require.NoError(t, err, "StoreBatch should not return an error") + + // 2. Get the initial state of one node. + nodeBefore, err := storage.GetByKey(ctx, nodesToStore[0].Key) + require.NoError(t, err) + require.NotNil(t, nodeBefore) + + // 3. Call DeleteBatch with an empty slice. + err = storage.DeleteBatch(ctx, []api.HexBytes{}) + require.NoError(t, err, "DeleteBatch with empty slice should not return an error") + + // 4. Verify the node is unchanged by fetching it again. + nodeAfter, err := storage.GetByKey(ctx, nodesToStore[0].Key) + require.NoError(t, err) + require.NotNil(t, nodeAfter, "Node should still exist after empty delete batch") + require.Equal(t, nodeBefore.Value, nodeAfter.Value) + }) +} + +func TestSmtStorage_Delete(t *testing.T) { + db := setupSmtTestDB(t) + storage := NewSmtStorage(db) + ctx := context.Background() + + err := storage.CreateIndexes(ctx) + require.NoError(t, err, "CreateIndexes should not return an error") + + t.Run("should delete an existing node", func(t *testing.T) { + // 1. Store a known set of nodes. + nodesToStore := createTestSmtNodes(3) + err := storage.StoreBatch(ctx, nodesToStore) + require.NoError(t, err, "StoreBatch should not return an error") + + keyToDelete := nodesToStore[1].Key + keyToKeep := nodesToStore[0].Key + + // 2. Call Delete. + err = storage.Delete(ctx, keyToDelete) + require.NoError(t, err, "Delete should not return an error") + + // 3. Verify the node is gone. + deletedNode, err := storage.GetByKey(ctx, keyToDelete) + require.NoError(t, err) + require.Nil(t, deletedNode, "Deleted node should not be found") + + // 4. Verify other nodes are unaffected. + keptNode, err := storage.GetByKey(ctx, keyToKeep) + require.NoError(t, err) + require.NotNil(t, keptNode, "Other nodes should not be deleted") + }) + + t.Run("should not error when deleting a non-existing node", func(t *testing.T) { + // 1. Store a node to ensure the collection is not empty. + nodeToStore := createTestSmtNodes(1)[0] + err := storage.Store(ctx, nodeToStore) + require.NoError(t, err, "Store should not return an error") + + // 2. Attempt to delete a key that does not exist. + nonExistentKey := api.HexBytes("deadbeefdeadbeefdeadbeefdeadbeef") + err = storage.Delete(ctx, nonExistentKey) + require.NoError(t, err, "Delete should not return an error for a non-existing key") + + // 3. Verify the original node is unaffected. + originalNode, err := storage.GetByKey(ctx, nodeToStore.Key) + require.NoError(t, err) + require.NotNil(t, originalNode, "Existing node should not be affected") + }) +} + +func TestSmtStorage_UpsertBatch(t *testing.T) { + db := setupSmtTestDB(t) + storage := NewSmtStorage(db) + ctx := context.Background() + + err := storage.CreateIndexes(ctx) + require.NoError(t, err, "CreateIndexes should not return an error") + + t.Run("should insert new nodes correctly", func(t *testing.T) { + nodesToInsert := createTestSmtNodes(5) + err := storage.UpsertBatch(ctx, nodesToInsert) + require.NoError(t, err, "UpsertBatch should not return an error for new nodes") + + // Verify all nodes were inserted + var keys []api.HexBytes + for _, node := range nodesToInsert { + keys = append(keys, node.Key) + } + retrievedNodes, err := storage.GetByKeys(ctx, keys) + require.NoError(t, err) + require.Len(t, retrievedNodes, 5) + }) + + t.Run("should update existing nodes", func(t *testing.T) { + // 1. Store an initial batch + initialNodes := createTestSmtNodes(3) + err := storage.StoreBatch(ctx, initialNodes) + require.NoError(t, err) + + // 2. Create a new batch with the same keys but different values + nodesToUpdate := []*models.SmtNode{ + models.NewSmtNode(initialNodes[0].Key, []byte("new-value-0")), + models.NewSmtNode(initialNodes[2].Key, []byte("new-value-2")), + } + + // 3. Call UpsertBatch + err = storage.UpsertBatch(ctx, nodesToUpdate) + require.NoError(t, err, "UpsertBatch should not return an error when updating nodes") + + // 4. Verify the nodes were updated + updatedNode0, err := storage.GetByKey(ctx, initialNodes[0].Key) + require.NoError(t, err) + require.NotNil(t, updatedNode0) + require.Equal(t, api.HexBytes("new-value-0"), updatedNode0.Value) + + updatedNode2, err := storage.GetByKey(ctx, initialNodes[2].Key) + require.NoError(t, err) + require.NotNil(t, updatedNode2) + require.Equal(t, api.HexBytes("new-value-2"), updatedNode2.Value) + + // 5. Verify the non-updated node is unchanged + unaffectedNode1, err := storage.GetByKey(ctx, initialNodes[1].Key) + require.NoError(t, err) + require.NotNil(t, unaffectedNode1) + require.Equal(t, initialNodes[1].Value, unaffectedNode1.Value) + }) + + t.Run("should handle a mix of new and existing nodes", func(t *testing.T) { + // 1. Store an initial batch + initialNodes := createTestSmtNodes(5) + err := storage.StoreBatch(ctx, initialNodes) + require.NoError(t, err) + countBefore, err := storage.Count(ctx) + require.NoError(t, err) + require.Equal(t, int64(5), countBefore) + + // 2. Create a mixed batch with truly unique new keys + newNode1 := models.NewSmtNode(api.HexBytes("new_key_1"), []byte("new-value-1")) + newNode2 := models.NewSmtNode(api.HexBytes("new_key_2"), []byte("new-value-2")) + + mixedBatch := []*models.SmtNode{ + // Update existing nodes + models.NewSmtNode(initialNodes[0].Key, []byte("updated-value-0")), + models.NewSmtNode(initialNodes[4].Key, []byte("updated-value-4")), + // Insert new nodes + newNode1, + newNode2, + } + + // 3. Call UpsertBatch + err = storage.UpsertBatch(ctx, mixedBatch) + require.NoError(t, err, "UpsertBatch should handle mixed operations") + + // 4. Verify counts + countAfter, err := storage.Count(ctx) + require.NoError(t, err) + require.Equal(t, int64(7), countAfter, "Count should be 7 (5 initial + 2 new)") + + // 5. Verify updated nodes + updatedNode, err := storage.GetByKey(ctx, initialNodes[0].Key) + require.NoError(t, err) + require.Equal(t, api.HexBytes("updated-value-0"), updatedNode.Value) + + // 6. Verify inserted nodes + insertedNode, err := storage.GetByKey(ctx, newNode1.Key) + require.NoError(t, err) + require.NotNil(t, insertedNode) + require.Equal(t, newNode1.Value, insertedNode.Value) + }) +} + +func TestSmtStorage_GetAll(t *testing.T) { + t.Run("should return all nodes from a populated collection", func(t *testing.T) { + db := setupSmtTestDB(t) + storage := NewSmtStorage(db) + ctx := context.Background() + err := storage.CreateIndexes(ctx) + require.NoError(t, err) + + nodesToStore := createTestSmtNodes(15) + err = storage.StoreBatch(ctx, nodesToStore) + require.NoError(t, err) + + allNodes, err := storage.GetAll(ctx) + require.NoError(t, err) + require.Len(t, allNodes, 15) + }) + + t.Run("should return an empty slice for an empty collection", func(t *testing.T) { + db := setupSmtTestDB(t) + storage := NewSmtStorage(db) + ctx := context.Background() + err := storage.CreateIndexes(ctx) + require.NoError(t, err) + + allNodes, err := storage.GetAll(ctx) + require.NoError(t, err) + require.Len(t, allNodes, 0) + }) } diff --git a/internal/testutil/storage.go b/internal/testutil/storage.go index 3b4195a..dd356ce 100644 --- a/internal/testutil/storage.go +++ b/internal/testutil/storage.go @@ -14,13 +14,13 @@ import ( ) // SetupTestStorage creates a complete storage instance with MongoDB using testcontainers -func SetupTestStorage(t *testing.T, conf config.Config) (*mongodb.Storage, func()) { +func SetupTestStorage(t *testing.T, conf config.Config) *mongodb.Storage { ctx := context.Background() container, err := mongoContainer.Run(ctx, "mongo:7.0") if err != nil { t.Skipf("Skipping MongoDB tests - cannot start MongoDB container (Docker not available?): %v", err) - return nil, func() {} + return nil } mongoURI, err := container.ConnectionString(ctx) @@ -28,18 +28,24 @@ func SetupTestStorage(t *testing.T, conf config.Config) (*mongodb.Storage, func( t.Fatalf("Failed to get MongoDB connection string: %v", err) } - connectCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + conf.Database.URI = mongoURI + if conf.Database.ConnectTimeout == 0 { + conf.Database.ConnectTimeout = 5 * time.Second + } + if conf.Database.Database == "" { + conf.Database.Database = t.Name() + } + + connectCtx, cancel := context.WithTimeout(ctx, conf.Database.ConnectTimeout) defer cancel() client, err := mongo.Connect(connectCtx, options.Client().ApplyURI(mongoURI)) if err != nil { t.Fatalf("Failed to connect to MongoDB: %v", err) } - if err := client.Ping(connectCtx, nil); err != nil { t.Fatalf("Failed to ping MongoDB: %v", err) } - conf.Database.URI = mongoURI storage, err := mongodb.NewStorage(conf) if err != nil { t.Fatalf("Failed to create storage: %v", err) @@ -59,6 +65,7 @@ func SetupTestStorage(t *testing.T, conf config.Config) (*mongodb.Storage, func( t.Logf("Failed to terminate MongoDB container: %v", err) } } + t.Cleanup(cleanup) - return storage, cleanup + return storage } diff --git a/pkg/api/bigint.go b/pkg/api/bigint.go index c0d21a2..5df3bda 100644 --- a/pkg/api/bigint.go +++ b/pkg/api/bigint.go @@ -4,9 +4,6 @@ import ( "encoding/json" "fmt" "math/big" - - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/bsontype" ) // BigInt wraps big.Int for JSON serialization @@ -73,34 +70,6 @@ func (b *BigInt) String() string { return b.Int.String() } -// MarshalBSONValue implements bson.ValueMarshaler -func (b *BigInt) MarshalBSONValue() (bsontype.Type, []byte, error) { - if b == nil || b.Int == nil { - return bson.MarshalValue("0") - } - return bson.MarshalValue(b.Int.String()) -} - -// UnmarshalBSONValue implements bson.ValueUnmarshaler -func (b *BigInt) UnmarshalBSONValue(bsonType bsontype.Type, data []byte) error { - if bsonType == bson.TypeNull { - b.Int = big.NewInt(0) - return nil - } - var s string - err := bson.UnmarshalValue(bsonType, data, &s) - if err != nil { - return err - } - - i, ok := new(big.Int).SetString(s, 10) - if !ok { - return fmt.Errorf("invalid big int string: %s", s) - } - b.Int = i - return nil -} - // BigintEncode matches TypeScript BigintConverter.encode func BigintEncode(value *big.Int) []byte { if value.Sign() == 0 { diff --git a/pkg/api/bigint_test.go b/pkg/api/bigint_test.go deleted file mode 100644 index 410f3d2..0000000 --- a/pkg/api/bigint_test.go +++ /dev/null @@ -1,481 +0,0 @@ -package api - -import ( - "math/big" - "testing" - - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/bsontype" -) - -func TestBigInt_MarshalBSONValue(t *testing.T) { - tests := []struct { - name string - bigInt *BigInt - wantErr bool - }{ - { - name: "marshal valid positive bigint", - bigInt: NewBigInt(big.NewInt(12345)), - wantErr: false, - }, - { - name: "marshal zero bigint", - bigInt: NewBigInt(big.NewInt(0)), - wantErr: false, - }, - { - name: "marshal negative bigint", - bigInt: NewBigInt(big.NewInt(-98765)), - wantErr: false, - }, - { - name: "marshal large bigint", - bigInt: mustNewBigIntFromString("123456789012345678901234567890"), - wantErr: false, - }, - { - name: "marshal nil internal bigint", - bigInt: &BigInt{Int: nil}, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - bsonType, data, err := tt.bigInt.MarshalBSONValue() - if (err != nil) != tt.wantErr { - t.Errorf("BigInt.MarshalBSONValue() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr { - if bsonType != bson.TypeString { - t.Errorf("Expected TypeString for BigInt, got %v", bsonType) - } - if len(data) == 0 { - t.Error("BigInt.MarshalBSONValue() returned empty data") - } - } - }) - } -} - -func TestBigInt_UnmarshalBSONValue(t *testing.T) { - tests := []struct { - name string - original *BigInt - wantErr bool - }{ - { - name: "unmarshal valid positive bigint", - original: NewBigInt(big.NewInt(12345)), - wantErr: false, - }, - { - name: "unmarshal zero bigint", - original: NewBigInt(big.NewInt(0)), - wantErr: false, - }, - { - name: "unmarshal negative bigint", - original: NewBigInt(big.NewInt(-98765)), - wantErr: false, - }, - { - name: "unmarshal large bigint", - original: mustNewBigIntFromString("123456789012345678901234567890"), - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Marshal the original to get BSON data - bsonType, data, err := tt.original.MarshalBSONValue() - if err != nil { - t.Fatalf("Failed to marshal original: %v", err) - } - - // Unmarshal the BSON data - var unmarshaled BigInt - err = unmarshaled.UnmarshalBSONValue(bsonType, data) - if (err != nil) != tt.wantErr { - t.Errorf("BigInt.UnmarshalBSONValue() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr && tt.original.String() != unmarshaled.String() { - t.Errorf("BigInt.UnmarshalBSONValue() = %s, want %s", unmarshaled.String(), tt.original.String()) - } - }) - } -} - -func TestBigInt_UnmarshalBSONValue_InvalidData(t *testing.T) { - tests := []struct { - name string - bsonType bsontype.Type - data []byte - wantErr bool - }{ - { - name: "invalid string data", - bsonType: bson.TypeString, - data: []byte("not-a-number"), - wantErr: true, - }, - { - name: "invalid bson type", - bsonType: bson.TypeInt32, - data: []byte{0x01, 0x00, 0x00, 0x00}, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var bi BigInt - err := bi.UnmarshalBSONValue(tt.bsonType, tt.data) - if (err != nil) != tt.wantErr { - t.Errorf("BigInt.UnmarshalBSONValue() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestBigInt_BSONRoundTrip(t *testing.T) { - testCases := []struct { - name string - original *BigInt - }{ - { - name: "positive number", - original: NewBigInt(big.NewInt(12345)), - }, - { - name: "zero", - original: NewBigInt(big.NewInt(0)), - }, - { - name: "negative number", - original: NewBigInt(big.NewInt(-98765)), - }, - { - name: "large number", - original: mustNewBigIntFromString("123456789012345678901234567890"), - }, - { - name: "very large number", - original: mustNewBigIntFromString("999999999999999999999999999999999999999999999999999999999999999999"), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Test struct-based round-trip - type TestStruct struct { - Value *BigInt `bson:"value"` - } - - original := TestStruct{Value: tc.original} - - // Marshal to BSON - bsonData, err := bson.Marshal(original) - if err != nil { - t.Fatalf("Marshal() failed: %v", err) - } - - // Unmarshal from BSON - var unmarshaled TestStruct - err = bson.Unmarshal(bsonData, &unmarshaled) - if err != nil { - t.Fatalf("Unmarshal() failed: %v", err) - } - - // Compare string representations - if tc.original.String() != unmarshaled.Value.String() { - t.Errorf("Round-trip failed: original %s, unmarshaled %s", - tc.original.String(), unmarshaled.Value.String()) - } - - // Compare using Cmp if both have valid Int values - if tc.original.Int != nil && unmarshaled.Value.Int != nil { - if tc.original.Int.Cmp(unmarshaled.Value.Int) != 0 { - t.Errorf("Round-trip comparison failed: original %s, unmarshaled %s", - tc.original.String(), unmarshaled.Value.String()) - } - } - }) - } -} - -func TestBigInt_BSONWithStruct(t *testing.T) { - // Test marshaling/unmarshaling as part of a struct - type TestStruct struct { - Name string `bson:"name"` - Amount *BigInt `bson:"amount"` - Count *BigInt `bson:"count"` - } - - original := TestStruct{ - Name: "test", - Amount: NewBigInt(big.NewInt(12345)), - Count: mustNewBigIntFromString("999999999999999999999999999999"), - } - - // Marshal struct to BSON - bsonData, err := bson.Marshal(original) - if err != nil { - t.Fatalf("Failed to marshal struct: %v", err) - } - - // Unmarshal struct from BSON - var unmarshaled TestStruct - err = bson.Unmarshal(bsonData, &unmarshaled) - if err != nil { - t.Fatalf("Failed to unmarshal struct: %v", err) - } - - // Verify data integrity - if original.Name != unmarshaled.Name { - t.Errorf("Name mismatch: got %s, want %s", unmarshaled.Name, original.Name) - } - - if original.Amount.String() != unmarshaled.Amount.String() { - t.Errorf("Amount mismatch: got %s, want %s", - unmarshaled.Amount.String(), original.Amount.String()) - } - - if original.Count.String() != unmarshaled.Count.String() { - t.Errorf("Count mismatch: got %s, want %s", - unmarshaled.Count.String(), original.Count.String()) - } -} - -func TestBigInt_BSONNilHandling(t *testing.T) { - // Test handling of nil BigInt in struct - type TestStruct struct { - Name string `bson:"name"` - Amount *BigInt `bson:"amount,omitempty"` - } - - original := TestStruct{ - Name: "test", - Amount: nil, - } - - // Marshal struct to BSON - bsonData, err := bson.Marshal(original) - if err != nil { - t.Fatalf("Failed to marshal struct with nil BigInt: %v", err) - } - - // Unmarshal struct from BSON - var unmarshaled TestStruct - err = bson.Unmarshal(bsonData, &unmarshaled) - if err != nil { - t.Fatalf("Failed to unmarshal struct with nil BigInt: %v", err) - } - - // Verify nil is preserved - if unmarshaled.Amount != nil { - t.Error("Expected nil BigInt to remain nil after round-trip") - } -} - -func TestBigInt_BSONNilInternalInt(t *testing.T) { - // Test BigInt with nil internal Int - bigInt := &BigInt{Int: nil} - - // Test struct-based marshaling - type TestStruct struct { - Value *BigInt `bson:"value"` - } - - original := TestStruct{Value: bigInt} - - // Marshal to BSON - bsonData, err := bson.Marshal(original) - if err != nil { - t.Fatalf("Marshal() failed for nil internal Int: %v", err) - } - - // Unmarshal from BSON - var unmarshaled TestStruct - err = bson.Unmarshal(bsonData, &unmarshaled) - if err != nil { - t.Fatalf("Unmarshal() failed: %v", err) - } - - // Should unmarshal to "0" - if unmarshaled.Value.String() != "0" { - t.Errorf("Expected '0' for nil internal Int, got %s", unmarshaled.Value.String()) - } -} - -func TestBigInt_BSONCompatibility(t *testing.T) { - // Test that BigInt works correctly when embedded in complex structures - type ComplexStruct struct { - ID string `bson:"_id"` - Numbers []*BigInt `bson:"numbers"` - Total *BigInt `bson:"total"` - Nil *BigInt `bson:"nil,omitempty"` - } - - original := ComplexStruct{ - ID: "test-id", - Numbers: []*BigInt{ - NewBigInt(big.NewInt(100)), - NewBigInt(big.NewInt(200)), - mustNewBigIntFromString("999999999999999999999999999999"), - }, - Total: NewBigInt(big.NewInt(300)), - Nil: nil, - } - - // Marshal to BSON - bsonData, err := bson.Marshal(original) - if err != nil { - t.Fatalf("Failed to marshal complex struct: %v", err) - } - - // Unmarshal from BSON - var unmarshaled ComplexStruct - err = bson.Unmarshal(bsonData, &unmarshaled) - if err != nil { - t.Fatalf("Failed to unmarshal complex struct: %v", err) - } - - // Verify all fields - if original.ID != unmarshaled.ID { - t.Errorf("ID mismatch: got %s, want %s", unmarshaled.ID, original.ID) - } - - if len(original.Numbers) != len(unmarshaled.Numbers) { - t.Errorf("Numbers slice length mismatch: got %d, want %d", - len(unmarshaled.Numbers), len(original.Numbers)) - } - - for i, num := range original.Numbers { - if num.String() != unmarshaled.Numbers[i].String() { - t.Errorf("Numbers[%d] mismatch: got %s, want %s", - i, unmarshaled.Numbers[i].String(), num.String()) - } - } - - if original.Total.String() != unmarshaled.Total.String() { - t.Errorf("Total mismatch: got %s, want %s", - unmarshaled.Total.String(), original.Total.String()) - } - - if unmarshaled.Nil != nil { - t.Error("Expected Nil to remain nil") - } -} - -func TestBigInt_BSONComprehensive(t *testing.T) { - // Test comprehensive BigInt BSON marshaling/unmarshaling scenarios - testCases := []struct { - name string - value *BigInt - description string - }{ - { - name: "small positive", - value: NewBigInt(big.NewInt(42)), - description: "Small positive integer", - }, - { - name: "large positive", - value: mustNewBigIntFromString("12345678901234567890123456789012345678901234567890"), - description: "Large positive integer beyond int64 range", - }, - { - name: "negative", - value: NewBigInt(big.NewInt(-999999999999999999)), - description: "Negative integer", - }, - { - name: "zero", - value: NewBigInt(big.NewInt(0)), - description: "Zero value", - }, - { - name: "nil internal int", - value: &BigInt{Int: nil}, - description: "BigInt with nil internal Int pointer", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Test direct value marshaling - bsonType, data, err := tc.value.MarshalBSONValue() - if err != nil { - t.Fatalf("MarshalBSONValue failed: %v", err) - } - - // Verify the BSON type is string - if bsonType != bson.TypeString { - t.Errorf("Expected BSON type string, got %v", bsonType) - } - - // Test direct value unmarshaling - var unmarshaled BigInt - err = unmarshaled.UnmarshalBSONValue(bsonType, data) - if err != nil { - t.Fatalf("UnmarshalBSONValue failed: %v", err) - } - - // Verify values match - if tc.value.String() != unmarshaled.String() { - t.Errorf("Direct marshaling mismatch: expected %s, got %s", - tc.value.String(), unmarshaled.String()) - } - - // Test struct-based marshaling - type TestDoc struct { - ID string `bson:"_id"` - Value *BigInt `bson:"value"` - } - - doc := TestDoc{ - ID: tc.name, - Value: tc.value, - } - - bsonData, err := bson.Marshal(doc) - if err != nil { - t.Fatalf("Struct marshal failed: %v", err) - } - - var unmarshaledDoc TestDoc - err = bson.Unmarshal(bsonData, &unmarshaledDoc) - if err != nil { - t.Fatalf("Struct unmarshal failed: %v", err) - } - - // Verify struct marshaling preserves values - if doc.ID != unmarshaledDoc.ID { - t.Errorf("ID mismatch: expected %s, got %s", doc.ID, unmarshaledDoc.ID) - } - - if doc.Value.String() != unmarshaledDoc.Value.String() { - t.Errorf("Struct marshaling mismatch: expected %s, got %s", - doc.Value.String(), unmarshaledDoc.Value.String()) - } - - t.Logf("✓ %s: %s -> BSON -> %s", tc.description, tc.value.String(), unmarshaledDoc.Value.String()) - }) - } -} - -// Helper function for test cases -func mustNewBigIntFromString(s string) *BigInt { - bi, err := NewBigIntFromString(s) - if err != nil { - panic(err) - } - return bi -} diff --git a/pkg/api/merkle_tree_path_verify_test.go b/pkg/api/merkle_tree_path_verify_test.go index d1a67bf..b6fbe67 100644 --- a/pkg/api/merkle_tree_path_verify_test.go +++ b/pkg/api/merkle_tree_path_verify_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/unicitynetwork/aggregator-go/internal/smt" "github.com/unicitynetwork/aggregator-go/pkg/api" ) @@ -28,7 +29,8 @@ func TestMerkleTreePathVerify(t *testing.T) { err := tree.AddLeaves([]*smt.Leaf{leaf}) require.NoError(t, err) - path := tree.GetPath(big.NewInt(42)) + path, err := tree.GetPath(big.NewInt(42)) + require.NoError(t, err) require.NotNil(t, path) result, err := path.Verify(big.NewInt(42)) @@ -49,7 +51,8 @@ func TestMerkleTreePathVerify(t *testing.T) { // Verify both paths for _, leafPath := range []int64{10, 12} { - path := tree.GetPath(big.NewInt(leafPath)) + path, err := tree.GetPath(big.NewInt(leafPath)) + require.NoError(t, err) require.NotNil(t, path) result, err := path.Verify(big.NewInt(leafPath)) @@ -75,7 +78,8 @@ func TestMerkleTreePathVerify(t *testing.T) { // Verify each path for _, p := range paths { - path := tree.GetPath(big.NewInt(0x1000000000000 + p)) + path, err := tree.GetPath(big.NewInt(0x1000000000000 + p)) + require.NoError(t, err) require.NotNil(t, path) result, err := path.Verify(big.NewInt(0x1000000000000 + p)) @@ -101,7 +105,8 @@ func TestMerkleTreePathVerify(t *testing.T) { require.NoError(t, err) // Verify transfer path - path := tree.GetPath(transferPath) + path, err := tree.GetPath(transferPath) + require.NoError(t, err) require.NotNil(t, path) result, err := path.Verify(transferPath) @@ -110,7 +115,8 @@ func TestMerkleTreePathVerify(t *testing.T) { require.True(t, result.PathValid, "Transfer path should be valid") // Verify mint path - pathMint := tree.GetPath(mintPath) + pathMint, err := tree.GetPath(mintPath) + require.NoError(t, err) require.NotNil(t, pathMint) resultMint, err := pathMint.Verify(mintPath) @@ -132,7 +138,8 @@ func TestMerkleTreePathVerify(t *testing.T) { // a valid path showing where that leaf would be inserted. Since leaf 1000 goes // left (bit 0 = 0) and 999 would go right (bit 0 = 1), we get a path to the // empty right branch with the left subtree as sibling. - path := tree.GetPath(big.NewInt(999)) + path, err := tree.GetPath(big.NewInt(999)) + require.NoError(t, err) require.NotNil(t, path) // When we verify this path with requestId 999: @@ -155,7 +162,8 @@ func TestMerkleTreePathVerify(t *testing.T) { require.NoError(t, err) // Get path for 5 - path5 := tree.GetPath(big.NewInt(0x1000 + 5)) + path5, err := tree.GetPath(big.NewInt(0x1000 + 5)) + require.NoError(t, err) // Try to verify with wrong requestId result, err := path5.Verify(big.NewInt(0x1000 + 15)) @@ -192,7 +200,8 @@ func TestMerkleTreePathVerify(t *testing.T) { require.NoError(t, err) for _, p := range tc.paths { - path := tree.GetPath(big.NewInt(p)) + path, err := tree.GetPath(big.NewInt(p)) + require.NoError(t, err) result, err := path.Verify(big.NewInt(p)) require.NoError(t, err) require.True(t, result.PathIncluded && result.PathValid, @@ -229,13 +238,15 @@ func TestMerkleTreePathVerify(t *testing.T) { require.NoError(t, err) // Verify paths - treePath1 := tree.GetPath(path1) + treePath1, err := tree.GetPath(path1) + require.NoError(t, err) result1, err := treePath1.Verify(path1) require.NoError(t, err) require.True(t, result1.PathIncluded && result1.PathValid, "RequestID1 path should be valid") - treePath2 := tree.GetPath(path2) + treePath2, err := tree.GetPath(path2) + require.NoError(t, err) result2, err := treePath2.Verify(path2) require.NoError(t, err) require.True(t, result2.PathIncluded && result2.PathValid, @@ -253,7 +264,8 @@ func TestMerkleTreePathVerify(t *testing.T) { // Verify all previously added leaves still work for j := int64(1); j <= i; j++ { - path := tree.GetPath(big.NewInt(0x100000 + j*100)) + path, err := tree.GetPath(big.NewInt(0x100000 + j*100)) + require.NoError(t, err) result, err := path.Verify(big.NewInt(0x100000 + j*100)) require.NoError(t, err) require.True(t, result.PathIncluded && result.PathValid, @@ -317,7 +329,8 @@ func TestMerkleTreePathVerifyDuplicates(t *testing.T) { require.Error(t, err) // Verify the original value is still there - path := tree.GetPath(big.NewInt(100)) + path, err := tree.GetPath(big.NewInt(100)) + require.NoError(t, err) result, err := path.Verify(big.NewInt(100)) require.NoError(t, err) require.True(t, result.PathIncluded && result.PathValid, @@ -343,7 +356,8 @@ func TestMerkleTreePathVerifyAlternateAlgorithm(t *testing.T) { require.Equal(t, root[:4], fmt.Sprintf("%04x", algo)) for _, leaf := range leaves { - path := tree.GetPath(leaf.Path) + path, err := tree.GetPath(leaf.Path) + require.NoError(t, err) require.Equal(t, root, path.Root) res, err := path.Verify(leaf.Path) require.NoError(t, err) diff --git a/pkg/api/request_id.go b/pkg/api/request_id.go index 631da4b..38dcc7f 100644 --- a/pkg/api/request_id.go +++ b/pkg/api/request_id.go @@ -6,9 +6,6 @@ import ( "encoding/json" "fmt" "math/big" - - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/bsontype" ) type RequestID = ImprintHexString @@ -121,24 +118,3 @@ func ValidateRequestID(requestID RequestID, publicKey []byte, stateHashBytes []b return requestID == expectedRequestID, nil } - -// MarshalBSONValue implements bson.ValueMarshaler for ImprintHexString -func (r ImprintHexString) MarshalBSONValue() (bsontype.Type, []byte, error) { - return bson.MarshalValue(string(r)) -} - -// UnmarshalBSONValue implements bson.ValueUnmarshaler for ImprintHexString -func (r *ImprintHexString) UnmarshalBSONValue(bsonType bsontype.Type, data []byte) error { - var s string - err := bson.UnmarshalValue(bsonType, data, &s) - if err != nil { - return err - } - - id, err := NewImprintHexString(s) - if err != nil { - return err - } - *r = id - return nil -} diff --git a/pkg/api/request_id_test.go b/pkg/api/request_id_test.go index 1f63ac0..9585c25 100644 --- a/pkg/api/request_id_test.go +++ b/pkg/api/request_id_test.go @@ -8,16 +8,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/bsontype" ) func TestRequestID_CreateAndSerialize(t *testing.T) { - // Test data that should produce the exact same result as TypeScript - // From RequestIdTest.ts: - // RequestId.create(new Uint8Array(20), DataHash.fromImprint(new Uint8Array(34))) - // Expected JSON: '0000ea659cdc838619b3767c057fdf8e6d99fde2680c5d8517eb06761c0878d40c40' - // Expected CBOR: '58220000ea659cdc838619b3767c057fdf8e6d99fde2680c5d8517eb06761c0878d40c40' t.Run("should create RequestID with exact TypeScript compatibility", func(t *testing.T) { // Create 20-byte public key (all zeros) @@ -86,520 +79,3 @@ func TestRequestID_CreateAndSerialize(t *testing.T) { assert.Equal(t, requestIDStr, hexStr) }) } - -func TestImprintHexString_MarshalBSONValue(t *testing.T) { - tests := []struct { - name string - hexString ImprintHexString - wantErr bool - }{ - { - name: "marshal valid request id", - hexString: ImprintHexString("0000ea659cdc838619b3767c057fdf8e6d99fde2680c5d8517eb06761c0878d40c40"), - wantErr: false, - }, - { - name: "marshal short hex string", - hexString: ImprintHexString("0000abcd"), - wantErr: false, - }, - { - name: "marshal long hex string", - hexString: ImprintHexString("0000" + "abcdef1234567890" + "abcdef1234567890" + "abcdef1234567890" + "abcdef1234567890"), - wantErr: false, - }, - { - name: "marshal empty string", - hexString: ImprintHexString(""), - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - bsonType, data, err := tt.hexString.MarshalBSONValue() - if (err != nil) != tt.wantErr { - t.Errorf("ImprintHexString.MarshalBSONValue() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr { - if bsonType != bson.TypeString { - t.Errorf("Expected TypeString for ImprintHexString, got %v", bsonType) - } - if len(data) == 0 { - t.Error("ImprintHexString.MarshalBSONValue() returned empty data") - } - } - }) - } -} - -func TestImprintHexString_UnmarshalBSONValue(t *testing.T) { - tests := []struct { - name string - original ImprintHexString - wantErr bool - }{ - { - name: "unmarshal valid request id", - original: ImprintHexString("0000ea659cdc838619b3767c057fdf8e6d99fde2680c5d8517eb06761c0878d40c40"), - wantErr: false, - }, - { - name: "unmarshal short hex string", - original: ImprintHexString("0000abcd"), - wantErr: false, - }, - { - name: "unmarshal with algorithm prefix", - original: ImprintHexString("1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef12"), - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Marshal the original to get BSON data - bsonType, data, err := tt.original.MarshalBSONValue() - if err != nil { - t.Fatalf("Failed to marshal original: %v", err) - } - - // Unmarshal the BSON data - var unmarshaled ImprintHexString - err = unmarshaled.UnmarshalBSONValue(bsonType, data) - if (err != nil) != tt.wantErr { - t.Errorf("ImprintHexString.UnmarshalBSONValue() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr && string(tt.original) != string(unmarshaled) { - t.Errorf("ImprintHexString.UnmarshalBSONValue() = %s, want %s", string(unmarshaled), string(tt.original)) - } - }) - } -} - -func TestImprintHexString_UnmarshalBSONValue_InvalidData(t *testing.T) { - tests := []struct { - name string - bsonType bsontype.Type - data []byte - wantErr bool - }{ - { - name: "invalid hex string", - bsonType: bson.TypeString, - data: []byte("xyz"), - wantErr: true, - }, - { - name: "too short string", - bsonType: bson.TypeString, - data: []byte("ab"), - wantErr: true, - }, - { - name: "invalid bson type", - bsonType: bson.TypeInt32, - data: []byte{0x01, 0x00, 0x00, 0x00}, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var h ImprintHexString - err := h.UnmarshalBSONValue(tt.bsonType, tt.data) - if (err != nil) != tt.wantErr { - t.Errorf("ImprintHexString.UnmarshalBSONValue() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestImprintHexString_BSONRoundTrip(t *testing.T) { - testCases := []struct { - name string - original ImprintHexString - }{ - { - name: "typical request id", - original: ImprintHexString("0000ea659cdc838619b3767c057fdf8e6d99fde2680c5d8517eb06761c0878d40c40"), - }, - { - name: "algorithm prefix 1234", - original: ImprintHexString("1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef12"), - }, - { - name: "algorithm prefix ffff", - original: ImprintHexString("ffff123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde00"), - }, - { - name: "short hex string", - original: ImprintHexString("0000abcd"), - }, - { - name: "long hex string", - original: ImprintHexString("0000abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890"), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Test struct-based round-trip - type TestStruct struct { - ID string `bson:"_id"` - Value ImprintHexString `bson:"value"` - } - - original := TestStruct{ - ID: tc.name, - Value: tc.original, - } - - // Marshal to BSON - bsonData, err := bson.Marshal(original) - if err != nil { - t.Fatalf("Marshal() failed: %v", err) - } - - // Unmarshal from BSON - var unmarshaled TestStruct - err = bson.Unmarshal(bsonData, &unmarshaled) - if err != nil { - t.Fatalf("Unmarshal() failed: %v", err) - } - - // Compare string representations - if string(tc.original) != string(unmarshaled.Value) { - t.Errorf("Round-trip failed: original %s, unmarshaled %s", - string(tc.original), string(unmarshaled.Value)) - } - - // Test that we can still call methods on the unmarshaled value - originalBytes, err := tc.original.Imprint() - if err == nil { // Only test if original is valid - unmarshaledBytes, err := unmarshaled.Value.Imprint() - if err != nil { - t.Errorf("Unmarshaled value failed Imprint(): %v", err) - } else if !assert.Equal(t, originalBytes, unmarshaledBytes) { - t.Errorf("Imprint() bytes don't match after round-trip") - } - } - }) - } -} - -func TestImprintHexString_BSONWithStruct(t *testing.T) { - // Test marshaling/unmarshaling as part of a struct - type TestStruct struct { - RequestID RequestID `bson:"request_id"` - StateHash ImprintHexString `bson:"state_hash"` - Name string `bson:"name"` - } - - original := TestStruct{ - RequestID: RequestID("0000ea659cdc838619b3767c057fdf8e6d99fde2680c5d8517eb06761c0878d40c40"), - StateHash: ImprintHexString("1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef12"), - Name: "test-struct", - } - - // Marshal struct to BSON - bsonData, err := bson.Marshal(original) - if err != nil { - t.Fatalf("Failed to marshal struct: %v", err) - } - - // Unmarshal struct from BSON - var unmarshaled TestStruct - err = bson.Unmarshal(bsonData, &unmarshaled) - if err != nil { - t.Fatalf("Failed to unmarshal struct: %v", err) - } - - // Verify data integrity - if original.Name != unmarshaled.Name { - t.Errorf("Name mismatch: got %s, want %s", unmarshaled.Name, original.Name) - } - - if string(original.RequestID) != string(unmarshaled.RequestID) { - t.Errorf("RequestID mismatch: got %s, want %s", - string(unmarshaled.RequestID), string(original.RequestID)) - } - - if string(original.StateHash) != string(unmarshaled.StateHash) { - t.Errorf("StateHash mismatch: got %s, want %s", - string(unmarshaled.StateHash), string(original.StateHash)) - } -} - -func TestImprintHexString_BSONNilHandling(t *testing.T) { - // Test handling of nil ImprintHexString pointer in struct - type TestStruct struct { - Name string `bson:"name"` - RequestID *ImprintHexString `bson:"request_id,omitempty"` - } - - original := TestStruct{ - Name: "test", - RequestID: nil, - } - - // Marshal struct to BSON - bsonData, err := bson.Marshal(original) - if err != nil { - t.Fatalf("Failed to marshal struct with nil ImprintHexString: %v", err) - } - - // Unmarshal struct from BSON - var unmarshaled TestStruct - err = bson.Unmarshal(bsonData, &unmarshaled) - if err != nil { - t.Fatalf("Failed to unmarshal struct with nil ImprintHexString: %v", err) - } - - // Verify nil is preserved - if unmarshaled.RequestID != nil { - t.Error("Expected nil ImprintHexString to remain nil after round-trip") - } -} - -func TestImprintHexString_BSONCompatibility(t *testing.T) { - // Test that ImprintHexString works correctly when embedded in complex structures - type ComplexStruct struct { - ID string `bson:"_id"` - RequestIDs []ImprintHexString `bson:"request_ids"` - MainID ImprintHexString `bson:"main_id"` - OptionalID *ImprintHexString `bson:"optional_id,omitempty"` - Metadata map[string]string `bson:"metadata"` - } - - optionalID := ImprintHexString("ffff123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde00") - original := ComplexStruct{ - ID: "test-complex", - RequestIDs: []ImprintHexString{ - ImprintHexString("0000ea659cdc838619b3767c057fdf8e6d99fde2680c5d8517eb06761c0878d40c40"), - ImprintHexString("1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef12"), - ImprintHexString("ffff000000000000000000000000000000000000000000000000000000000000"), - }, - MainID: ImprintHexString("0000abc000000000000000000000000000000000000000000000000000000000"), - OptionalID: &optionalID, - Metadata: map[string]string{ - "version": "1.0", - "type": "test", - }, - } - - // Marshal to BSON - bsonData, err := bson.Marshal(original) - if err != nil { - t.Fatalf("Failed to marshal complex struct: %v", err) - } - - // Unmarshal from BSON - var unmarshaled ComplexStruct - err = bson.Unmarshal(bsonData, &unmarshaled) - if err != nil { - t.Fatalf("Failed to unmarshal complex struct: %v", err) - } - - // Verify all fields - if original.ID != unmarshaled.ID { - t.Errorf("ID mismatch: got %s, want %s", unmarshaled.ID, original.ID) - } - - if len(original.RequestIDs) != len(unmarshaled.RequestIDs) { - t.Errorf("RequestIDs slice length mismatch: got %d, want %d", - len(unmarshaled.RequestIDs), len(original.RequestIDs)) - } - - for i, id := range original.RequestIDs { - if string(id) != string(unmarshaled.RequestIDs[i]) { - t.Errorf("RequestIDs[%d] mismatch: got %s, want %s", - i, string(unmarshaled.RequestIDs[i]), string(id)) - } - } - - if string(original.MainID) != string(unmarshaled.MainID) { - t.Errorf("MainID mismatch: got %s, want %s", - string(unmarshaled.MainID), string(original.MainID)) - } - - if original.OptionalID == nil && unmarshaled.OptionalID != nil { - t.Error("Expected nil OptionalID to remain nil") - } else if original.OptionalID != nil && unmarshaled.OptionalID == nil { - t.Error("Expected non-nil OptionalID to remain non-nil") - } else if original.OptionalID != nil && unmarshaled.OptionalID != nil { - if string(*original.OptionalID) != string(*unmarshaled.OptionalID) { - t.Errorf("OptionalID mismatch: got %s, want %s", - string(*unmarshaled.OptionalID), string(*original.OptionalID)) - } - } - - // Verify metadata - for key, value := range original.Metadata { - if unmarshaled.Metadata[key] != value { - t.Errorf("Metadata[%s] mismatch: got %s, want %s", - key, unmarshaled.Metadata[key], value) - } - } -} - -func TestRequestID_BSONFunctionality(t *testing.T) { - // Test that RequestID (which is an alias for ImprintHexString) works with BSON - t.Run("RequestID as alias works with BSON", func(t *testing.T) { - // Create a RequestID using the standard creation function - publicKey := make([]byte, 20) - stateHashBytes := make([]byte, 34) - stateHash, err := NewImprintHexString(fmt.Sprintf("%x", stateHashBytes)) - require.NoError(t, err) - - requestID, err := CreateRequestID(publicKey, stateHash) - require.NoError(t, err) - - // Test struct-based BSON marshaling - type TestDoc struct { - ID string `bson:"_id"` - RequestID RequestID `bson:"request_id"` - } - - doc := TestDoc{ - ID: "test-doc", - RequestID: requestID, - } - - // Marshal to BSON - bsonData, err := bson.Marshal(doc) - require.NoError(t, err) - - // Unmarshal from BSON - var unmarshaled TestDoc - err = bson.Unmarshal(bsonData, &unmarshaled) - require.NoError(t, err) - - // Verify RequestID is preserved - assert.Equal(t, string(requestID), string(unmarshaled.RequestID)) - assert.Equal(t, doc.ID, unmarshaled.ID) - - // Verify the RequestID still functions correctly - originalBytes, err := requestID.Imprint() - require.NoError(t, err) - - unmarshaledBytes, err := unmarshaled.RequestID.Imprint() - require.NoError(t, err) - - assert.Equal(t, originalBytes, unmarshaledBytes) - }) -} - -func TestImprintHexString_BSONComprehensive(t *testing.T) { - // Test comprehensive ImprintHexString BSON marshaling/unmarshaling scenarios - testCases := []struct { - name string - value ImprintHexString - description string - }{ - { - name: "standard request id", - value: ImprintHexString("0000ea659cdc838619b3767c057fdf8e6d99fde2680c5d8517eb06761c0878d40c40"), - description: "Standard RequestID format with 0000 algorithm prefix", - }, - { - name: "different algorithm", - value: ImprintHexString("1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef12"), - description: "Different algorithm prefix (1234)", - }, - { - name: "max algorithm", - value: ImprintHexString("ffff123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde00"), - description: "Maximum algorithm prefix (ffff)", - }, - { - name: "minimal valid", - value: ImprintHexString("0000abcd"), - description: "Minimal valid hex string (3+ bytes)", - }, - { - name: "extended length", - value: ImprintHexString("0000abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890"), - description: "Extended length hex string", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Test direct value marshaling - bsonType, data, err := tc.value.MarshalBSONValue() - if err != nil { - t.Fatalf("MarshalBSONValue failed: %v", err) - } - - // Verify the BSON type is string - if bsonType != bson.TypeString { - t.Errorf("Expected BSON type string, got %v", bsonType) - } - - // Test direct value unmarshaling - var unmarshaled ImprintHexString - err = unmarshaled.UnmarshalBSONValue(bsonType, data) - if err != nil { - t.Fatalf("UnmarshalBSONValue failed: %v", err) - } - - // Verify values match - if string(tc.value) != string(unmarshaled) { - t.Errorf("Direct marshaling mismatch: expected %s, got %s", - string(tc.value), string(unmarshaled)) - } - - // Test struct-based marshaling - type TestDoc struct { - ID string `bson:"_id"` - Value ImprintHexString `bson:"value"` - } - - doc := TestDoc{ - ID: tc.name, - Value: tc.value, - } - - bsonData, err := bson.Marshal(doc) - if err != nil { - t.Fatalf("Struct marshal failed: %v", err) - } - - var unmarshaledDoc TestDoc - err = bson.Unmarshal(bsonData, &unmarshaledDoc) - if err != nil { - t.Fatalf("Struct unmarshal failed: %v", err) - } - - // Verify struct marshaling preserves values - if doc.ID != unmarshaledDoc.ID { - t.Errorf("ID mismatch: expected %s, got %s", doc.ID, unmarshaledDoc.ID) - } - - if string(doc.Value) != string(unmarshaledDoc.Value) { - t.Errorf("Struct marshaling mismatch: expected %s, got %s", - string(doc.Value), string(unmarshaledDoc.Value)) - } - - // Test that methods still work on unmarshaled value - if len(string(tc.value)) >= 4 { // Only test if long enough for algorithm - originalBytes, err := tc.value.Imprint() - if err == nil { // Only test if original is valid - unmarshaledBytes, err := unmarshaledDoc.Value.Imprint() - if err != nil { - t.Errorf("Unmarshaled value failed Imprint(): %v", err) - } else if !assert.Equal(t, originalBytes, unmarshaledBytes) { - t.Errorf("Imprint() bytes don't match after round-trip") - } - } - } - - t.Logf("✓ %s: %s -> BSON -> %s", tc.description, string(tc.value), string(unmarshaledDoc.Value)) - }) - } -} diff --git a/pkg/api/shard_types_test.go b/pkg/api/shard_types_test.go new file mode 100644 index 0000000..402867a --- /dev/null +++ b/pkg/api/shard_types_test.go @@ -0,0 +1,49 @@ +package api + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetShardProofRequest_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + json string + wantShardID ShardID + wantErr bool + }{ + { + name: "valid JSON without 0x prefix", + json: `{"shardId":2}`, + wantShardID: 2, + wantErr: false, + }, + { + name: "valid JSON with 0x prefix", + json: `{"shardId":3}`, + wantShardID: 3, + wantErr: false, + }, + { + name: "invalid hex", + json: `{"shardId":"GGGG"}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req GetShardProofRequest + err := json.Unmarshal([]byte(tt.json), &req) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantShardID, req.ShardID) + } + }) + } +} diff --git a/pkg/api/smt.go b/pkg/api/smt.go index cb7ace0..5f9d446 100644 --- a/pkg/api/smt.go +++ b/pkg/api/smt.go @@ -42,7 +42,7 @@ func (m *MerkleTreePath) Verify(requestId *big.Int) (*PathVerificationResult, er } // The "running totals" as we go through the hashing steps - currentPath := big.NewInt(1) + var currentPath *big.Int var currentData *[]byte for i, step := range m.Steps { @@ -61,7 +61,7 @@ func (m *MerkleTreePath) Verify(requestId *big.Int) (*PathVerificationResult, er } if i == 0 { - if stepPath.Sign() > 0 { + if stepPath.BitLen() >= 2 { // First step, normal case: data is the value in the leaf, apply the leaf hashing rule hasher.Reset().AddData(CborArray(2)) hasher.AddCborBytes(BigintEncode(stepPath)) @@ -73,8 +73,10 @@ func (m *MerkleTreePath) Verify(requestId *big.Int) (*PathVerificationResult, er currentData = &hasher.GetHash().RawHash } else { // First step, special case: data is the "our branch" hash value for the next step + // Note that in this case stepPath is a "naked" direction bit currentData = stepData } + currentPath = stepPath } else { // All subsequent steps: apply the non-leaf hashing rule var left, right *[]byte @@ -101,19 +103,21 @@ func (m *MerkleTreePath) Verify(requestId *big.Int) (*PathVerificationResult, er hasher.AddCborBytes(*right) } currentData = &hasher.GetHash().RawHash - } - if stepPath.Sign() > 0 { + // Initialization for when currentPath is a "naked" direction bit + if currentPath.BitLen() < 2 { + currentPath = big.NewInt(1) + } // Append step path bits to current path pathLen := stepPath.BitLen() - 1 - stepPath.SetBit(stepPath, pathLen, 0) + mask := new(big.Int).SetBit(stepPath, pathLen, 0) currentPath.Lsh(currentPath, uint(pathLen)) - currentPath.Or(currentPath, stepPath) + currentPath.Or(currentPath, mask) } } pathValid := currentData != nil && m.Root == NewDataHash(hasher.algorithm, *currentData).ToHex() - pathIncluded := requestId.Cmp(currentPath) == 0 + pathIncluded := currentPath != nil && requestId.Cmp(currentPath) == 0 return &PathVerificationResult{ PathValid: pathValid, diff --git a/pkg/api/types.go b/pkg/api/types.go index a1e20c4..78382dd 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -7,14 +7,12 @@ import ( "fmt" "strconv" "time" - - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/bsontype" ) // Basic types for API type StateHash = ImprintHexString type TransactionHash = ImprintHexString +type ShardID = int // Authenticator represents the authentication data for a commitment type Authenticator struct { @@ -94,33 +92,6 @@ func (t *Timestamp) UnmarshalJSON(data []byte) error { return nil } -// MarshalBSONValue implements bson.ValueMarshaler -func (t *Timestamp) MarshalBSONValue() (bsontype.Type, []byte, error) { - if t == nil { - return bson.TypeNull, nil, nil - } - // Use Unix timestamp in milliseconds as int64 for BSON - millis := t.Time.UnixMilli() - return bson.MarshalValue(millis) -} - -// UnmarshalBSONValue implements bson.ValueUnmarshaler -func (t *Timestamp) UnmarshalBSONValue(bsonType bsontype.Type, data []byte) error { - if bsonType == bson.TypeNull { - t.Time = time.Time{} - return nil - } - - var millis int64 - err := bson.UnmarshalValue(bsonType, data, &millis) - if err != nil { - return err - } - - t.Time = time.UnixMilli(millis) - return nil -} - // AggregatorRecord represents a finalized commitment with proof data type AggregatorRecord struct { RequestID RequestID `json:"requestId"` @@ -149,28 +120,32 @@ func NewAggregatorRecord(commitment *Commitment, blockNumber, leafIndex *BigInt) // Block represents a blockchain block type Block struct { - Index *BigInt `json:"index"` - ChainID string `json:"chainId"` - Version string `json:"version"` - ForkID string `json:"forkId"` - RootHash HexBytes `json:"rootHash"` - PreviousBlockHash HexBytes `json:"previousBlockHash"` - NoDeletionProofHash HexBytes `json:"noDeletionProofHash"` - CreatedAt *Timestamp `json:"createdAt"` - UnicityCertificate HexBytes `json:"unicityCertificate"` + Index *BigInt `json:"index"` + ChainID string `json:"chainId"` + ShardID ShardID `json:"shardId"` + Version string `json:"version"` + ForkID string `json:"forkId"` + RootHash HexBytes `json:"rootHash"` + PreviousBlockHash HexBytes `json:"previousBlockHash"` + NoDeletionProofHash HexBytes `json:"noDeletionProofHash"` + CreatedAt *Timestamp `json:"createdAt"` + UnicityCertificate HexBytes `json:"unicityCertificate"` + ParentMerkleTreePath *MerkleTreePath `json:"parentMerkleTreePath,omitempty"` // child mode only } // NewBlock creates a new block -func NewBlock(index *BigInt, chainID, version, forkID string, rootHash, previousBlockHash, unicityCertificate HexBytes) *Block { +func NewBlock(index *BigInt, chainID string, shardID ShardID, version, forkID string, rootHash, previousBlockHash, uc HexBytes, parentMerkleTreePath *MerkleTreePath) *Block { return &Block{ - Index: index, - ChainID: chainID, - Version: version, - ForkID: forkID, - RootHash: rootHash, - PreviousBlockHash: previousBlockHash, - CreatedAt: Now(), - UnicityCertificate: unicityCertificate, + Index: index, + ChainID: chainID, + ShardID: shardID, + Version: version, + ForkID: forkID, + RootHash: rootHash, + PreviousBlockHash: previousBlockHash, + CreatedAt: Now(), + UnicityCertificate: uc, + ParentMerkleTreePath: parentMerkleTreePath, } } @@ -255,6 +230,11 @@ type InclusionProof struct { UnicityCertificate HexBytes `json:"unicityCertificate"` } +type RootShardInclusionProof struct { + MerkleTreePath *MerkleTreePath `json:"merkleTreePath"` + UnicityCertificate HexBytes `json:"unicityCertificate"` +} + // GetNoDeletionProofResponse represents the get_no_deletion_proof JSON-RPC response type GetNoDeletionProofResponse struct { NoDeletionProof *NoDeletionProof `json:"noDeletionProof"` @@ -286,11 +266,43 @@ type GetBlockCommitmentsResponse struct { Commitments []*AggregatorRecord `json:"commitments"` } +// Status constants for SubmitShardRootResponse +const ( + ShardRootStatusSuccess = "SUCCESS" + ShardRootStatusInvalidShardID = "INVALID_SHARD_ID" + ShardRootStatusInvalidRootHash = "INVALID_ROOT_HASH" + ShardRootStatusInternalError = "INTERNAL_ERROR" + ShardRootStatusNotLeader = "NOT_LEADER" +) + +// SubmitShardRootRequest represents the submit_shard_root JSON-RPC request +type SubmitShardRootRequest struct { + ShardID ShardID `json:"shardId"` + RootHash HexBytes `json:"rootHash"` // Raw root hash from child SMT +} + +// SubmitShardRootResponse represents the submit_shard_root JSON-RPC response +type SubmitShardRootResponse struct { + Status string `json:"status"` // "SUCCESS", "INVALID_SHARD_ID", "INVALID_ROOT_HASH", etc. +} + +// GetShardProofRequest represents the get_shard_proof JSON-RPC request +type GetShardProofRequest struct { + ShardID ShardID `json:"shardId"` +} + +// GetShardProofResponse represents the get_shard_proof JSON-RPC response +type GetShardProofResponse struct { + MerkleTreePath *MerkleTreePath `json:"merkleTreePath"` // Proof path for the shard + UnicityCertificate HexBytes `json:"unicityCertificate"` // Unicity Certificate from the finalized block +} + // HealthStatus represents the health status of the service type HealthStatus struct { Status string `json:"status"` Role string `json:"role"` ServerID string `json:"serverId"` + Sharding Sharding `json:"sharding"` Details map[string]string `json:"details,omitempty"` } @@ -311,3 +323,14 @@ func (h *HealthStatus) AddDetail(key, value string) { } h.Details[key] = value } + +func (r *RootShardInclusionProof) IsValid(shardRootHash string) bool { + return r.MerkleTreePath != nil && len(r.UnicityCertificate) > 0 && + len(r.MerkleTreePath.Steps) > 0 && r.MerkleTreePath.Steps[0].Data != nil && *r.MerkleTreePath.Steps[0].Data == shardRootHash +} + +type Sharding struct { + Mode string `json:"mode"` + ShardIDLen int `json:"shardIdLen"` + ShardID int `json:"shardId"` +} diff --git a/pkg/api/types_test.go b/pkg/api/types_test.go index b1dbafc..bce1484 100644 --- a/pkg/api/types_test.go +++ b/pkg/api/types_test.go @@ -6,7 +6,6 @@ import ( "time" "github.com/stretchr/testify/require" - "go.mongodb.org/mongo-driver/bson" ) func TestRequestIDMarshalJSON(t *testing.T) { @@ -65,74 +64,6 @@ func TestSubmitCommitmentRequestJSON(t *testing.T) { require.Equal(t, req.Authenticator.Algorithm, unmarshaledReq.Authenticator.Algorithm, "Algorithm mismatch") } -func TestRequestIDMarshalBSON(t *testing.T) { - requestID := RequestID("0000cfe84a1828e2edd0a7d9533b23e519f746069a938d549a150e07e14dc0f9cf00") - - // Wrap in a struct for BSON compatibility - type wrapper struct { - ID RequestID `bson:"id"` - } - w := wrapper{ID: requestID} - - data, err := bson.Marshal(w) - require.NoError(t, err, "Failed to marshal RequestID to BSON") - - var unmarshaled wrapper - err = bson.Unmarshal(data, &unmarshaled) - require.NoError(t, err, "Failed to unmarshal RequestID from BSON") - - require.Equal(t, requestID, unmarshaled.ID, "RequestID mismatch") -} - -func TestHexBytesMarshalBSON(t *testing.T) { - hexBytes := HexBytes{0x01, 0x02, 0x03, 0x04} - - // Wrap HexBytes in a struct for BSON compatibility - type wrapper struct { - Data HexBytes `bson:"data"` - } - w := wrapper{Data: hexBytes} - - data, err := bson.Marshal(w) - require.NoError(t, err, "Failed to marshal HexBytes to BSON") - - var unmarshaled wrapper - err = bson.Unmarshal(data, &unmarshaled) - require.NoError(t, err, "Failed to unmarshal HexBytes from BSON") - - require.Equal(t, len(hexBytes), len(unmarshaled.Data), "HexBytes length mismatch") - - for i, b := range hexBytes { - require.Equal(t, b, unmarshaled.Data[i], "HexBytes byte mismatch at index %d", i) - } -} - -func TestSubmitCommitmentRequestBSON(t *testing.T) { - req := &SubmitCommitmentRequest{ - RequestID: "0000cfe84a1828e2edd0a7d9533b23e519f746069a938d549a150e07e14dc0f9cf00", - TransactionHash: "00008a51b5b84171e6c7c345bf3610cc18fa1b61bad33908e1522520c001b0e7fd1d", - Authenticator: Authenticator{ - Algorithm: "secp256k1", - PublicKey: HexBytes{0x03, 0x20, 0x44, 0xf2}, - Signature: HexBytes{0x41, 0x67, 0x51, 0xe8}, - StateHash: ImprintHexString("0000cd60"), - }, - } - - data, err := bson.Marshal(req) - require.NoError(t, err, "Failed to marshal SubmitCommitmentRequest to BSON") - - var unmarshaledReq SubmitCommitmentRequest - err = bson.Unmarshal(data, &unmarshaledReq) - require.NoError(t, err, "Failed to unmarshal SubmitCommitmentRequest from BSON") - - require.Equal(t, req.RequestID, unmarshaledReq.RequestID, "RequestID mismatch") - - require.Equal(t, req.TransactionHash, unmarshaledReq.TransactionHash, "TransactionHash mismatch") - - require.Equal(t, req.Authenticator.Algorithm, unmarshaledReq.Authenticator.Algorithm, "Algorithm mismatch") -} - func TestImprintHexStringMarshalJSON(t *testing.T) { imprint := ImprintHexString("0000cd60") @@ -146,25 +77,6 @@ func TestImprintHexStringMarshalJSON(t *testing.T) { require.Equal(t, imprint, unmarshaledImprint, "ImprintHexString mismatch") } -func TestImprintHexStringMarshalBSON(t *testing.T) { - imprint := ImprintHexString("0000cd60") - - // Wrap ImprintHexString in a struct for BSON compatibility - type wrapper struct { - Imprint ImprintHexString `bson:"imprint"` - } - w := wrapper{Imprint: imprint} - - data, err := bson.Marshal(w) - require.NoError(t, err, "Failed to marshal ImprintHexString to BSON") - - var unmarshaled wrapper - err = bson.Unmarshal(data, &unmarshaled) - require.NoError(t, err, "Failed to unmarshal ImprintHexString from BSON") - - require.Equal(t, imprint, unmarshaled.Imprint, "ImprintHexString mismatch") -} - func TestTimeNanoMarshalJSON(t *testing.T) { now := time.Now() @@ -177,262 +89,3 @@ func TestTimeNanoMarshalJSON(t *testing.T) { require.True(t, now.Equal(unmarshaledTime) || now.Sub(unmarshaledTime) < time.Millisecond, "time.Time mismatch") } - -func TestTimeNanoMarshalBSON(t *testing.T) { - now := time.Now() - - // Wrap time.Time in a struct for BSON compatibility - type wrapper struct { - Time time.Time `bson:"time"` - } - w := wrapper{Time: now} - - data, err := bson.Marshal(w) - require.NoError(t, err, "Failed to marshal time.Time to BSON") - - var unmarshaled wrapper - err = bson.Unmarshal(data, &unmarshaled) - require.NoError(t, err, "Failed to unmarshal time.Time from BSON") - - require.True(t, now.Equal(unmarshaled.Time) || now.Sub(unmarshaled.Time) < time.Millisecond, "time.Time mismatch") -} - -func TestTimestamp_MarshalBSONValue(t *testing.T) { - tests := []struct { - name string - timestamp *Timestamp - wantErr bool - }{ - { - name: "marshal valid timestamp", - timestamp: NewTimestamp(time.Unix(1640995200, 0)), // 2022-01-01 00:00:00 UTC - wantErr: false, - }, - { - name: "marshal current time", - timestamp: Now(), - wantErr: false, - }, - { - name: "marshal nil timestamp", - timestamp: nil, - wantErr: false, - }, - { - name: "marshal zero timestamp", - timestamp: NewTimestamp(time.Time{}), - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - bsonType, data, err := tt.timestamp.MarshalBSONValue() - if (err != nil) != tt.wantErr { - t.Errorf("Timestamp.MarshalBSONValue() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr { - if tt.timestamp == nil { - if bsonType != bson.TypeNull { - t.Errorf("Expected TypeNull for nil timestamp, got %v", bsonType) - } - } else { - if bsonType != bson.TypeInt64 { - t.Errorf("Expected TypeInt64 for valid timestamp, got %v", bsonType) - } - if len(data) == 0 { - t.Error("Timestamp.MarshalBSONValue() returned empty data for valid timestamp") - } - } - } - }) - } -} - -func TestTimestamp_UnmarshalBSONValue(t *testing.T) { - // Test valid timestamp - testTime := time.Unix(1640995200, 0) // 2022-01-01 00:00:00 UTC - expectedMillis := testTime.UnixMilli() - - // Marshal the expected value as int64 - bsonType, bsonData, err := bson.MarshalValue(expectedMillis) - if err != nil { - t.Fatalf("Failed to marshal test data: %v", err) - } - - var ts Timestamp - err = ts.UnmarshalBSONValue(bsonType, bsonData) - if err != nil { - t.Errorf("Timestamp.UnmarshalBSONValue() error = %v", err) - return - } - - if ts.UnixMilli() != expectedMillis { - t.Errorf("Timestamp.UnmarshalBSONValue() = %d, want %d", ts.UnixMilli(), expectedMillis) - } - - // Test null value - var nullTs Timestamp - err = nullTs.UnmarshalBSONValue(bson.TypeNull, nil) - if err != nil { - t.Errorf("Timestamp.UnmarshalBSONValue() error for null = %v", err) - } - - if !nullTs.Time.IsZero() { - t.Error("Expected zero time for null BSON value") - } -} - -func TestTimestamp_BSONRoundTrip(t *testing.T) { - // Test round-trip marshaling and unmarshaling - originalTime := time.Unix(1640995200, 123000000) // 2022-01-01 00:00:00.123 UTC - original := NewTimestamp(originalTime) - - // Marshal to BSON value - bsonType, bsonData, err := original.MarshalBSONValue() - if err != nil { - t.Fatalf("MarshalBSONValue() failed: %v", err) - } - - // Unmarshal from BSON value - var unmarshaled Timestamp - err = unmarshaled.UnmarshalBSONValue(bsonType, bsonData) - if err != nil { - t.Fatalf("UnmarshalBSONValue() failed: %v", err) - } - - // Compare milliseconds (BSON stores as milliseconds) - if original.UnixMilli() != unmarshaled.UnixMilli() { - t.Errorf("Round-trip failed: original %d ms, unmarshaled %d ms", - original.UnixMilli(), unmarshaled.UnixMilli()) - } -} - -func TestTimestamp_BSONWithStruct(t *testing.T) { - // Test marshaling/unmarshaling as part of a struct - type TestStruct struct { - Name string `bson:"name"` - CreatedAt *Timestamp `bson:"createdAt"` - } - - original := TestStruct{ - Name: "test", - CreatedAt: NewTimestamp(time.Unix(1640995200, 0)), - } - - // Marshal struct to BSON - bsonData, err := bson.Marshal(original) - if err != nil { - t.Fatalf("Failed to marshal struct: %v", err) - } - - // Unmarshal struct from BSON - var unmarshaled TestStruct - err = bson.Unmarshal(bsonData, &unmarshaled) - if err != nil { - t.Fatalf("Failed to unmarshal struct: %v", err) - } - - // Verify data integrity - if original.Name != unmarshaled.Name { - t.Errorf("Name mismatch: got %s, want %s", unmarshaled.Name, original.Name) - } - - if original.CreatedAt.UnixMilli() != unmarshaled.CreatedAt.UnixMilli() { - t.Errorf("Timestamp mismatch: got %d, want %d", - unmarshaled.CreatedAt.UnixMilli(), original.CreatedAt.UnixMilli()) - } -} - -func TestTimestamp_BSONNilHandling(t *testing.T) { - // Test handling of nil timestamp in struct - type TestStruct struct { - Name string `bson:"name"` - CreatedAt *Timestamp `bson:"createdAt,omitempty"` - } - - original := TestStruct{ - Name: "test", - CreatedAt: nil, - } - - // Marshal struct to BSON - bsonData, err := bson.Marshal(original) - if err != nil { - t.Fatalf("Failed to marshal struct with nil timestamp: %v", err) - } - - // Unmarshal struct from BSON - var unmarshaled TestStruct - err = bson.Unmarshal(bsonData, &unmarshaled) - if err != nil { - t.Fatalf("Failed to unmarshal struct with nil timestamp: %v", err) - } - - // Verify nil is preserved - if unmarshaled.CreatedAt != nil { - t.Error("Expected nil timestamp to remain nil after round-trip") - } -} - -func TestTimestamp_BSONCompatibility(t *testing.T) { - // Test that Timestamp works correctly when embedded in complex structures - type ComplexStruct struct { - ID string `bson:"_id"` - Timestamps []Timestamp `bson:"timestamps"` - Created *Timestamp `bson:"created"` - Updated *Timestamp `bson:"updated,omitempty"` - } - - now := time.Now() - original := ComplexStruct{ - ID: "test-id", - Timestamps: []Timestamp{ - *NewTimestamp(now.Add(-time.Hour)), - *NewTimestamp(now.Add(-time.Minute)), - }, - Created: NewTimestamp(now), - Updated: nil, - } - - // Marshal to BSON - bsonData, err := bson.Marshal(original) - if err != nil { - t.Fatalf("Failed to marshal complex struct: %v", err) - } - - // Unmarshal from BSON - var unmarshaled ComplexStruct - err = bson.Unmarshal(bsonData, &unmarshaled) - if err != nil { - t.Fatalf("Failed to unmarshal complex struct: %v", err) - } - - // Verify all fields - if original.ID != unmarshaled.ID { - t.Errorf("ID mismatch: got %s, want %s", unmarshaled.ID, original.ID) - } - - if len(original.Timestamps) != len(unmarshaled.Timestamps) { - t.Errorf("Timestamps slice length mismatch: got %d, want %d", - len(unmarshaled.Timestamps), len(original.Timestamps)) - } - - for i, ts := range original.Timestamps { - if ts.UnixMilli() != unmarshaled.Timestamps[i].UnixMilli() { - t.Errorf("Timestamp[%d] mismatch: got %d, want %d", - i, unmarshaled.Timestamps[i].UnixMilli(), ts.UnixMilli()) - } - } - - if original.Created.UnixMilli() != unmarshaled.Created.UnixMilli() { - t.Errorf("Created timestamp mismatch: got %d, want %d", - unmarshaled.Created.UnixMilli(), original.Created.UnixMilli()) - } - - if unmarshaled.Updated != nil { - t.Error("Expected Updated to remain nil") - } -} diff --git a/sharding-compose.yml b/sharding-compose.yml new file mode 100644 index 0000000..03faff4 --- /dev/null +++ b/sharding-compose.yml @@ -0,0 +1,288 @@ +services: + x-bft: + &bft-base + platform: linux/amd64 + user: "${USER_UID:-1001}:${USER_GID:-1001}" + # https://github.com/unicitynetwork/bft-core/pkgs/container/bft-core + image: ghcr.io/unicitynetwork/bft-core:49d48f8fd3686aff1066d98eff2512e5fc71713c + bft-root: + <<: *bft-base + volumes: + - ./data/genesis-root:/genesis/root + - ./data/genesis:/genesis + healthcheck: + test: [ "CMD", "nc", "-zv", "bft-root", "8000" ] + interval: 5s + networks: + - default + entrypoint: ["/busybox/sh", "-c"] + command: + - | + if [ -f /genesis/root/node-info.json ] && [ -f /genesis/trust-base.json ] && [ -f /genesis/root/trust-base-signed.json ]; then + echo "Genesis files already exist, skipping initialization." + else + echo "Creating root genesis..." && + ubft root-node init --home /genesis/root -g && + echo "Creating root trust base..." && + ubft trust-base generate --home /genesis --network-id 3 --node-info /genesis/root/node-info.json && + echo "Signing root trust base..." && + ubft trust-base sign --home /genesis/root --trust-base /genesis/trust-base.json + fi + echo "Starting root node..." && + ubft root-node run --home /genesis/root --address "/ip4/$(hostname -i)/tcp/8000" --trust-base /genesis/trust-base.json --rpc-server-address "$(hostname -i):8002" && + ls -l /genesis/root + echo "Root node started successfully." + + bft-aggregator-genesis-gen: + <<: *bft-base + volumes: + - ./data/genesis-root:/genesis/root + - ./data/genesis:/genesis + depends_on: + bft-root: + condition: service_healthy + ports: + - "11003:11003" + networks: + - default + entrypoint: ["/busybox/sh", "-c"] + command: + - | + if [ -f /genesis/aggregator/node-info.json ] && [ -f /genesis/shard-conf-7_0.json ]; then + echo "Aggregator genesis and config already exist, skipping initialization." + else + echo "Creating aggregator genesis..." && + ubft shard-node init --home /genesis/aggregator --generate && + echo "Creating aggregator partition configuration..." && + ubft shard-conf generate --home /genesis --t2-timeout 5000 --network-id 3 --partition-id 7 --partition-type-id 7 --epoch-start 10 --node-info=/genesis/aggregator/node-info.json && + echo "Creating aggregator partition state..." && + ubft shard-conf genesis --home "/genesis/aggregator" --shard-conf /genesis/shard-conf-7_0.json + fi + chmod -R 755 /genesis/aggregator + chmod 644 /genesis/shard-conf-7_0.json + chmod 644 /genesis/trust-base.json + chmod -R 755 /genesis/root + echo "Permissions fixed." + ls -l /genesis/aggregator && + ls -l /genesis/ + + upload-configurations: + image: curlimages/curl:8.13.0 + user: "${USER_UID:-1001}:${USER_GID:-1001}" + depends_on: + bft-root: + condition: service_healthy + bft-aggregator-genesis-gen: + condition: service_completed_successfully + restart: on-failure + volumes: + - ./data/genesis:/genesis + command: | + /bin/sh -c " + echo Uploading aggregator configuration && + curl -X PUT -H 'Content-Type: application/json' -d @/genesis/shard-conf-7_0.json http://bft-root:8002/api/v1/configurations + " + + mongodb-shard1: + image: mongo:7.0 + container_name: aggregator-shard1-mongodb + user: "${USER_UID:-1001}:${USER_GID:-1001}" + restart: unless-stopped + ports: + - "27017:27017" + networks: + - default + environment: + MONGO_INITDB_ROOT_USERNAME: admin + MONGO_INITDB_ROOT_PASSWORD: password + MONGO_INITDB_DATABASE: aggregator + volumes: + - ./data/mongodb_shard1_data:/data/db + - ./scripts/mongo-init.js:/docker-entrypoint-initdb.d/mongo-init.js:ro + healthcheck: + test: ["CMD", "mongosh", "--eval", "db.adminCommand('ping')"] + interval: 10s + timeout: 5s + retries: 5 + mongodb-shard2: + image: mongo:7.0 + container_name: aggregator-shard2-mongodb + user: "${USER_UID:-1001}:${USER_GID:-1001}" + restart: unless-stopped + ports: + - "27018:27017" + networks: + - default + environment: + MONGO_INITDB_ROOT_USERNAME: admin + MONGO_INITDB_ROOT_PASSWORD: password + MONGO_INITDB_DATABASE: aggregator + volumes: + - ./data/mongodb_shard2_data:/data/db + - ./scripts/mongo-init.js:/docker-entrypoint-initdb.d/mongo-init.js:ro + healthcheck: + test: ["CMD", "mongosh", "--eval", "db.adminCommand('ping')"] + interval: 10s + timeout: 5s + retries: 5 + mongodb-root: + image: mongo:7.0 + container_name: aggregator-root-mongodb + user: "${USER_UID:-1001}:${USER_GID:-1001}" + restart: unless-stopped + ports: + - "27019:27017" + networks: + - default + environment: + MONGO_INITDB_ROOT_USERNAME: admin + MONGO_INITDB_ROOT_PASSWORD: password + MONGO_INITDB_DATABASE: aggregator + volumes: + - ./data/mongodb_root_data:/data/db + - ./scripts/mongo-init.js:/docker-entrypoint-initdb.d/mongo-init.js:ro + healthcheck: + test: ["CMD", "mongosh", "--eval", "db.adminCommand('ping')"] + interval: 10s + timeout: 5s + retries: 5 + aggregator-shard1: &aggregator-base + build: + context: . + dockerfile: Dockerfile + container_name: aggregator-shard1 + restart: unless-stopped + ports: + - "3001:3000" + networks: + - default + volumes: + - ./data/genesis:/app/bft-config + environment: &environment-base + # Server Configuration + PORT: "3000" + HOST: "0.0.0.0" + CONCURRENCY_LIMIT: "1000" + ENABLE_DOCS: "true" + ENABLE_CORS: "true" + + # Database Configuration + MONGODB_URI: "mongodb://admin:password@mongodb-shard1:27017/aggregator?authSource=admin" + MONGODB_DATABASE: "aggregator" + MONGODB_CONNECT_TIMEOUT: "10s" + MONGODB_SERVER_SELECTION_TIMEOUT: "5s" + + # High Availability Configuration + DISABLE_HIGH_AVAILABILITY: "true" + LOCK_TTL_SECONDS: "30" + LEADER_HEARTBEAT_INTERVAL: "10s" + LEADER_ELECTION_POLLING_INTERVAL: "5s" + + # Logging Configuration + LOG_LEVEL: "debug" + LOG_FORMAT: "json" + LOG_ENABLE_JSON: "true" + + # Processing Configuration + BATCH_LIMIT: "1000" + + # BFT Configuration + # Enable BFT support and specify configuration file paths + BFT_ENABLED: "true" + BFT_KEY_CONF_FILE: "/app/bft-config/aggregator/keys.json" + BFT_SHARD_CONF_FILE: "/app/bft-config/shard-conf-7_0.json" + BFT_TRUST_BASE_FILE: "/app/bft-config/trust-base.json" + # BFT_BOOTSTRAP_ADDRESSES will be set dynamically by the entrypoint script + + # Redis Configuration + REDIS_HOST: "redis" + REDIS_PORT: "6379" + REDIS_PASSWORD: "" + REDIS_DB: "0" + REDIS_POOL_SIZE: "100" + REDIS_MIN_IDLE_CONNS: "10" + + # Storage Configuration + USE_REDIS_FOR_COMMITMENTS: "false" + REDIS_FLUSH_INTERVAL: "50ms" + REDIS_MAX_BATCH_SIZE: "2000" + + # Sharding configuration + SHARDING_MODE: "child" + SHARD_ID_LENGTH: 1 + SHARDING_CHILD_PARENT_RPC_ADDR: http://aggregator-root:3000 + SHARDING_CHILD_SHARD_ID: 3 # 0b11 + SHARDING_CHILD_ROUND_DURATION: 1s + SHARDING_CHILD_PARENT_POLL_TIMEOUT: 5s + SHARDING_CHILD_PARENT_POLL_INTERVAL: 100ms + entrypoint: ["/bin/sh", "-c"] + command: + - | + # Extract the first root node's nodeId from trust-base.json + if [ -f /app/bft-config/trust-base.json ]; then + ROOT_NODE_ID=$$(cat /app/bft-config/trust-base.json | grep -o '"nodeId": "[^"]*"' | head -1 | cut -d'"' -f4) + if [ -n "$$ROOT_NODE_ID" ]; then + export BFT_BOOTSTRAP_ADDRESSES="/dns4/bft-root/tcp/8000/p2p/$$ROOT_NODE_ID" + echo "Set BFT_BOOTSTRAP_ADDRESSES to: $$BFT_BOOTSTRAP_ADDRESSES" + else + echo "Warning: Could not extract nodeId from trust-base.json" + exit 1 + fi + else + echo "Error: trust-base.json not found at /app/bft-config/trust-base.json" + exit 1 + fi + + # Start the aggregator application + exec /app/aggregator + + depends_on: + bft-aggregator-genesis-gen: + condition: service_completed_successfully + mongodb-shard1: + condition: service_healthy + healthcheck: + test: [ "CMD", "nc", "-zv", "localhost", "3000" ] + interval: 30s + timeout: 10s + retries: 3 + aggregator-shard2: + <<: *aggregator-base + container_name: aggregator-shard2 + ports: + - "3002:3000" + environment: + <<: *environment-base + MONGODB_URI: "mongodb://admin:password@mongodb-shard2:27017/aggregator?authSource=admin" + SHARDING_CHILD_SHARD_ID: 2 # 0b10 + depends_on: + bft-aggregator-genesis-gen: + condition: service_completed_successfully + mongodb-shard2: + condition: service_healthy + healthcheck: + test: [ "CMD", "nc", "-zv", "localhost", "3000" ] + interval: 30s + timeout: 10s + retries: 3 + aggregator-root: + <<: *aggregator-base + container_name: aggregator-root + ports: + - "3009:3000" + environment: + <<: *environment-base + MONGODB_URI: "mongodb://admin:password@mongodb-root:27017/aggregator?authSource=admin" + SHARDING_MODE: "parent" + depends_on: + bft-aggregator-genesis-gen: + condition: service_completed_successfully + mongodb-root: + condition: service_healthy + healthcheck: + test: [ "CMD", "nc", "-zv", "localhost", "3000" ] + interval: 30s + timeout: 10s + retries: 3 +networks: + default: \ No newline at end of file diff --git a/sharding-ha-compose.yml b/sharding-ha-compose.yml new file mode 100644 index 0000000..74b8ed4 --- /dev/null +++ b/sharding-ha-compose.yml @@ -0,0 +1,311 @@ +services: + x-bft: + &bft-base + platform: linux/amd64 + user: "${USER_UID:-1001}:${USER_GID:-1001}" + image: ghcr.io/unicitynetwork/bft-core:49d48f8fd3686aff1066d98eff2512e5fc71713c + bft-root: + <<: *bft-base + volumes: + - ./data/genesis-root:/genesis/root + - ./data/genesis:/genesis + healthcheck: + test: [ "CMD", "nc", "-zv", "bft-root", "8000" ] + interval: 5s + networks: + - default + entrypoint: ["/busybox/sh", "-c"] + command: + - | + if [ -f /genesis/root/node-info.json ] && [ -f /genesis/trust-base.json ] && [ -f /genesis/root/trust-base-signed.json ]; then + echo "Genesis files already exist, skipping initialization." + else + echo "Creating root genesis..." && + ubft root-node init --home /genesis/root -g && + echo "Creating root trust base..." && + ubft trust-base generate --home /genesis --network-id 3 --node-info /genesis/root/node-info.json && + echo "Signing root trust base..." && + ubft trust-base sign --home /genesis/root --trust-base /genesis/trust-base.json + fi + echo "Starting root node..." && + ubft root-node run --home /genesis/root --address "/ip4/$(hostname -i)/tcp/8000" --trust-base /genesis/trust-base.json --rpc-server-address "$(hostname -i):8002" && + ls -l /genesis/root + echo "Root node started successfully." + + bft-aggregator-genesis-gen: + <<: *bft-base + volumes: + - ./data/genesis-root:/genesis/root + - ./data/genesis:/genesis + depends_on: + bft-root: + condition: service_healthy + ports: + - "11003:11003" + networks: + - default + entrypoint: ["/busybox/sh", "-c"] + command: + - | + if [ -f /genesis/aggregator/node-info.json ] && [ -f /genesis/shard-conf-7_0.json ]; then + echo "Aggregator genesis and config already exist, skipping initialization." + else + echo "Creating aggregator genesis..." && + ubft shard-node init --home /genesis/aggregator --generate && + echo "Creating aggregator partition configuration..." && + ubft shard-conf generate --home /genesis --t2-timeout 5000 --network-id 3 --partition-id 7 --partition-type-id 7 --epoch-start 10 --node-info=/genesis/aggregator/node-info.json && + echo "Creating aggregator partition state..." && + ubft shard-conf genesis --home "/genesis/aggregator" --shard-conf /genesis/shard-conf-7_0.json + fi + chmod -R 755 /genesis/aggregator + chmod 644 /genesis/shard-conf-7_0.json + chmod 644 /genesis/trust-base.json + chmod -R 755 /genesis/root + echo "Permissions fixed." + ls -l /genesis/aggregator && + ls -l /genesis/ + + upload-configurations: + image: curlimages/curl:8.13.0 + user: "${USER_UID:-1001}:${USER_GID:-1001}" + depends_on: + bft-root: + condition: service_healthy + bft-aggregator-genesis-gen: + condition: service_completed_successfully + restart: on-failure + volumes: + - ./data/genesis:/genesis + command: | + /bin/sh -c " + echo Uploading aggregator configuration && + curl -X PUT -H 'Content-Type: application/json' -d @/genesis/shard-conf-7_0.json http://bft-root:8002/api/v1/configurations + " + + + + + # Shard 1 MongoDB Replica Set + mongodb-shard1-1: &mongo-base + image: mongo:7.0 + command: ["--replSet", "rs1", "--bind_ip_all", "--noauth"] + networks: + - default + restart: unless-stopped + healthcheck: + test: ["CMD", "mongosh", "--eval", "db.adminCommand('ping')"] + interval: 1s + timeout: 1s + retries: 30 + mongodb-shard1-2: + <<: *mongo-base + container_name: mongodb-shard1-2 + volumes: + - ./data/mongodb_shard1_data_2:/data/db + mongo-setup-shard1: + image: mongo:7.0 + networks: + - default + depends_on: + mongodb-shard1-1: + condition: service_healthy + mongodb-shard1-2: + condition: service_healthy + entrypoint: + - mongosh + - "--host" + - "mongodb-shard1-1:27017" + - "--eval" + - "try { rs.initiate({_id: 'rs1', members: [{_id: 0, host: 'mongodb-shard1-1:27017'}, {_id: 1, host: 'mongodb-shard1-2:27017'}]}) } catch(e) { print('Replica set already initialized or error:', e.message) }" + + # Shard 2 MongoDB Replica Set + mongodb-shard2-1: + <<: *mongo-base + command: ["--replSet", "rs2", "--bind_ip_all", "--noauth"] + container_name: mongodb-shard2-1 + volumes: + - ./data/mongodb_shard2_data_1:/data/db + mongodb-shard2-2: + <<: *mongo-base + command: ["--replSet", "rs2", "--bind_ip_all", "--noauth"] + container_name: mongodb-shard2-2 + volumes: + - ./data/mongodb_shard2_data_2:/data/db + mongo-setup-shard2: + image: mongo:7.0 + networks: + - default + depends_on: + mongodb-shard2-1: + condition: service_healthy + mongodb-shard2-2: + condition: service_healthy + entrypoint: + - mongosh + - "--host" + - "mongodb-shard2-1:27017" + - "--eval" + - "try { rs.initiate({_id: 'rs2', members: [{_id: 0, host: 'mongodb-shard2-1:27017'}, {_id: 1, host: 'mongodb-shard2-2:27017'}]}) } catch(e) { print('Replica set already initialized or error:', e.message) }" + + # Root MongoDB Replica Set + mongodb-root-1: + <<: *mongo-base + command: ["--replSet", "rs-root", "--bind_ip_all", "--noauth"] + container_name: mongodb-root-1 + volumes: + - ./data/mongodb_root_data_1:/data/db + mongodb-root-2: + <<: *mongo-base + command: ["--replSet", "rs-root", "--bind_ip_all", "--noauth"] + container_name: mongodb-root-2 + volumes: + - ./data/mongodb_root_data_2:/data/db + mongo-setup-root: + image: mongo:7.0 + networks: + - default + depends_on: + mongodb-root-1: + condition: service_healthy + mongodb-root-2: + condition: service_healthy + entrypoint: + - mongosh + - "--host" + - "mongodb-root-1:27017" + - "--eval" + - "try { rs.initiate({_id: 'rs-root', members: [{_id: 0, host: 'mongodb-root-1:27017'}, {_id: 1, host: 'mongodb-root-2:27017'}]}) } catch(e) { print('Replica set already initialized or error:', e.message) }" + + # Aggregator Shard 1 Cluster + aggregator-shard1-1: &aggregator-base + build: + context: . + dockerfile: Dockerfile + container_name: aggregator-shard1-1 + restart: unless-stopped + ports: + - "3001:3000" + networks: + - default + volumes: + - ./data/genesis:/app/bft-config + environment: &environment-base + MONGODB_URI: "mongodb://mongodb-shard1-1:27017,mongodb-shard1-2:27017/aggregator?replicaSet=rs1" + PORT: "3000" + HOST: "0.0.0.0" + CONCURRENCY_LIMIT: "1000" + ENABLE_DOCS: "true" + ENABLE_CORS: "true" + MONGODB_DATABASE: "aggregator" + MONGODB_CONNECT_TIMEOUT: "10s" + MONGODB_SERVER_SELECTION_TIMEOUT: "5s" + DISABLE_HIGH_AVAILABILITY: "false" + LOCK_TTL_SECONDS: "30" + LEADER_HEARTBEAT_INTERVAL: "10s" + LEADER_ELECTION_POLLING_INTERVAL: "5s" + LOG_LEVEL: "debug" + LOG_FORMAT: "json" + LOG_ENABLE_JSON: "true" + BATCH_LIMIT: "1000" + BFT_ENABLED: "false" + BFT_KEY_CONF_FILE: "/app/bft-config/aggregator/keys.json" + BFT_SHARD_CONF_FILE: "/app/bft-config/shard-conf-7_0.json" + BFT_TRUST_BASE_FILE: "/app/bft-config/trust-base.json" + REDIS_HOST: "redis" + REDIS_PORT: "6379" + REDIS_PASSWORD: "" + REDIS_DB: "0" + REDIS_POOL_SIZE: "100" + REDIS_MIN_IDLE_CONNS: "10" + USE_REDIS_FOR_COMMITMENTS: "false" + REDIS_FLUSH_INTERVAL: "50ms" + REDIS_MAX_BATCH_SIZE: "2000" + + # Sharding configuration + SHARDING_MODE: "child" + SHARD_ID_LENGTH: 1 + SHARDING_CHILD_PARENT_RPC_ADDR: http://aggregator-root-1:3000 + SHARDING_CHILD_SHARD_ID: 3 # 0b11 + SHARDING_CHILD_ROUND_DURATION: 1s + SHARDING_CHILD_PARENT_POLL_TIMEOUT: 5s + SHARDING_CHILD_PARENT_POLL_INTERVAL: 100ms + entrypoint: ["/bin/sh", "-c"] + command: + - | + if [ -f /app/bft-config/trust-base.json ]; then + ROOT_NODE_ID=$$(cat /app/bft-config/trust-base.json | grep -o '"nodeId": "[^"]*"' | head -1 | cut -d'"' -f4) + if [ -n "$$ROOT_NODE_ID" ]; then + export BFT_BOOTSTRAP_ADDRESSES="/dns4/bft-root/tcp/8000/p2p/$$ROOT_NODE_ID" + echo "Set BFT_BOOTSTRAP_ADDRESSES to: $$BFT_BOOTSTRAP_ADDRESSES" + else + echo "Warning: Could not extract nodeId from trust-base.json" + exit 1 + fi + else + echo "Error: trust-base.json not found at /app/bft-config/trust-base.json" + exit 1 + fi + exec /app/aggregator + healthcheck: + test: [ "CMD", "nc", "-zv", "localhost", "3000" ] + interval: 30s + timeout: 10s + retries: 3 + aggregator-shard1-2: + <<: *aggregator-base + container_name: aggregator-shard1-2 + ports: + - "3002:3000" + environment: + <<: *environment-base + MONGODB_URI: "mongodb://mongodb-shard1-1:27017,mongodb-shard1-2:27017/aggregator?replicaSet=rs1" + depends_on: + - mongo-setup-shard1 + - bft-aggregator-genesis-gen + + # Aggregator Shard 2 Cluster + aggregator-shard2-1: + <<: *aggregator-base + container_name: aggregator-shard2-1 + ports: + - "3003:3000" + environment: + <<: *environment-base + MONGODB_URI: "mongodb://mongodb-shard2-1:27017,mongodb-shard2-2:27017/aggregator?replicaSet=rs2" + SHARDING_MODE: "child" + SHARDING_CHILD_SHARD_ID: 2 # 0b10 + depends_on: + - mongo-setup-shard2 + - bft-aggregator-genesis-gen + aggregator-shard2-2: + <<: *aggregator-base + container_name: aggregator-shard2-2 + ports: + - "3004:3000" + environment: + <<: *environment-base + MONGODB_URI: "mongodb://mongodb-shard2-1:27017,mongodb-shard2-2:27017/aggregator?replicaSet=rs2" + SHARDING_MODE: "child" + SHARDING_CHILD_SHARD_ID: 2 # 0b10 + depends_on: + - mongo-setup-shard2 + - bft-aggregator-genesis-gen + + # Aggregator Root Cluster + # currently only 1 instance as the child must send only to the parent which is not implemented currently + aggregator-root-1: + <<: *aggregator-base + container_name: aggregator-root-1 + ports: + - "3009:3000" + environment: + <<: *environment-base + MONGODB_URI: "mongodb://mongodb-root-1:27017,mongodb-root-2:27017/aggregator?replicaSet=rs-root" + SHARDING_MODE: "parent" + BFT_ENABLED: "true" + depends_on: + - mongo-setup-root + - bft-aggregator-genesis-gen + +networks: + default: diff --git a/test-docs-example.sh b/test-docs-example.sh old mode 100755 new mode 100644 diff --git a/test/integration/sharding_e2e_test.go b/test/integration/sharding_e2e_test.go new file mode 100644 index 0000000..e3ae55c --- /dev/null +++ b/test/integration/sharding_e2e_test.go @@ -0,0 +1,564 @@ +package integration + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "math/big" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/suite" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" + + "github.com/unicitynetwork/aggregator-go/internal/config" + "github.com/unicitynetwork/aggregator-go/internal/gateway" + "github.com/unicitynetwork/aggregator-go/internal/ha" + "github.com/unicitynetwork/aggregator-go/internal/ha/state" + "github.com/unicitynetwork/aggregator-go/internal/logger" + "github.com/unicitynetwork/aggregator-go/internal/round" + "github.com/unicitynetwork/aggregator-go/internal/service" + "github.com/unicitynetwork/aggregator-go/internal/storage" + "github.com/unicitynetwork/aggregator-go/internal/storage/interfaces" + "github.com/unicitynetwork/aggregator-go/internal/testutil" + "github.com/unicitynetwork/aggregator-go/pkg/api" +) + +type ShardingE2ETestSuite struct { + suite.Suite + mongoContainer testcontainers.Container + mongoURI string + instances []*aggregatorInstance +} + +type aggregatorInstance struct { + name string + cfg *config.Config + logger *logger.Logger + commitmentQueue interfaces.CommitmentQueue + storage interfaces.Storage + manager round.Manager + service gateway.Service + server *gateway.Server + leaderElection *ha.LeaderElection + cleanup func() +} + +type jsonRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params interface{} `json:"params"` + ID int `json:"id"` +} + +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + Data string `json:"data,omitempty"` + } `json:"error,omitempty"` + ID int `json:"id"` +} + +func (suite *ShardingE2ETestSuite) SetupSuite() { + ctx := context.Background() + + req := testcontainers.ContainerRequest{ + Image: "mongo:7.0", + ExposedPorts: []string{"27017/tcp"}, + WaitingFor: wait.ForLog("Waiting for connections").WithStartupTimeout(60 * time.Second), + } + + mongoContainer, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + suite.Require().NoError(err) + + host, err := mongoContainer.Host(ctx) + suite.Require().NoError(err) + + port, err := mongoContainer.MappedPort(ctx, "27017") + suite.Require().NoError(err) + + suite.mongoContainer = mongoContainer + suite.mongoURI = fmt.Sprintf("mongodb://%s:%s", host, port.Port()) + suite.instances = make([]*aggregatorInstance, 0) + + suite.T().Logf("MongoDB container started at %s", suite.mongoURI) +} + +func (suite *ShardingE2ETestSuite) TearDownSuite() { + ctx := context.Background() + + for _, inst := range suite.instances { + if inst.cleanup != nil { + inst.cleanup() + } + } + + if suite.mongoContainer != nil { + suite.mongoContainer.Terminate(ctx) + } +} + +func (suite *ShardingE2ETestSuite) buildConfig(mode config.ShardingMode, port, dbName string, shardID api.ShardID) *config.Config { + cfg := &config.Config{ + Server: config.ServerConfig{ + Host: "localhost", + Port: port, + EnableCORS: true, + EnableDocs: false, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + ConcurrencyLimit: 100, + }, + Database: config.DatabaseConfig{ + URI: suite.mongoURI, + Database: dbName, + ConnectTimeout: 10 * time.Second, + ServerSelectionTimeout: 10 * time.Second, + SocketTimeout: 10 * time.Second, + MaxPoolSize: 10, + MinPoolSize: 2, + }, + HA: config.HAConfig{ + Enabled: false, + }, + Logging: config.LoggingConfig{ + Level: "info", + Format: "json", + }, + BFT: config.BFTConfig{ + Enabled: false, + }, + Processing: config.ProcessingConfig{ + RoundDuration: 100 * time.Millisecond, + BatchLimit: 1000, + MaxCommitmentsPerRound: 1000, + }, + Storage: config.StorageConfig{ + UseRedisForCommitments: false, + }, + Sharding: config.ShardingConfig{ + Mode: mode, + ShardIDLength: 1, // 2 shards: IDs 2-3 + }, + } + + // Child-specific configuration + if mode == config.ShardingModeChild { + cfg.Sharding.Child = config.ChildConfig{ + ParentRpcAddr: "http://localhost:9000", + ShardID: shardID, + ParentPollTimeout: 5 * time.Second, + ParentPollInterval: 100 * time.Millisecond, + } + } + + return cfg +} + +func (suite *ShardingE2ETestSuite) startAggregatorInstance(name string, cfg *config.Config) *aggregatorInstance { + ctx := context.Background() + + log, err := logger.New(cfg.Logging.Level, cfg.Logging.Format, cfg.Logging.Output, cfg.Logging.EnableJSON) + suite.Require().NoError(err) + + commitmentQueue, storageInstance, err := storage.NewStorage(cfg, log) + suite.Require().NoError(err) + + err = commitmentQueue.Initialize(ctx) + suite.Require().NoError(err) + + stateTracker := state.NewSyncStateTracker() + + manager, err := round.NewManager(ctx, cfg, log, commitmentQueue, storageInstance, stateTracker) + suite.Require().NoError(err) + + err = manager.Start(ctx) + suite.Require().NoError(err) + + var leaderElection *ha.LeaderElection + var leaderSelector service.LeaderSelector + var haManager *ha.HAManager + + if cfg.HA.Enabled { + leaderElection = ha.NewLeaderElection(log, cfg.HA, storageInstance.LeadershipStorage()) + leaderElection.Start(ctx) + leaderSelector = leaderElection + + time.Sleep(100 * time.Millisecond) + + disableBlockSync := cfg.Sharding.Mode == config.ShardingModeParent + haManager = ha.NewHAManager(log, manager, leaderElection, storageInstance, manager.GetSMT(), cfg.Sharding.Child.ShardID, stateTracker, cfg.Processing.RoundDuration, disableBlockSync) + haManager.Start(ctx) + } else { + leaderSelector = nil + err = manager.Activate(ctx) + suite.Require().NoError(err) + } + + svc, err := service.NewService(ctx, cfg, log, manager, commitmentQueue, storageInstance, leaderSelector) + suite.Require().NoError(err) + + server := gateway.NewServer(cfg, log, svc) + + go func() { + if err := server.Start(); err != nil && err != http.ErrServerClosed { + log.Error("Server error", "error", err.Error()) + } + }() + + time.Sleep(200 * time.Millisecond) + + inst := &aggregatorInstance{ + name: name, + cfg: cfg, + logger: log, + commitmentQueue: commitmentQueue, + storage: storageInstance, + manager: manager, + service: svc, + server: server, + leaderElection: leaderElection, + cleanup: func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + server.Stop(shutdownCtx) + if haManager != nil { + haManager.Stop() + } + if leaderElection != nil { + leaderElection.Stop(context.Background()) + } + manager.Stop(context.Background()) + storageInstance.Close(context.Background()) + }, + } + + suite.instances = append(suite.instances, inst) + suite.T().Logf("✓ Started %s on :%s", name, cfg.Server.Port) + + return inst +} + +func (suite *ShardingE2ETestSuite) rpcCall(url string, method string, params interface{}) (json.RawMessage, error) { + reqBody := jsonRPCRequest{ + JSONRPC: "2.0", + Method: method, + Params: params, + ID: 1, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + resp, err := http.Post(url, "application/json", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() + + var rpcResp jsonRPCResponse + if err := json.NewDecoder(resp.Body).Decode(&rpcResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if rpcResp.Error != nil { + return nil, fmt.Errorf("RPC error: %s", rpcResp.Error.Message) + } + + return rpcResp.Result, nil +} + +func (suite *ShardingE2ETestSuite) submitCommitment(url string, commitment *api.SubmitCommitmentRequest) (*api.SubmitCommitmentResponse, error) { + result, err := suite.rpcCall(url, "submit_commitment", commitment) + if err != nil { + return nil, err + } + + var response api.SubmitCommitmentResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &response, nil +} + +func (suite *ShardingE2ETestSuite) getInclusionProof(url string, requestID string) (*api.GetInclusionProofResponse, error) { + params := map[string]string{"requestId": requestID} + result, err := suite.rpcCall(url, "get_inclusion_proof", params) + if err != nil { + return nil, err + } + + var response api.GetInclusionProofResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &response, nil +} + +func (suite *ShardingE2ETestSuite) getBlockHeight(url string) (*api.GetBlockHeightResponse, error) { + result, err := suite.rpcCall(url, "get_block_height", nil) + if err != nil { + return nil, err + } + + var response api.GetBlockHeightResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &response, nil +} + +func (suite *ShardingE2ETestSuite) createCommitmentForShard(shardID api.ShardID, shardIDLength int) (*api.SubmitCommitmentRequest, string) { + // Shard ID is encoded in the LSBs of the requestID (see commitment_validator.go verifyShardID) + msbPos := shardIDLength + compareMask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), uint(msbPos)), big.NewInt(1)) + expectedLSBs := new(big.Int).And(big.NewInt(int64(shardID)), compareMask) + + for attempts := 0; attempts < 1000; attempts++ { + baseData := fmt.Sprintf("shard_%d_attempt_%d", shardID, attempts) + commitment := testutil.CreateTestCommitment(suite.T(), baseData) + + requestIDBytes, err := commitment.RequestID.Bytes() + suite.Require().NoError(err) + requestIDBigInt := new(big.Int).SetBytes(requestIDBytes) + requestIDLSBs := new(big.Int).And(requestIDBigInt, compareMask) + + if requestIDLSBs.Cmp(expectedLSBs) == 0 { + receipt := true + apiCommitment := &api.SubmitCommitmentRequest{ + RequestID: commitment.RequestID, + TransactionHash: api.TransactionHash(commitment.TransactionHash), + Authenticator: api.Authenticator{ + Algorithm: commitment.Authenticator.Algorithm, + PublicKey: api.HexBytes(commitment.Authenticator.PublicKey), + Signature: api.HexBytes(commitment.Authenticator.Signature), + StateHash: api.StateHash(commitment.Authenticator.StateHash), + }, + Receipt: &receipt, + } + + return apiCommitment, commitment.RequestID.String() + } + } + + suite.FailNow("Failed to generate commitment for shard after 1000 attempts") + return nil, "" +} + +func (suite *ShardingE2ETestSuite) waitForBlock(url string, blockNumber int64, timeout time.Duration) { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + resp, err := suite.getBlockHeight(url) + if err == nil && resp.BlockNumber != nil && resp.BlockNumber.Cmp(big.NewInt(blockNumber)) >= 0 { + return + } + time.Sleep(50 * time.Millisecond) + } + suite.FailNow(fmt.Sprintf("Timeout waiting for block %d at %s", blockNumber, url)) +} + +// waitForProofAvailable waits for a VALID inclusion proof to become available +// This includes waiting for the parent proof to be received and joined +func (suite *ShardingE2ETestSuite) waitForProofAvailable(url, requestID string, timeout time.Duration) *api.GetInclusionProofResponse { + deadline := time.Now().Add(timeout) + reqID := api.RequestID(requestID) + reqIDPath, err := reqID.GetPath() + suite.Require().NoError(err) + + for time.Now().Before(deadline) { + resp, err := suite.getInclusionProof(url, requestID) + if err == nil && resp.InclusionProof != nil && resp.InclusionProof.MerkleTreePath != nil { + // Also verify that the proof is valid (includes parent proof) + result, verifyErr := resp.InclusionProof.MerkleTreePath.Verify(reqIDPath) + if verifyErr == nil && result != nil && result.Result { + return resp + } + // Proof exists but not valid yet (probably waiting for parent proof), keep retrying + } + time.Sleep(50 * time.Millisecond) + } + suite.FailNow(fmt.Sprintf("Timeout waiting for valid proof for requestID %s at %s", requestID, url)) + return nil +} + +// TestShardingE2E tests hierarchical sharding with parent and child aggregators. +// Verifies that commitments submitted to children are included in child blocks, +// child root hashes are aggregated by the parent, and clients can retrieve +// valid joined inclusion proofs that chain child and parent merkle paths. +func (suite *ShardingE2ETestSuite) TestShardingE2E() { + ctx := context.Background() + _ = ctx + + parentCfg := suite.buildConfig(config.ShardingModeParent, "9000", "aggregator_test_parent", 0) + suite.startAggregatorInstance("parent aggregator", parentCfg) + parentURL := "http://localhost:9000" + + child0Cfg := suite.buildConfig(config.ShardingModeChild, "9001", "aggregator_test_child_0", 2) + suite.startAggregatorInstance("child aggregator 0 (shard 2)", child0Cfg) + child0URL := "http://localhost:9001" + + child1Cfg := suite.buildConfig(config.ShardingModeChild, "9002", "aggregator_test_child_1", 3) + suite.startAggregatorInstance("child aggregator 1 (shard 3)", child1Cfg) + child1URL := "http://localhost:9002" + + time.Sleep(500 * time.Millisecond) + + suite.T().Log("Phase 1: Submitting commitments...") + + commitment1, reqID1 := suite.createCommitmentForShard(2, 1) + submitTime1 := time.Now() + resp1, err := suite.submitCommitment(child0URL, commitment1) + suite.Require().NoError(err) + suite.Require().Equal("SUCCESS", resp1.Status) + suite.T().Logf(" Submitted commitment 1 to child 0: %s", reqID1) + + commitment2, reqID2 := suite.createCommitmentForShard(2, 1) + submitTime2 := time.Now() + resp2, err := suite.submitCommitment(child0URL, commitment2) + suite.Require().NoError(err) + suite.Require().Equal("SUCCESS", resp2.Status) + suite.T().Logf(" Submitted commitment 2 to child 0: %s", reqID2) + + commitment3, reqID3 := suite.createCommitmentForShard(3, 1) + submitTime3 := time.Now() + resp3, err := suite.submitCommitment(child1URL, commitment3) + suite.Require().NoError(err) + suite.Require().Equal("SUCCESS", resp3.Status) + suite.T().Logf(" Submitted commitment 3 to child 1: %s", reqID3) + + commitment4, reqID4 := suite.createCommitmentForShard(3, 1) + submitTime4 := time.Now() + resp4, err := suite.submitCommitment(child1URL, commitment4) + suite.Require().NoError(err) + suite.Require().Equal("SUCCESS", resp4.Status) + suite.T().Logf(" Submitted commitment 4 to child 1: %s", reqID4) + + suite.T().Log("✓ Submitted 2 commitments to child 0") + suite.T().Log("✓ Submitted 2 commitments to child 1") + + suite.T().Log("Phase 2: Waiting for parent block...") + suite.waitForBlock(parentURL, 1, 3*time.Second) + suite.T().Log("✓ Parent created block 1 (children submitted roots)") + + suite.T().Log("Phase 3: Verifying joined proofs...") + + testCases := []struct { + requestID string + childURL string + shardID int + name string + submitTime time.Time + }{ + {reqID1, child0URL, 2, "commitment 1 (child 0)", submitTime1}, + {reqID2, child0URL, 2, "commitment 2 (child 0)", submitTime2}, + {reqID3, child1URL, 3, "commitment 3 (child 1)", submitTime3}, + {reqID4, child1URL, 3, "commitment 4 (child 1)", submitTime4}, + } + + for _, tc := range testCases { + proofAvailableStart := time.Now() + childProofResp := suite.waitForProofAvailable(tc.childURL, tc.requestID, 500*time.Millisecond) + totalLatency := time.Since(tc.submitTime) + suite.T().Logf("%s: proof available after %v (total from submit: %v)", + tc.name, time.Since(proofAvailableStart), totalLatency) + suite.Require().NotNil(childProofResp.InclusionProof, "Inclusion proof is nil for %s", tc.name) + suite.Require().NotNil(childProofResp.InclusionProof.MerkleTreePath, "Merkle path is nil for %s", tc.name) + joinedProof := childProofResp.InclusionProof.MerkleTreePath + + reqID := api.RequestID(tc.requestID) + reqIDPath, err := reqID.GetPath() + suite.Require().NoError(err, "Failed to get path from requestID for %s", tc.name) + + result, err := joinedProof.Verify(reqIDPath) + suite.Require().NoError(err, "Proof verification failed for %s", tc.name) + suite.Require().NotNil(result, "Verification result is nil for %s", tc.name) + + suite.Require().True(result.PathValid, "Path not valid for %s", tc.name) + suite.Require().True(result.PathIncluded, "Path not included for %s", tc.name) + suite.Require().True(result.Result, "Overall verification failed for %s", tc.name) + + suite.T().Logf("✓ Verified joined proof for %s", tc.name) + } + + suite.T().Log("✓ All initial proofs verified successfully!") + + suite.T().Log("Phase 4: Testing with additional blocks...") + + commitment5, reqID5 := suite.createCommitmentForShard(2, 1) + submitTime5 := time.Now() + suite.submitCommitment(child0URL, commitment5) + + commitment6, reqID6 := suite.createCommitmentForShard(3, 1) + submitTime6 := time.Now() + suite.submitCommitment(child1URL, commitment6) + + suite.T().Log("✓ Submitted additional commitments") + + suite.T().Log("Verifying old commitments still work...") + for _, tc := range testCases { + childProofResp, err := suite.getInclusionProof(tc.childURL, tc.requestID) + suite.Require().NoError(err, "Failed to get proof for old %s", tc.name) + suite.Require().NotNil(childProofResp.InclusionProof) + + reqID := api.RequestID(tc.requestID) + reqIDPath, err := reqID.GetPath() + suite.Require().NoError(err) + + result, err := childProofResp.InclusionProof.MerkleTreePath.Verify(reqIDPath) + suite.Require().NoError(err, "Verification failed for old %s", tc.name) + suite.Require().True(result.Result, "Old commitment proof invalid for %s", tc.name) + } + suite.T().Log("✓ All old commitments still verify correctly") + + suite.T().Log("Verifying new commitments...") + newTestCases := []struct { + requestID string + childURL string + name string + submitTime time.Time + }{ + {reqID5, child0URL, "new commitment (child 0)", submitTime5}, + {reqID6, child1URL, "new commitment (child 1)", submitTime6}, + } + + for _, tc := range newTestCases { + proofAvailableStart := time.Now() + childProofResp := suite.waitForProofAvailable(tc.childURL, tc.requestID, 10*time.Second) + totalLatency := time.Since(tc.submitTime) + suite.T().Logf("%s: proof available after %v (total from submit: %v)", + tc.name, time.Since(proofAvailableStart), totalLatency) + suite.Require().NotNil(childProofResp.InclusionProof) + + reqID := api.RequestID(tc.requestID) + reqIDPath, err := reqID.GetPath() + suite.Require().NoError(err) + + result, err := childProofResp.InclusionProof.MerkleTreePath.Verify(reqIDPath) + suite.Require().NoError(err, "Verification failed for %s", tc.name) + suite.Require().True(result.Result, "New commitment proof invalid for %s", tc.name) + + suite.T().Logf("✓ Verified %s", tc.name) + } + + suite.T().Log("✓ All new commitments verify correctly!") + suite.T().Log("✓ E2E sharding test completed successfully - old and new proofs work!") +} + +func TestShardingE2E(t *testing.T) { + suite.Run(t, new(ShardingE2ETestSuite)) +}