diff --git a/common/types/message/message.go b/common/types/message/message.go index 00dd867dce..9e912c5ab6 100644 --- a/common/types/message/message.go +++ b/common/types/message/message.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" - "github.com/scroll-tech/da-codec/encoding/codecv3" "github.com/scroll-tech/go-ethereum/common" ) @@ -52,9 +51,10 @@ type ChunkTaskDetail struct { // BatchTaskDetail is a type containing BatchTask detail. type BatchTaskDetail struct { - ChunkInfos []*ChunkInfo `json:"chunk_infos"` - ChunkProofs []*ChunkProof `json:"chunk_proofs"` - BatchHeader *codecv3.DABatch `json:"batch_header"` + ChunkInfos []*ChunkInfo `json:"chunk_infos"` + ChunkProofs []*ChunkProof `json:"chunk_proofs"` + BatchHeader string `json:"batch_header"` + BlobBytes []byte `json:"blob_bytes"` } // BundleTaskDetail consists of all the information required to describe the task to generate a proof for a bundle of batches. diff --git a/coordinator/internal/logic/provertask/batch_prover_task.go b/coordinator/internal/logic/provertask/batch_prover_task.go index cfc649c029..97d59ddb72 100644 --- a/coordinator/internal/logic/provertask/batch_prover_task.go +++ b/coordinator/internal/logic/provertask/batch_prover_task.go @@ -9,7 +9,9 @@ import ( "github.com/gin-gonic/gin" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/scroll-tech/da-codec/encoding" "github.com/scroll-tech/da-codec/encoding/codecv3" + "github.com/scroll-tech/da-codec/encoding/codecv4" "github.com/scroll-tech/go-ethereum/common" "github.com/scroll-tech/go-ethereum/log" "github.com/scroll-tech/go-ethereum/params" @@ -208,17 +210,9 @@ func (bp *BatchProverTask) formatProverTask(ctx context.Context, task *orm.Prove chunkInfos = append(chunkInfos, &chunkInfo) } - taskDetail := message.BatchTaskDetail{ - ChunkInfos: chunkInfos, - ChunkProofs: chunkProofs, - } - - if hardForkName == "darwin" { - batchHeader, decodeErr := codecv3.NewDABatchFromBytes(batch.BatchHeader) - if decodeErr != nil { - return nil, fmt.Errorf("failed to decode batch header, taskID:%s err:%w", task.TaskID, decodeErr) - } - taskDetail.BatchHeader = batchHeader + taskDetail, err := bp.getBatchTaskDetail(ctx, hardForkName, batch, chunkInfos, chunkProofs) + if err != nil { + return nil, fmt.Errorf("failed to get batch task detail, taskID:%s err:%w", task.TaskID, err) } chunkProofsBytes, err := json.Marshal(taskDetail) @@ -241,3 +235,78 @@ func (bp *BatchProverTask) recoverActiveAttempts(ctx *gin.Context, batchTask *or log.Error("failed to recover batch active attempts", "hash", batchTask.Hash, "error", err) } } + +func (bp *BatchProverTask) getBatchTaskDetail(ctx context.Context, hardForkName string, batch *orm.Batch, chunkInfos []*message.ChunkInfo, chunkProofs []*message.ChunkProof) (*message.BatchTaskDetail, error) { + taskDetail := &message.BatchTaskDetail{ + ChunkInfos: chunkInfos, + ChunkProofs: chunkProofs, + } + + if hardForkName != "darwin" { + return taskDetail, nil + } + + dbChunks, err := bp.chunkOrm.GetChunksInRange(ctx, batch.StartChunkIndex, batch.EndChunkIndex) + if err != nil { + return nil, fmt.Errorf("failed to get chunks in range for batch %d (start: %d, end: %d): %w", batch.Index, batch.StartChunkIndex, batch.EndChunkIndex, err) + } + + chunks := make([]*encoding.Chunk, len(dbChunks)) + for i, c := range dbChunks { + blocks, getErr := bp.blockOrm.GetL2BlocksInRange(ctx, c.StartBlockNumber, c.EndBlockNumber) + if getErr != nil { + return nil, fmt.Errorf("failed to get blocks in range for batch %d, chunk %d, (start: %d, end: %d): %w", batch.Index, c.Index, c.StartBlockNumber, c.EndBlockNumber, getErr) + } + chunks[i] = &encoding.Chunk{Blocks: blocks} + } + + dbParentBatch, getErr := bp.batchOrm.GetBatchByIndex(ctx, batch.Index-1) + if getErr != nil { + return nil, fmt.Errorf("failed to get parent batch header for batch %d: %w", batch.Index, getErr) + } + + batchEncoding := &encoding.Batch{ + Index: batch.Index, + TotalL1MessagePoppedBefore: dbChunks[0].TotalL1MessagesPoppedBefore, + ParentBatchHash: common.HexToHash(dbParentBatch.Hash), + Chunks: chunks, + } + + if !batch.EnableEncode { + daBatch, createErr := codecv3.NewDABatch(batchEncoding) + if createErr != nil { + return nil, fmt.Errorf("failed to create DA batch (v3) for batch %d: %w", batch.Index, createErr) + } + taskDetail.BlobBytes = daBatch.Blob()[:] + + batchHeader, decodeErr := codecv3.NewDABatchFromBytes(batch.BatchHeader) + if decodeErr != nil { + return nil, fmt.Errorf("failed to decode batch header (v3) for batch %d: %w", batch.Index, decodeErr) + } + + jsonData, marshalErr := json.Marshal(batchHeader) + if marshalErr != nil { + return nil, fmt.Errorf("failed to marshal batch header (v3) for batch %d: %w", batch.Index, marshalErr) + } + taskDetail.BatchHeader = string(jsonData) + } else { + daBatch, createErr := codecv4.NewDABatch(batchEncoding, batch.EnableEncode) + if createErr != nil { + return nil, fmt.Errorf("failed to create DA batch (v4) for batch %d: %w", batch.Index, createErr) + } + taskDetail.BlobBytes = daBatch.Blob()[:] + + batchHeader, decodeErr := codecv4.NewDABatchFromBytes(batch.BatchHeader) + if decodeErr != nil { + return nil, fmt.Errorf("failed to decode batch header (v4) for batch %d: %w", batch.Index, decodeErr) + } + + jsonData, marshalErr := json.Marshal(batchHeader) + if marshalErr != nil { + return nil, fmt.Errorf("failed to marshal batch header (v4) for batch %d: %w", batch.Index, marshalErr) + } + taskDetail.BatchHeader = string(jsonData) + } + + return taskDetail, nil +} diff --git a/coordinator/internal/orm/batch.go b/coordinator/internal/orm/batch.go index a4f8bd77dc..12dc0338d1 100644 --- a/coordinator/internal/orm/batch.go +++ b/coordinator/internal/orm/batch.go @@ -31,6 +31,7 @@ type Batch struct { WithdrawRoot string `json:"withdraw_root" gorm:"column:withdraw_root"` ParentBatchHash string `json:"parent_batch_hash" gorm:"column:parent_batch_hash"` BatchHeader []byte `json:"batch_header" gorm:"column:batch_header"` + EnableEncode bool `json:"enable_encode" gorm:"column:enable_encode"` // proof ChunkProofsStatus int16 `json:"chunk_proofs_status" gorm:"column:chunk_proofs_status;default:1"` @@ -225,6 +226,19 @@ func (o *Batch) GetBatchesByBundleHash(ctx context.Context, bundleHash string) ( return batches, nil } +// GetBatchByIndex retrieves the batch by the given index. +func (o *Batch) GetBatchByIndex(ctx context.Context, index uint64) (*Batch, error) { + db := o.db.WithContext(ctx) + db = db.Model(&Batch{}) + db = db.Where("index = ?", index) + + var batch Batch + if err := db.First(&batch).Error; err != nil { + return nil, fmt.Errorf("Batch.GetBatchByIndex error: %w, index: %v", err, index) + } + return &batch, nil +} + // InsertBatch inserts a new batch into the database. func (o *Batch) InsertBatch(ctx context.Context, batch *encoding.Batch, dbTX ...*gorm.DB) (*Batch, error) { if batch == nil { diff --git a/coordinator/internal/orm/chunk.go b/coordinator/internal/orm/chunk.go index a0d701b937..58b4b477ba 100644 --- a/coordinator/internal/orm/chunk.go +++ b/coordinator/internal/orm/chunk.go @@ -231,6 +231,32 @@ func (o *Chunk) GetAttemptsByHash(ctx context.Context, hash string) (int16, int1 return chunk.ActiveAttempts, chunk.TotalAttempts, nil } +// GetChunksInRange retrieves chunks within a given range (inclusive) from the database. +// The range is closed, i.e., it includes both start and end indices. +// The returned chunks are sorted in ascending order by their index. +func (o *Chunk) GetChunksInRange(ctx context.Context, startIndex uint64, endIndex uint64) ([]*Chunk, error) { + if startIndex > endIndex { + return nil, fmt.Errorf("Chunk.GetChunksInRange: start index should be less than or equal to end index, start index: %v, end index: %v", startIndex, endIndex) + } + + db := o.db.WithContext(ctx) + db = db.Model(&Chunk{}) + db = db.Where("index >= ? AND index <= ?", startIndex, endIndex) + db = db.Order("index ASC") + + var chunks []*Chunk + if err := db.Find(&chunks).Error; err != nil { + return nil, fmt.Errorf("Chunk.GetChunksInRange error: %w, start index: %v, end index: %v", err, startIndex, endIndex) + } + + // sanity check + if uint64(len(chunks)) != endIndex-startIndex+1 { + return nil, fmt.Errorf("Chunk.GetChunksInRange: incorrect number of chunks, expected: %v, got: %v, start index: %v, end index: %v", endIndex-startIndex+1, len(chunks), startIndex, endIndex) + } + + return chunks, nil +} + // InsertChunk inserts a new chunk into the database. // for unit test func (o *Chunk) InsertChunk(ctx context.Context, chunk *encoding.Chunk, dbTX ...*gorm.DB) (*Chunk, error) { diff --git a/coordinator/internal/orm/l2_block.go b/coordinator/internal/orm/l2_block.go index f3790c879c..0990494f8c 100644 --- a/coordinator/internal/orm/l2_block.go +++ b/coordinator/internal/orm/l2_block.go @@ -87,6 +87,55 @@ func (o *L2Block) GetL2BlockByNumber(ctx context.Context, blockNumber uint64) (* return &l2Block, nil } +// GetL2BlocksInRange retrieves the L2 blocks within the specified range (inclusive). +// The range is closed, i.e., it includes both start and end block numbers. +// The returned blocks are sorted in ascending order by their block number. +func (o *L2Block) GetL2BlocksInRange(ctx context.Context, startBlockNumber uint64, endBlockNumber uint64) ([]*encoding.Block, error) { + if startBlockNumber > endBlockNumber { + return nil, fmt.Errorf("L2Block.GetL2BlocksInRange: start block number should be less than or equal to end block number, start block: %v, end block: %v", startBlockNumber, endBlockNumber) + } + + db := o.db.WithContext(ctx) + db = db.Model(&L2Block{}) + db = db.Select("header, transactions, withdraw_root, row_consumption") + db = db.Where("number >= ? AND number <= ?", startBlockNumber, endBlockNumber) + db = db.Order("number ASC") + + var l2Blocks []L2Block + if err := db.Find(&l2Blocks).Error; err != nil { + return nil, fmt.Errorf("L2Block.GetL2BlocksInRange error: %w, start block: %v, end block: %v", err, startBlockNumber, endBlockNumber) + } + + // sanity check + if uint64(len(l2Blocks)) != endBlockNumber-startBlockNumber+1 { + return nil, fmt.Errorf("L2Block.GetL2BlocksInRange: unexpected number of results, expected: %v, got: %v", endBlockNumber-startBlockNumber+1, len(l2Blocks)) + } + + var blocks []*encoding.Block + for _, v := range l2Blocks { + var block encoding.Block + + if err := json.Unmarshal([]byte(v.Transactions), &block.Transactions); err != nil { + return nil, fmt.Errorf("L2Block.GetL2BlocksInRange error: %w, start block: %v, end block: %v", err, startBlockNumber, endBlockNumber) + } + + block.Header = &gethTypes.Header{} + if err := json.Unmarshal([]byte(v.Header), block.Header); err != nil { + return nil, fmt.Errorf("L2Block.GetL2BlocksInRange error: %w, start block: %v, end block: %v", err, startBlockNumber, endBlockNumber) + } + + block.WithdrawRoot = common.HexToHash(v.WithdrawRoot) + + if err := json.Unmarshal([]byte(v.RowConsumption), &block.RowConsumption); err != nil { + return nil, fmt.Errorf("L2Block.GetL2BlocksInRange error: %w, start block: %v, end block: %v", err, startBlockNumber, endBlockNumber) + } + + blocks = append(blocks, &block) + } + + return blocks, nil +} + // InsertL2Blocks inserts l2 blocks into the "l2_block" table. // for unit test func (o *L2Block) InsertL2Blocks(ctx context.Context, blocks []*encoding.Block) error {