Skip to content

Commit

Permalink
feat(coordinator): add blob bytes in batch proving task
Browse files Browse the repository at this point in the history
  • Loading branch information
colinlyguo committed Aug 18, 2024
1 parent 0983b9a commit 7cb9c06
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 15 deletions.
8 changes: 4 additions & 4 deletions common/types/message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"fmt"

"github.com/scroll-tech/da-codec/encoding/codecv3"
"github.com/scroll-tech/go-ethereum/common"
)

Expand Down Expand Up @@ -52,9 +51,10 @@ type ChunkTaskDetail struct {

// BatchTaskDetail is a type containing BatchTask detail.
type BatchTaskDetail struct {
ChunkInfos []*ChunkInfo `json:"chunk_infos"`
ChunkProofs []*ChunkProof `json:"chunk_proofs"`
BatchHeader *codecv3.DABatch `json:"batch_header"`
ChunkInfos []*ChunkInfo `json:"chunk_infos"`
ChunkProofs []*ChunkProof `json:"chunk_proofs"`
BatchHeader string `json:"batch_header"`
BlobBytes []byte `json:"blob_bytes"`
}

// BundleTaskDetail consists of all the information required to describe the task to generate a proof for a bundle of batches.
Expand Down
91 changes: 80 additions & 11 deletions coordinator/internal/logic/provertask/batch_prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (
"github.com/gin-gonic/gin"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/scroll-tech/da-codec/encoding"
"github.com/scroll-tech/da-codec/encoding/codecv3"
"github.com/scroll-tech/da-codec/encoding/codecv4"
"github.com/scroll-tech/go-ethereum/common"
"github.com/scroll-tech/go-ethereum/log"
"github.com/scroll-tech/go-ethereum/params"
Expand Down Expand Up @@ -208,17 +210,9 @@ func (bp *BatchProverTask) formatProverTask(ctx context.Context, task *orm.Prove
chunkInfos = append(chunkInfos, &chunkInfo)
}

taskDetail := message.BatchTaskDetail{
ChunkInfos: chunkInfos,
ChunkProofs: chunkProofs,
}

if hardForkName == "darwin" {
batchHeader, decodeErr := codecv3.NewDABatchFromBytes(batch.BatchHeader)
if decodeErr != nil {
return nil, fmt.Errorf("failed to decode batch header, taskID:%s err:%w", task.TaskID, decodeErr)
}
taskDetail.BatchHeader = batchHeader
taskDetail, err := bp.getBatchTaskDetail(ctx, hardForkName, batch, chunkInfos, chunkProofs)
if err != nil {
return nil, fmt.Errorf("failed to get batch task detail, taskID:%s err:%w", task.TaskID, err)
}

chunkProofsBytes, err := json.Marshal(taskDetail)
Expand All @@ -241,3 +235,78 @@ func (bp *BatchProverTask) recoverActiveAttempts(ctx *gin.Context, batchTask *or
log.Error("failed to recover batch active attempts", "hash", batchTask.Hash, "error", err)
}
}

func (bp *BatchProverTask) getBatchTaskDetail(ctx context.Context, hardForkName string, batch *orm.Batch, chunkInfos []*message.ChunkInfo, chunkProofs []*message.ChunkProof) (*message.BatchTaskDetail, error) {
taskDetail := &message.BatchTaskDetail{
ChunkInfos: chunkInfos,
ChunkProofs: chunkProofs,
}

if hardForkName != "darwin" {
return taskDetail, nil
}

dbChunks, err := bp.chunkOrm.GetChunksInRange(ctx, batch.StartChunkIndex, batch.EndChunkIndex)
if err != nil {
return nil, fmt.Errorf("failed to get chunks in range for batch %d (start: %d, end: %d): %w", batch.Index, batch.StartChunkIndex, batch.EndChunkIndex, err)
}

chunks := make([]*encoding.Chunk, len(dbChunks))
for i, c := range dbChunks {
blocks, getErr := bp.blockOrm.GetL2BlocksInRange(ctx, c.StartBlockNumber, c.EndBlockNumber)
if getErr != nil {
return nil, fmt.Errorf("failed to get blocks in range for batch %d, chunk %d, (start: %d, end: %d): %w", batch.Index, c.Index, c.StartBlockNumber, c.EndBlockNumber, getErr)
}
chunks[i] = &encoding.Chunk{Blocks: blocks}
}

dbParentBatch, getErr := bp.batchOrm.GetBatchByIndex(ctx, batch.Index-1)
if getErr != nil {
return nil, fmt.Errorf("failed to get parent batch header for batch %d: %w", batch.Index, getErr)
}

batchEncoding := &encoding.Batch{
Index: batch.Index,
TotalL1MessagePoppedBefore: dbChunks[0].TotalL1MessagesPoppedBefore,
ParentBatchHash: common.HexToHash(dbParentBatch.Hash),
Chunks: chunks,
}

if !batch.EnableEncode {
daBatch, createErr := codecv3.NewDABatch(batchEncoding)
if createErr != nil {
return nil, fmt.Errorf("failed to create DA batch (v3) for batch %d: %w", batch.Index, createErr)
}
taskDetail.BlobBytes = daBatch.Blob()[:]

batchHeader, decodeErr := codecv3.NewDABatchFromBytes(batch.BatchHeader)
if decodeErr != nil {
return nil, fmt.Errorf("failed to decode batch header (v3) for batch %d: %w", batch.Index, decodeErr)
}

jsonData, marshalErr := json.Marshal(batchHeader)
if marshalErr != nil {
return nil, fmt.Errorf("failed to marshal batch header (v3) for batch %d: %w", batch.Index, marshalErr)
}
taskDetail.BatchHeader = string(jsonData)
} else {
daBatch, createErr := codecv4.NewDABatch(batchEncoding, batch.EnableEncode)
if createErr != nil {
return nil, fmt.Errorf("failed to create DA batch (v4) for batch %d: %w", batch.Index, createErr)
}
taskDetail.BlobBytes = daBatch.Blob()[:]

batchHeader, decodeErr := codecv4.NewDABatchFromBytes(batch.BatchHeader)
if decodeErr != nil {
return nil, fmt.Errorf("failed to decode batch header (v4) for batch %d: %w", batch.Index, decodeErr)
}

jsonData, marshalErr := json.Marshal(batchHeader)
if marshalErr != nil {
return nil, fmt.Errorf("failed to marshal batch header (v4) for batch %d: %w", batch.Index, marshalErr)
}
taskDetail.BatchHeader = string(jsonData)
}

return taskDetail, nil
}
14 changes: 14 additions & 0 deletions coordinator/internal/orm/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type Batch struct {
WithdrawRoot string `json:"withdraw_root" gorm:"column:withdraw_root"`
ParentBatchHash string `json:"parent_batch_hash" gorm:"column:parent_batch_hash"`
BatchHeader []byte `json:"batch_header" gorm:"column:batch_header"`
EnableEncode bool `json:"enable_encode" gorm:"column:enable_encode"`

// proof
ChunkProofsStatus int16 `json:"chunk_proofs_status" gorm:"column:chunk_proofs_status;default:1"`
Expand Down Expand Up @@ -225,6 +226,19 @@ func (o *Batch) GetBatchesByBundleHash(ctx context.Context, bundleHash string) (
return batches, nil
}

// GetBatchByIndex retrieves the batch by the given index.
func (o *Batch) GetBatchByIndex(ctx context.Context, index uint64) (*Batch, error) {
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("index = ?", index)

var batch Batch
if err := db.First(&batch).Error; err != nil {
return nil, fmt.Errorf("Batch.GetBatchByIndex error: %w, index: %v", err, index)
}
return &batch, nil
}

// InsertBatch inserts a new batch into the database.
func (o *Batch) InsertBatch(ctx context.Context, batch *encoding.Batch, dbTX ...*gorm.DB) (*Batch, error) {
if batch == nil {
Expand Down
26 changes: 26 additions & 0 deletions coordinator/internal/orm/chunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,32 @@ func (o *Chunk) GetAttemptsByHash(ctx context.Context, hash string) (int16, int1
return chunk.ActiveAttempts, chunk.TotalAttempts, nil
}

// GetChunksInRange retrieves chunks within a given range (inclusive) from the database.
// The range is closed, i.e., it includes both start and end indices.
// The returned chunks are sorted in ascending order by their index.
func (o *Chunk) GetChunksInRange(ctx context.Context, startIndex uint64, endIndex uint64) ([]*Chunk, error) {
if startIndex > endIndex {
return nil, fmt.Errorf("Chunk.GetChunksInRange: start index should be less than or equal to end index, start index: %v, end index: %v", startIndex, endIndex)
}

db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("index >= ? AND index <= ?", startIndex, endIndex)
db = db.Order("index ASC")

var chunks []*Chunk
if err := db.Find(&chunks).Error; err != nil {
return nil, fmt.Errorf("Chunk.GetChunksInRange error: %w, start index: %v, end index: %v", err, startIndex, endIndex)
}

// sanity check
if uint64(len(chunks)) != endIndex-startIndex+1 {
return nil, fmt.Errorf("Chunk.GetChunksInRange: incorrect number of chunks, expected: %v, got: %v, start index: %v, end index: %v", endIndex-startIndex+1, len(chunks), startIndex, endIndex)
}

return chunks, nil
}

// InsertChunk inserts a new chunk into the database.
// for unit test
func (o *Chunk) InsertChunk(ctx context.Context, chunk *encoding.Chunk, dbTX ...*gorm.DB) (*Chunk, error) {
Expand Down
49 changes: 49 additions & 0 deletions coordinator/internal/orm/l2_block.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,55 @@ func (o *L2Block) GetL2BlockByNumber(ctx context.Context, blockNumber uint64) (*
return &l2Block, nil
}

// GetL2BlocksInRange retrieves the L2 blocks within the specified range (inclusive).
// The range is closed, i.e., it includes both start and end block numbers.
// The returned blocks are sorted in ascending order by their block number.
func (o *L2Block) GetL2BlocksInRange(ctx context.Context, startBlockNumber uint64, endBlockNumber uint64) ([]*encoding.Block, error) {
if startBlockNumber > endBlockNumber {
return nil, fmt.Errorf("L2Block.GetL2BlocksInRange: start block number should be less than or equal to end block number, start block: %v, end block: %v", startBlockNumber, endBlockNumber)
}

db := o.db.WithContext(ctx)
db = db.Model(&L2Block{})
db = db.Select("header, transactions, withdraw_root, row_consumption")
db = db.Where("number >= ? AND number <= ?", startBlockNumber, endBlockNumber)
db = db.Order("number ASC")

var l2Blocks []L2Block
if err := db.Find(&l2Blocks).Error; err != nil {
return nil, fmt.Errorf("L2Block.GetL2BlocksInRange error: %w, start block: %v, end block: %v", err, startBlockNumber, endBlockNumber)
}

// sanity check
if uint64(len(l2Blocks)) != endBlockNumber-startBlockNumber+1 {
return nil, fmt.Errorf("L2Block.GetL2BlocksInRange: unexpected number of results, expected: %v, got: %v", endBlockNumber-startBlockNumber+1, len(l2Blocks))
}

var blocks []*encoding.Block
for _, v := range l2Blocks {
var block encoding.Block

if err := json.Unmarshal([]byte(v.Transactions), &block.Transactions); err != nil {
return nil, fmt.Errorf("L2Block.GetL2BlocksInRange error: %w, start block: %v, end block: %v", err, startBlockNumber, endBlockNumber)
}

block.Header = &gethTypes.Header{}
if err := json.Unmarshal([]byte(v.Header), block.Header); err != nil {
return nil, fmt.Errorf("L2Block.GetL2BlocksInRange error: %w, start block: %v, end block: %v", err, startBlockNumber, endBlockNumber)
}

block.WithdrawRoot = common.HexToHash(v.WithdrawRoot)

if err := json.Unmarshal([]byte(v.RowConsumption), &block.RowConsumption); err != nil {
return nil, fmt.Errorf("L2Block.GetL2BlocksInRange error: %w, start block: %v, end block: %v", err, startBlockNumber, endBlockNumber)
}

blocks = append(blocks, &block)
}

return blocks, nil
}

// InsertL2Blocks inserts l2 blocks into the "l2_block" table.
// for unit test
func (o *L2Block) InsertL2Blocks(ctx context.Context, blocks []*encoding.Block) error {
Expand Down

0 comments on commit 7cb9c06

Please sign in to comment.