Skip to content

Commit

Permalink
simplify logic
Browse files Browse the repository at this point in the history
  • Loading branch information
colinlyguo committed Aug 19, 2024
1 parent 873821e commit 2fb7ba3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 47 deletions.
37 changes: 16 additions & 21 deletions coordinator/internal/logic/provertask/batch_prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (bp *BatchProverTask) formatProverTask(ctx context.Context, task *orm.Prove
chunkInfos = append(chunkInfos, &chunkInfo)
}

taskDetail, err := bp.getBatchTaskDetail(ctx, hardForkName, batch, chunkInfos, chunkProofs)
taskDetail, err := bp.getBatchTaskDetail(ctx, hardForkName, batch, chunks, chunkInfos, chunkProofs)
if err != nil {
return nil, fmt.Errorf("failed to get batch task detail, taskID:%s err:%w", task.TaskID, err)
}
Expand All @@ -236,7 +236,7 @@ func (bp *BatchProverTask) recoverActiveAttempts(ctx *gin.Context, batchTask *or
}
}

func (bp *BatchProverTask) getBatchTaskDetail(ctx context.Context, hardForkName string, batch *orm.Batch, chunkInfos []*message.ChunkInfo, chunkProofs []*message.ChunkProof) (*message.BatchTaskDetail, error) {
func (bp *BatchProverTask) getBatchTaskDetail(ctx context.Context, hardForkName string, dbBatch *orm.Batch, dbChunks []*orm.Chunk, chunkInfos []*message.ChunkInfo, chunkProofs []*message.ChunkProof) (*message.BatchTaskDetail, error) {
taskDetail := &message.BatchTaskDetail{
ChunkInfos: chunkInfos,
ChunkProofs: chunkProofs,
Expand All @@ -246,64 +246,59 @@ func (bp *BatchProverTask) getBatchTaskDetail(ctx context.Context, hardForkName
return taskDetail, nil
}

dbChunks, err := bp.chunkOrm.GetChunksInRange(ctx, batch.StartChunkIndex, batch.EndChunkIndex)
if err != nil {
return nil, fmt.Errorf("failed to get chunks in range for batch %d (start: %d, end: %d): %w", batch.Index, batch.StartChunkIndex, batch.EndChunkIndex, err)
}

chunks := make([]*encoding.Chunk, len(dbChunks))
for i, c := range dbChunks {
blocks, getErr := bp.blockOrm.GetL2BlocksInRange(ctx, c.StartBlockNumber, c.EndBlockNumber)
if getErr != nil {
return nil, fmt.Errorf("failed to get blocks in range for batch %d, chunk %d, (start: %d, end: %d): %w", batch.Index, c.Index, c.StartBlockNumber, c.EndBlockNumber, getErr)
return nil, fmt.Errorf("failed to get blocks in range for batch %d, chunk %d, (start: %d, end: %d): %w", dbBatch.Index, c.Index, c.StartBlockNumber, c.EndBlockNumber, getErr)
}
chunks[i] = &encoding.Chunk{Blocks: blocks}
}

dbParentBatch, getErr := bp.batchOrm.GetBatchByIndex(ctx, batch.Index-1)
dbParentBatch, getErr := bp.batchOrm.GetBatchByIndex(ctx, dbBatch.Index-1)
if getErr != nil {
return nil, fmt.Errorf("failed to get parent batch header for batch %d: %w", batch.Index, getErr)
return nil, fmt.Errorf("failed to get parent batch header for batch %d: %w", dbBatch.Index, getErr)
}

batchEncoding := &encoding.Batch{
Index: batch.Index,
Index: dbBatch.Index,
TotalL1MessagePoppedBefore: dbChunks[0].TotalL1MessagesPoppedBefore,
ParentBatchHash: common.HexToHash(dbParentBatch.Hash),
Chunks: chunks,
}

if !batch.EnableEncode {
if !dbBatch.EnableEncode {
daBatch, createErr := codecv3.NewDABatch(batchEncoding)
if createErr != nil {
return nil, fmt.Errorf("failed to create DA batch (v3) for batch %d: %w", batch.Index, createErr)
return nil, fmt.Errorf("failed to create DA batch (v3) for batch %d: %w", dbBatch.Index, createErr)
}
taskDetail.BlobBytes = daBatch.Blob()[:]

batchHeader, decodeErr := codecv3.NewDABatchFromBytes(batch.BatchHeader)
batchHeader, decodeErr := codecv3.NewDABatchFromBytes(dbBatch.BatchHeader)
if decodeErr != nil {
return nil, fmt.Errorf("failed to decode batch header (v3) for batch %d: %w", batch.Index, decodeErr)
return nil, fmt.Errorf("failed to decode batch header (v3) for batch %d: %w", dbBatch.Index, decodeErr)
}

jsonData, marshalErr := json.Marshal(batchHeader)
if marshalErr != nil {
return nil, fmt.Errorf("failed to marshal batch header (v3) for batch %d: %w", batch.Index, marshalErr)
return nil, fmt.Errorf("failed to marshal batch header (v3) for batch %d: %w", dbBatch.Index, marshalErr)
}
taskDetail.BatchHeader = string(jsonData)
} else {
daBatch, createErr := codecv4.NewDABatch(batchEncoding, batch.EnableEncode)
daBatch, createErr := codecv4.NewDABatch(batchEncoding, dbBatch.EnableEncode)
if createErr != nil {
return nil, fmt.Errorf("failed to create DA batch (v4) for batch %d: %w", batch.Index, createErr)
return nil, fmt.Errorf("failed to create DA batch (v4) for batch %d: %w", dbBatch.Index, createErr)
}
taskDetail.BlobBytes = daBatch.Blob()[:]

batchHeader, decodeErr := codecv4.NewDABatchFromBytes(batch.BatchHeader)
batchHeader, decodeErr := codecv4.NewDABatchFromBytes(dbBatch.BatchHeader)
if decodeErr != nil {
return nil, fmt.Errorf("failed to decode batch header (v4) for batch %d: %w", batch.Index, decodeErr)
return nil, fmt.Errorf("failed to decode batch header (v4) for batch %d: %w", dbBatch.Index, decodeErr)
}

jsonData, marshalErr := json.Marshal(batchHeader)
if marshalErr != nil {
return nil, fmt.Errorf("failed to marshal batch header (v4) for batch %d: %w", batch.Index, marshalErr)
return nil, fmt.Errorf("failed to marshal batch header (v4) for batch %d: %w", dbBatch.Index, marshalErr)
}
taskDetail.BatchHeader = string(jsonData)
}
Expand Down
26 changes: 0 additions & 26 deletions coordinator/internal/orm/chunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,32 +231,6 @@ func (o *Chunk) GetAttemptsByHash(ctx context.Context, hash string) (int16, int1
return chunk.ActiveAttempts, chunk.TotalAttempts, nil
}

// GetChunksInRange retrieves chunks within a given range (inclusive) from the database.
// The range is closed, i.e., it includes both start and end indices.
// The returned chunks are sorted in ascending order by their index.
func (o *Chunk) GetChunksInRange(ctx context.Context, startIndex uint64, endIndex uint64) ([]*Chunk, error) {
if startIndex > endIndex {
return nil, fmt.Errorf("Chunk.GetChunksInRange: start index should be less than or equal to end index, start index: %v, end index: %v", startIndex, endIndex)
}

db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("index >= ? AND index <= ?", startIndex, endIndex)
db = db.Order("index ASC")

var chunks []*Chunk
if err := db.Find(&chunks).Error; err != nil {
return nil, fmt.Errorf("Chunk.GetChunksInRange error: %w, start index: %v, end index: %v", err, startIndex, endIndex)
}

// sanity check
if uint64(len(chunks)) != endIndex-startIndex+1 {
return nil, fmt.Errorf("Chunk.GetChunksInRange: incorrect number of chunks, expected: %v, got: %v, start index: %v, end index: %v", endIndex-startIndex+1, len(chunks), startIndex, endIndex)
}

return chunks, nil
}

// InsertChunk inserts a new chunk into the database.
// for unit test
func (o *Chunk) InsertChunk(ctx context.Context, chunk *encoding.Chunk, dbTX ...*gorm.DB) (*Chunk, error) {
Expand Down

0 comments on commit 2fb7ba3

Please sign in to comment.