From 9bd1fb2020c4aa2b0696fb37a9d9a5477357b415 Mon Sep 17 00:00:00 2001 From: Tanuj Nayak Date: Wed, 8 Oct 2025 20:38:00 -0700 Subject: [PATCH] [ENH]: Add operator to finalize a task's completion --- Cargo.lock | 3 + Cargo.toml | 2 +- examples/task_api_example.py | 3 + go/pkg/sysdb/coordinator/coordinator.go | 12 + go/pkg/sysdb/coordinator/model/collection.go | 5 + go/pkg/sysdb/coordinator/table_catalog.go | 56 +- go/pkg/sysdb/coordinator/task.go | 245 ++++- go/pkg/sysdb/grpc/collection_service.go | 104 +++ go/pkg/sysdb/grpc/task_service.go | 39 + go/pkg/sysdb/metastore/db/dao/collection.go | 34 + go/pkg/sysdb/metastore/db/dao/database.go | 17 + go/pkg/sysdb/metastore/db/dao/task.go | 148 ++- go/pkg/sysdb/metastore/db/dao/task_test.go | 71 +- .../sysdb/metastore/db/dbmodel/collection.go | 2 + go/pkg/sysdb/metastore/db/dbmodel/database.go | 1 + go/pkg/sysdb/metastore/db/dbmodel/task.go | 52 +- .../db/migrations/20251013000000.sql | 7 + .../sysdb/metastore/db/migrations/atlas.sum | 3 +- idl/chromadb/proto/coordinator.proto | 88 +- idl/chromadb/proto/heapservice.proto | 12 + rust/log-service/src/lib.rs | 5 +- rust/log/src/in_memory_log.rs | 2 +- rust/s3heap-service/Cargo.toml | 1 + rust/s3heap-service/src/lib.rs | 183 +++- rust/s3heap-service/src/scheduler.rs | 4 +- rust/sysdb/src/bin/chroma-task-manager.rs | 10 +- rust/sysdb/src/sysdb.rs | 552 +++++++++++- rust/sysdb/src/test_sysdb.rs | 11 + rust/types/src/execution/operator.rs | 2 +- rust/types/src/flush.rs | 108 ++- rust/types/src/task.rs | 124 ++- rust/worker/Cargo.toml | 2 + .../src/compactor/compaction_manager.rs | 268 +++++- rust/worker/src/compactor/config.rs | 77 ++ rust/worker/src/compactor/scheduler.rs | 165 ++-- rust/worker/src/compactor/tasks.rs | 13 +- rust/worker/src/compactor/types.rs | 4 +- rust/worker/src/config.rs | 2 + .../src/execution/operators/execute_task.rs | 192 ++++ .../src/execution/operators/finish_task.rs | 138 +++ .../operators/get_collection_and_segments.rs | 44 +- rust/worker/src/execution/operators/mod.rs | 7 +- .../src/execution/operators/prepare_task.rs | 257 ++++++ .../src/execution/operators/register.rs | 133 ++- .../src/execution/orchestration/compact.rs | 851 ++++++++++++++++-- rust/worker/src/lib.rs | 34 + 46 files changed, 3754 insertions(+), 339 deletions(-) create mode 100644 go/pkg/sysdb/metastore/db/migrations/20251013000000.sql create mode 100644 rust/worker/src/execution/operators/execute_task.rs create mode 100644 rust/worker/src/execution/operators/finish_task.rs create mode 100644 rust/worker/src/execution/operators/prepare_task.rs diff --git a/Cargo.lock b/Cargo.lock index bcb12e5b477..ca386a36f8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7176,6 +7176,7 @@ dependencies = [ "tokio", "tonic", "tonic-health", + "tower 0.4.13", "tracing", "uuid", "wal3", @@ -9994,6 +9995,7 @@ dependencies = [ "chroma-config", "chroma-distance", "chroma-error", + "chroma-frontend", "chroma-index", "chroma-jemalloc-pprof-server", "chroma-log", @@ -10024,6 +10026,7 @@ dependencies = [ "rand_xorshift", "random-port", "regex", + "reqwest", "roaring", "s3heap", "s3heap-service", diff --git a/Cargo.toml b/Cargo.toml index fdde38326e4..75fc8d1baf5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,7 +51,7 @@ tracing = { version = "0.1" } tracing-bunyan-formatter = "0.3" tracing-opentelemetry = "0.28.0" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -uuid = { version = "1.11.0", features = ["v4", "fast-rng", "macro-diagnostics", "serde"] } +uuid = { version = "1.11.0", features = ["v4", "v7", "fast-rng", "macro-diagnostics", "serde"] } utoipa = { version = "5.0.0", features = ["macros", "axum_extras", "debug", "uuid"] } sqlx = { version = "0.8.3", features = ["runtime-tokio", "sqlite", "postgres", "chrono"] } sha2 = "0.10.8" diff --git a/examples/task_api_example.py b/examples/task_api_example.py index f71fd911d35..a5f8aeec395 100644 --- a/examples/task_api_example.py +++ b/examples/task_api_example.py @@ -7,6 +7,7 @@ """ import chromadb +import time # Connect to Chroma server client = chromadb.HttpClient(host="localhost", port=8000) @@ -60,6 +61,8 @@ print("Task is now registered and will run on new data!") print("=" * 60) +time.sleep(10) + # Add more documents to trigger task execution print("\nAdding more documents...") collection.add( diff --git a/go/pkg/sysdb/coordinator/coordinator.go b/go/pkg/sysdb/coordinator/coordinator.go index e60198edf44..dc123a01d0f 100644 --- a/go/pkg/sysdb/coordinator/coordinator.go +++ b/go/pkg/sysdb/coordinator/coordinator.go @@ -12,6 +12,7 @@ import ( "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel" s3metastore "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/s3" "github.com/chroma-core/chroma/go/pkg/types" + "github.com/google/uuid" "github.com/pingcap/log" "go.uber.org/zap" ) @@ -243,6 +244,17 @@ func (s *Coordinator) FlushCollectionCompaction(ctx context.Context, flushCollec return s.catalog.FlushCollectionCompaction(ctx, flushCollectionCompaction) } +func (s *Coordinator) FlushCollectionCompactionAndTask( + ctx context.Context, + flushCollectionCompaction *model.FlushCollectionCompaction, + taskID uuid.UUID, + taskRunNonce uuid.UUID, + completionOffset int64, + nextRunDelaySecs uint64, +) (*model.FlushCollectionInfo, error) { + return s.catalog.FlushCollectionCompactionAndTask(ctx, flushCollectionCompaction, taskID, taskRunNonce, completionOffset, nextRunDelaySecs) +} + func (s *Coordinator) ListCollectionsToGc(ctx context.Context, cutoffTimeSecs *uint64, limit *uint64, tenantID *string, minVersionsIfAlive *uint64) ([]*model.CollectionToGc, error) { return s.catalog.ListCollectionsToGc(ctx, cutoffTimeSecs, limit, tenantID, minVersionsIfAlive) } diff --git a/go/pkg/sysdb/coordinator/model/collection.go b/go/pkg/sysdb/coordinator/model/collection.go index 4cf8a543897..49dd42a80d5 100644 --- a/go/pkg/sysdb/coordinator/model/collection.go +++ b/go/pkg/sysdb/coordinator/model/collection.go @@ -4,6 +4,7 @@ import ( "time" "github.com/chroma-core/chroma/go/pkg/types" + "github.com/google/uuid" ) type Collection struct { @@ -98,6 +99,10 @@ type FlushCollectionInfo struct { ID string CollectionVersion int32 TenantLastCompactionTime int64 + // Optional task fields (only populated for task-based compactions) + TaskNextNonce *uuid.UUID + TaskNextRun *time.Time + TaskCompletionOffset *int64 } func FilterCollection(collection *Collection, collectionID types.UniqueID, collectionName *string) bool { diff --git a/go/pkg/sysdb/coordinator/table_catalog.go b/go/pkg/sysdb/coordinator/table_catalog.go index 99ff7a4e3ba..d4216077ed8 100644 --- a/go/pkg/sysdb/coordinator/table_catalog.go +++ b/go/pkg/sysdb/coordinator/table_catalog.go @@ -280,7 +280,7 @@ func (tc *Catalog) createCollectionImpl(txCtx context.Context, createCollection return nil, false, err } if len(databases) == 0 { - log.Error("database not found", zap.Error(err)) + log.Error("database not found for database", zap.String("database_name", databaseName), zap.String("tenant_id", tenantID)) return nil, false, common.ErrDatabaseNotFound } @@ -1343,7 +1343,7 @@ func (tc *Catalog) CreateCollectionAndSegments(ctx context.Context, createCollec return nil, false, err } if len(databases) == 0 { - log.Error("database not found", zap.Error(err)) + log.Error("database not found for database", zap.String("database_name", createCollection.DatabaseName), zap.String("tenant_id", createCollection.TenantID)) return nil, false, common.ErrDatabaseNotFound } @@ -1719,6 +1719,58 @@ func (tc *Catalog) FlushCollectionCompaction(ctx context.Context, flushCollectio return flushCollectionInfo, nil } +// FlushCollectionCompactionAndTask atomically updates collection compaction data and task completion offset. +// NOTE: This does NOT advance next_nonce - that is done separately by AdvanceTask in PrepareTask. +// This only updates the completion_offset to record how far we've processed. +// This is only supported for versioned collections (the modern/default path). +func (tc *Catalog) FlushCollectionCompactionAndTask( + ctx context.Context, + flushCollectionCompaction *model.FlushCollectionCompaction, + taskID uuid.UUID, + taskRunNonce uuid.UUID, + completionOffset int64, + nextRunDelaySecs uint64, +) (*model.FlushCollectionInfo, error) { + if !tc.versionFileEnabled { + // Task-based compactions are only supported with versioned collections + log.Error("FlushCollectionCompactionAndTask is only supported for versioned collections") + return nil, errors.New("task-based compaction requires versioned collections") + } + + var flushCollectionInfo *model.FlushCollectionInfo + + err := tc.txImpl.Transaction(ctx, func(txCtx context.Context) error { + var err error + flushCollectionInfo, err = tc.FlushCollectionCompactionForVersionedCollection(txCtx, flushCollectionCompaction) + if err != nil { + return err + } + + // Update ONLY completion_offset - next_nonce was already advanced in PrepareTask + // We still validate taskRunNonce to ensure we're updating the correct epoch + err = tc.metaDomain.TaskDb(txCtx).UpdateCompletionOffset(taskID, taskRunNonce, completionOffset) + if err != nil { + return err + } + + return nil + }) + + if err != nil { + return nil, err + } + + // Populate task fields with authoritative values from database + flushCollectionInfo.TaskCompletionOffset = &completionOffset + + log.Info("FlushCollectionCompactionAndTask", + zap.String("collection_id", flushCollectionCompaction.ID.String()), + zap.String("task_id", taskID.String()), + zap.Int64("completion_offset", completionOffset)) + + return flushCollectionInfo, nil +} + func (tc *Catalog) validateVersionFile(versionFile *coordinatorpb.CollectionVersionFile, collectionID string, version int64) error { if versionFile.GetCollectionInfoImmutable().GetCollectionId() != collectionID { log.Error("collection id mismatch", zap.String("collection_id", collectionID), zap.String("version_file_collection_id", versionFile.GetCollectionInfoImmutable().GetCollectionId())) diff --git a/go/pkg/sysdb/coordinator/task.go b/go/pkg/sysdb/coordinator/task.go index eb6586a93fb..ac0518780d2 100644 --- a/go/pkg/sysdb/coordinator/task.go +++ b/go/pkg/sysdb/coordinator/task.go @@ -123,12 +123,13 @@ func (s *Coordinator) CreateTask(ctx context.Context, req *coordinatorpb.CreateT OperatorParams: paramsJSON, CompletionOffset: 0, LastRun: nil, - NextRun: &now, + NextRun: now, // Initialize to current time, will be scheduled by task scheduler MinRecordsForTask: int64(req.MinRecordsForTask), CurrentAttempts: 0, CreatedAt: now, UpdatedAt: now, NextNonce: nextNonce, + LowestLiveNonce: &nextNonce, // Initialize to same value as NextNonce OldestWrittenNonce: nil, } @@ -194,24 +195,214 @@ func (s *Coordinator) GetTaskByName(ctx context.Context, req *coordinatorpb.GetT } } - // Convert task to response - response := &coordinatorpb.GetTaskByNameResponse{ - TaskId: proto.String(task.ID.String()), - Name: proto.String(task.Name), - OperatorName: proto.String(operator.OperatorName), - InputCollectionId: proto.String(task.InputCollectionID), - OutputCollectionName: proto.String(task.OutputCollectionName), + // Convert task to response with nested Task message + taskProto := &coordinatorpb.Task{ + TaskId: task.ID.String(), + Name: task.Name, + OperatorName: operator.OperatorName, + InputCollectionId: task.InputCollectionID, + OutputCollectionName: task.OutputCollectionName, Params: paramsStruct, - CompletionOffset: proto.Int64(task.CompletionOffset), - MinRecordsForTask: proto.Uint64(uint64(task.MinRecordsForTask)), - TenantId: proto.String(task.TenantID), - DatabaseId: proto.String(task.DatabaseID), + CompletionOffset: task.CompletionOffset, + MinRecordsForTask: uint64(task.MinRecordsForTask), + TenantId: task.TenantID, + DatabaseId: task.DatabaseID, + NextRunAt: uint64(task.NextRun.UnixMicro()), + LowestLiveNonce: "", + NextNonce: task.NextNonce.String(), + } + // Add lowest_live_nonce if it's set + if task.LowestLiveNonce != nil { + taskProto.LowestLiveNonce = task.LowestLiveNonce.String() + } + // Add output_collection_id if it's set + if task.OutputCollectionID != nil { + taskProto.OutputCollectionId = task.OutputCollectionID + } + + return &coordinatorpb.GetTaskByNameResponse{ + Task: taskProto, + }, nil +} + +// GetTaskByUuid retrieves a task by UUID from the database +func (s *Coordinator) GetTaskByUuid(ctx context.Context, req *coordinatorpb.GetTaskByUuidRequest) (*coordinatorpb.GetTaskByUuidResponse, error) { + // Parse the task UUID + taskID, err := uuid.Parse(req.TaskId) + if err != nil { + log.Error("GetTaskByUuid: invalid task_id", zap.Error(err)) + return nil, status.Errorf(codes.InvalidArgument, "invalid task_id: %v", err) + } + + // Fetch task by ID + task, err := s.catalog.metaDomain.TaskDb(ctx).GetByID(taskID) + if err != nil { + return nil, err + } + + // If task not found, return error + if task == nil { + return nil, common.ErrTaskNotFound + } + + // Look up operator name from operators table + operator, err := s.catalog.metaDomain.OperatorDb(ctx).GetByID(task.OperatorID) + if err != nil { + log.Error("GetTaskByUuid: failed to get operator", zap.Error(err)) + return nil, err + } + if operator == nil { + log.Error("GetTaskByUuid: operator not found", zap.String("operator_id", task.OperatorID.String())) + return nil, common.ErrOperatorNotFound + } + + // Debug logging + log.Info("Found task by UUID", zap.String("task_id", task.ID.String()), zap.String("name", task.Name), zap.String("input_collection_id", task.InputCollectionID), zap.String("output_collection_name", task.OutputCollectionName)) + + // Deserialize params from JSON string to protobuf Struct + var paramsStruct *structpb.Struct + if task.OperatorParams != "" { + paramsStruct = &structpb.Struct{} + if err := paramsStruct.UnmarshalJSON([]byte(task.OperatorParams)); err != nil { + log.Error("GetTaskByUuid: failed to unmarshal params", zap.Error(err)) + return nil, err + } + } + + // Convert task to response with nested Task message + taskProto := &coordinatorpb.Task{ + TaskId: task.ID.String(), + Name: task.Name, + OperatorName: operator.OperatorName, + InputCollectionId: task.InputCollectionID, + OutputCollectionName: task.OutputCollectionName, + Params: paramsStruct, + CompletionOffset: task.CompletionOffset, + MinRecordsForTask: uint64(task.MinRecordsForTask), + TenantId: task.TenantID, + DatabaseId: task.DatabaseID, + NextRunAt: uint64(task.NextRun.UnixMicro()), + LowestLiveNonce: "", + NextNonce: task.NextNonce.String(), + } + // Add lowest_live_nonce if it's set + if task.LowestLiveNonce != nil { + taskProto.LowestLiveNonce = task.LowestLiveNonce.String() } // Add output_collection_id if it's set if task.OutputCollectionID != nil { - response.OutputCollectionId = task.OutputCollectionID + taskProto.OutputCollectionId = task.OutputCollectionID } - return response, nil + + return &coordinatorpb.GetTaskByUuidResponse{ + Task: taskProto, + }, nil +} + +// CreateOutputCollectionForTask atomically creates an output collection and updates the task's output_collection_id +func (s *Coordinator) CreateOutputCollectionForTask(ctx context.Context, req *coordinatorpb.CreateOutputCollectionForTaskRequest) (*coordinatorpb.CreateOutputCollectionForTaskResponse, error) { + var collectionID types.UniqueID + + // Execute all operations in a transaction for atomicity + err := s.catalog.txImpl.Transaction(ctx, func(txCtx context.Context) error { + // 1. Parse task ID + taskID, err := uuid.Parse(req.TaskId) + if err != nil { + log.Error("CreateOutputCollectionForTask: invalid task_id", zap.Error(err)) + return status.Errorf(codes.InvalidArgument, "invalid task_id: %v", err) + } + + // 2. Get the task to verify it exists and doesn't already have an output collection + task, err := s.catalog.metaDomain.TaskDb(txCtx).GetByID(taskID) + if err != nil { + log.Error("CreateOutputCollectionForTask: failed to get task", zap.Error(err)) + return err + } + if task == nil { + log.Error("CreateOutputCollectionForTask: task not found") + return status.Errorf(codes.NotFound, "task not found") + } + + // Check if output collection already exists + if task.OutputCollectionID != nil && *task.OutputCollectionID != "" { + log.Error("CreateOutputCollectionForTask: output collection already exists", + zap.String("existing_collection_id", *task.OutputCollectionID)) + return status.Errorf(codes.AlreadyExists, "output collection already exists") + } + + // 3. Generate new collection UUID + collectionID = types.NewUniqueID() + + // 4. Look up database by ID to get its name + database, err := s.catalog.metaDomain.DatabaseDb(txCtx).GetByID(req.DatabaseId) + if err != nil { + log.Error("CreateOutputCollectionForTask: failed to get database", zap.Error(err)) + return err + } + if database == nil { + log.Error("CreateOutputCollectionForTask: database not found", zap.String("database_id", req.DatabaseId), zap.String("tenant_id", req.TenantId)) + return common.ErrDatabaseNotFound + } + + // 5. Create the collection with segments + // Set a default dimension to ensure segment writers can be initialized + dimension := int32(1) // Default dimension for task output collections + collection := &model.CreateCollection{ + ID: collectionID, + Name: req.CollectionName, + ConfigurationJsonStr: "{}", // Empty JSON object for default config + TenantID: req.TenantId, + DatabaseName: database.Name, + Dimension: &dimension, + Metadata: nil, + } + + // Create segments for the collection (distributed setup) + segments := []*model.Segment{ + { + ID: types.NewUniqueID(), + Type: "urn:chroma:segment/vector/hnsw-distributed", + Scope: "VECTOR", + CollectionID: collectionID, + }, + { + ID: types.NewUniqueID(), + Type: "urn:chroma:segment/metadata/blockfile", + Scope: "METADATA", + CollectionID: collectionID, + }, + { + ID: types.NewUniqueID(), + Type: "urn:chroma:segment/record/blockfile", + Scope: "RECORD", + CollectionID: collectionID, + }, + } + + _, _, err = s.catalog.CreateCollectionAndSegments(txCtx, collection, segments, 0) + if err != nil { + log.Error("CreateOutputCollectionForTask: failed to create collection", zap.Error(err)) + return err + } + + // 6. Update task with output_collection_id + collectionIDStr := collectionID.String() + err = s.catalog.metaDomain.TaskDb(txCtx).UpdateOutputCollectionID(taskID, &collectionIDStr) + if err != nil { + log.Error("CreateOutputCollectionForTask: failed to update task", zap.Error(err)) + return err + } + + return nil + }) + + if err != nil { + return nil, err + } + + return &coordinatorpb.CreateOutputCollectionForTaskResponse{ + CollectionId: collectionID.String(), + }, nil } // DeleteTask soft deletes a task by name @@ -289,13 +480,17 @@ func (s *Coordinator) AdvanceTask(ctx context.Context, req *coordinatorpb.Advanc return nil, status.Errorf(codes.InvalidArgument, "invalid task_run_nonce: %v", err) } - err = s.catalog.metaDomain.TaskDb(ctx).AdvanceTask(taskID, taskRunNonce) + advanceTask, err := s.catalog.metaDomain.TaskDb(ctx).AdvanceTask(taskID, taskRunNonce, *req.CompletionOffset, *req.NextRunDelaySecs) if err != nil { log.Error("AdvanceTask failed", zap.Error(err), zap.String("task_id", taskID.String())) return nil, err } - return &coordinatorpb.AdvanceTaskResponse{}, nil + return &coordinatorpb.AdvanceTaskResponse{ + NextRunNonce: advanceTask.NextNonce.String(), + NextRunAt: uint64(advanceTask.NextRun.UnixMilli()), + CompletionOffset: advanceTask.CompletionOffset, + }, nil } // GetOperators retrieves all operators from the database @@ -340,7 +535,7 @@ func (s *Coordinator) PeekScheduleByCollectionId(ctx context.Context, req *coord TaskRunNonce: proto.String(task.NextNonce.String()), WhenToRun: nil, } - if task.NextRun != nil { + if !task.NextRun.IsZero() { whenToRun := uint64(task.NextRun.UnixMilli()) entry.WhenToRun = &whenToRun } @@ -351,3 +546,19 @@ func (s *Coordinator) PeekScheduleByCollectionId(ctx context.Context, req *coord Schedule: scheduleEntries, }, nil } + +func (s *Coordinator) FinishTask(ctx context.Context, req *coordinatorpb.FinishTaskRequest) (*coordinatorpb.FinishTaskResponse, error) { + taskID, err := uuid.Parse(req.TaskId) + if err != nil { + log.Error("FinishTask: invalid task_id", zap.Error(err)) + return nil, err + } + + err = s.catalog.metaDomain.TaskDb(ctx).FinishTask(taskID) + if err != nil { + log.Error("FinishTask: failed to fin task", zap.Error(err)) + return nil, err + } + + return &coordinatorpb.FinishTaskResponse{}, nil +} diff --git a/go/pkg/sysdb/grpc/collection_service.go b/go/pkg/sysdb/grpc/collection_service.go index e3017b0d5e1..506e1fdd9f9 100644 --- a/go/pkg/sysdb/grpc/collection_service.go +++ b/go/pkg/sysdb/grpc/collection_service.go @@ -10,9 +10,11 @@ import ( "github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb" "github.com/chroma-core/chroma/go/pkg/sysdb/coordinator/model" "github.com/chroma-core/chroma/go/pkg/types" + "github.com/google/uuid" "github.com/pingcap/log" "go.uber.org/zap" "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/timestamppb" ) func (s *Server) ResetState(context.Context, *emptypb.Empty) (*coordinatorpb.ResetStateResponse, error) { @@ -570,6 +572,108 @@ func (s *Server) FlushCollectionCompaction(ctx context.Context, req *coordinator return res, nil } +func (s *Server) FlushCollectionCompactionAndTask(ctx context.Context, req *coordinatorpb.FlushCollectionCompactionAndTaskRequest) (*coordinatorpb.FlushCollectionCompactionAndTaskResponse, error) { + // Parse the flush compaction request (nested message) + flushReq := req.GetFlushCompaction() + if flushReq == nil { + log.Error("FlushCollectionCompactionAndTask failed. flush_compaction is nil") + return nil, grpcutils.BuildInternalGrpcError("flush_compaction is required") + } + + // Parse task update info + taskUpdate := req.GetTaskUpdate() + if taskUpdate == nil { + log.Error("FlushCollectionCompactionAndTask failed. task_update is nil") + return nil, grpcutils.BuildInternalGrpcError("task_update is required") + } + + taskID, err := uuid.Parse(taskUpdate.TaskId) + if err != nil { + log.Error("FlushCollectionCompactionAndTask failed. error parsing task id", zap.Error(err), zap.String("task_id", taskUpdate.TaskId)) + return nil, grpcutils.BuildInternalGrpcError("invalid task_id: " + err.Error()) + } + + taskRunNonce, err := uuid.Parse(taskUpdate.TaskRunNonce) + if err != nil { + log.Error("FlushCollectionCompactionAndTask failed. error parsing task run nonce", zap.Error(err), zap.String("task_run_nonce", taskUpdate.TaskRunNonce)) + return nil, grpcutils.BuildInternalGrpcError("invalid task_run_nonce: " + err.Error()) + } + + // Parse collection and segment info (reuse logic from FlushCollectionCompaction) + collectionID, err := types.ToUniqueID(&flushReq.CollectionId) + err = grpcutils.BuildErrorForUUID(collectionID, "collection", err) + if err != nil { + log.Error("FlushCollectionCompactionAndTask failed. error parsing collection id", zap.Error(err), zap.String("collection_id", flushReq.CollectionId)) + return nil, grpcutils.BuildInternalGrpcError(err.Error()) + } + + segmentCompactionInfo := make([]*model.FlushSegmentCompaction, 0, len(flushReq.SegmentCompactionInfo)) + for _, flushSegmentCompaction := range flushReq.SegmentCompactionInfo { + segmentID, err := types.ToUniqueID(&flushSegmentCompaction.SegmentId) + err = grpcutils.BuildErrorForUUID(segmentID, "segment", err) + if err != nil { + log.Error("FlushCollectionCompactionAndTask failed. error parsing segment id", zap.Error(err), zap.String("collection_id", flushReq.CollectionId)) + return nil, grpcutils.BuildInternalGrpcError(err.Error()) + } + filePaths := make(map[string][]string) + for key, filePath := range flushSegmentCompaction.FilePaths { + filePaths[key] = filePath.Paths + } + segmentCompactionInfo = append(segmentCompactionInfo, &model.FlushSegmentCompaction{ + ID: segmentID, + FilePaths: filePaths, + }) + } + + flushCollectionCompaction := &model.FlushCollectionCompaction{ + ID: collectionID, + TenantID: flushReq.TenantId, + LogPosition: flushReq.LogPosition, + CurrentCollectionVersion: flushReq.CollectionVersion, + FlushSegmentCompactions: segmentCompactionInfo, + TotalRecordsPostCompaction: flushReq.TotalRecordsPostCompaction, + SizeBytesPostCompaction: flushReq.SizeBytesPostCompaction, + } + + flushCollectionInfo, err := s.coordinator.FlushCollectionCompactionAndTask( + ctx, + flushCollectionCompaction, + taskID, + taskRunNonce, + taskUpdate.CompletionOffset, + taskUpdate.NextRunDelaySecs, + ) + if err != nil { + log.Error("FlushCollectionCompactionAndTask failed", zap.Error(err), zap.String("collection_id", flushReq.CollectionId), zap.String("task_id", taskUpdate.TaskId)) + if err == common.ErrCollectionSoftDeleted { + return nil, grpcutils.BuildFailedPreconditionGrpcError(err.Error()) + } + if err == common.ErrTaskNotFound { + return nil, grpcutils.BuildNotFoundGrpcError(err.Error()) + } + return nil, grpcutils.BuildInternalGrpcError(err.Error()) + } + + res := &coordinatorpb.FlushCollectionCompactionAndTaskResponse{ + CollectionId: flushCollectionInfo.ID, + CollectionVersion: flushCollectionInfo.CollectionVersion, + LastCompactionTime: flushCollectionInfo.TenantLastCompactionTime, + } + + // Populate task fields with authoritative values from database + if flushCollectionInfo.TaskNextNonce != nil { + res.NextNonce = flushCollectionInfo.TaskNextNonce.String() + } + if flushCollectionInfo.TaskNextRun != nil { + res.NextRun = timestamppb.New(*flushCollectionInfo.TaskNextRun) + } + if flushCollectionInfo.TaskCompletionOffset != nil { + res.CompletionOffset = *flushCollectionInfo.TaskCompletionOffset + } + + return res, nil +} + func (s *Server) ListCollectionsToGc(ctx context.Context, req *coordinatorpb.ListCollectionsToGcRequest) (*coordinatorpb.ListCollectionsToGcResponse, error) { absoluteCutoffTimeSecs := (*uint64)(nil) if req.CutoffTime != nil { diff --git a/go/pkg/sysdb/grpc/task_service.go b/go/pkg/sysdb/grpc/task_service.go index 0b2415493df..bb5a6a4a0a4 100644 --- a/go/pkg/sysdb/grpc/task_service.go +++ b/go/pkg/sysdb/grpc/task_service.go @@ -40,6 +40,33 @@ func (s *Server) GetTaskByName(ctx context.Context, req *coordinatorpb.GetTaskBy return res, nil } +func (s *Server) GetTaskByUuid(ctx context.Context, req *coordinatorpb.GetTaskByUuidRequest) (*coordinatorpb.GetTaskByUuidResponse, error) { + log.Info("GetTaskByUuid", zap.String("task_id", req.TaskId)) + + res, err := s.coordinator.GetTaskByUuid(ctx, req) + if err != nil { + log.Error("GetTaskByUuid failed", zap.Error(err)) + if err == common.ErrTaskNotFound { + return nil, grpcutils.BuildNotFoundGrpcError(err.Error()) + } + return nil, err + } + + return res, nil +} + +func (s *Server) CreateOutputCollectionForTask(ctx context.Context, req *coordinatorpb.CreateOutputCollectionForTaskRequest) (*coordinatorpb.CreateOutputCollectionForTaskResponse, error) { + log.Info("CreateOutputCollectionForTask", zap.String("task_id", req.TaskId), zap.String("collection_name", req.CollectionName)) + + res, err := s.coordinator.CreateOutputCollectionForTask(ctx, req) + if err != nil { + log.Error("CreateOutputCollectionForTask failed", zap.Error(err)) + return nil, err + } + + return res, nil +} + func (s *Server) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRequest) (*coordinatorpb.DeleteTaskResponse, error) { log.Info("DeleteTask", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName)) @@ -64,6 +91,18 @@ func (s *Server) AdvanceTask(ctx context.Context, req *coordinatorpb.AdvanceTask return res, nil } +func (s *Server) FinishTask(ctx context.Context, req *coordinatorpb.FinishTaskRequest) (*coordinatorpb.FinishTaskResponse, error) { + log.Info("FinishTask", zap.String("task_id", req.TaskId)) + + res, err := s.coordinator.FinishTask(ctx, req) + if err != nil { + log.Error("FinishTask failed", zap.Error(err)) + return nil, err + } + + return res, nil +} + func (s *Server) GetOperators(ctx context.Context, req *coordinatorpb.GetOperatorsRequest) (*coordinatorpb.GetOperatorsResponse, error) { log.Info("GetOperators") diff --git a/go/pkg/sysdb/metastore/db/dao/collection.go b/go/pkg/sysdb/metastore/db/dao/collection.go index 8654e54716a..bf7ab31e491 100644 --- a/go/pkg/sysdb/metastore/db/dao/collection.go +++ b/go/pkg/sysdb/metastore/db/dao/collection.go @@ -566,6 +566,40 @@ func (s *collectionDb) UpdateLogPositionAndVersionInfo( return result.RowsAffected, nil } +func (s *collectionDb) UpdateVersionInfo( + collectionID string, + currentCollectionVersion int32, + currentVersionFileName string, + newCollectionVersion int32, + newVersionFileName string, + totalRecordsPostCompaction uint64, + sizeBytesPostCompaction uint64, + lastCompactionTimeSecs uint64, + numVersions uint64, +) (int64, error) { + // Similar to UpdateLogPositionAndVersionInfo but does NOT update log_position + // Used for task-based flushes where the collection's log position should remain unchanged + result := s.db.Model(&dbmodel.Collection{}). + Clauses(clause.Locking{Strength: "UPDATE"}). + Where("id = ? AND version = ? AND (version_file_name IS NULL OR version_file_name = ?)", + collectionID, + currentCollectionVersion, + currentVersionFileName). + Updates(map[string]interface{}{ + // NOTE: log_position is NOT updated here + "version": newCollectionVersion, + "version_file_name": newVersionFileName, + "total_records_post_compaction": totalRecordsPostCompaction, + "size_bytes_post_compaction": sizeBytesPostCompaction, + "last_compaction_time_secs": lastCompactionTimeSecs, + "num_versions": numVersions, + }) + if result.Error != nil { + return 0, result.Error + } + return result.RowsAffected, nil +} + func (s *collectionDb) UpdateLogPositionVersionTotalRecordsAndLogicalSize(collectionID string, logPosition int64, currentCollectionVersion int32, totalRecordsPostCompaction uint64, sizeBytesPostCompaction uint64, lastCompactionTimeSecs uint64, tenant string, schemaStr *string) (int32, error) { log.Info("update log position, version, and total records post compaction", zap.String("collectionID", collectionID), zap.Int64("logPosition", logPosition), zap.Int32("currentCollectionVersion", currentCollectionVersion), zap.Uint64("totalRecords", totalRecordsPostCompaction)) var collection dbmodel.Collection diff --git a/go/pkg/sysdb/metastore/db/dao/database.go b/go/pkg/sysdb/metastore/db/dao/database.go index 01a1f71f811..2e9076e2f8a 100644 --- a/go/pkg/sysdb/metastore/db/dao/database.go +++ b/go/pkg/sysdb/metastore/db/dao/database.go @@ -67,6 +67,23 @@ func (s *databaseDb) GetDatabases(tenantID string, databaseName string) ([]*dbmo return databases, nil } +func (s *databaseDb) GetByID(databaseID string) (*dbmodel.Database, error) { + var database dbmodel.Database + query := s.db.Table("databases"). + Select("databases.id, databases.name, databases.tenant_id"). + Where("databases.id = ?", databaseID). + Where("databases.is_deleted = ?", false) + + if err := query.First(&database).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + log.Error("GetByID", zap.Error(err)) + return nil, err + } + return &database, nil +} + func (s *databaseDb) Insert(database *dbmodel.Database) error { err := s.db.Create(database).Error if err != nil { diff --git a/go/pkg/sysdb/metastore/db/dao/task.go b/go/pkg/sysdb/metastore/db/dao/task.go index 1c0b07d7ffe..2626af4c0f4 100644 --- a/go/pkg/sysdb/metastore/db/dao/task.go +++ b/go/pkg/sysdb/metastore/db/dao/task.go @@ -77,32 +77,23 @@ func (s *taskDb) GetByID(taskID uuid.UUID) (*dbmodel.Task, error) { return &task, nil } -func (s *taskDb) AdvanceTask(taskID uuid.UUID, taskRunNonce uuid.UUID) error { - nextNonce, err := uuid.NewV7() - if err != nil { - log.Error("AdvanceTask: failed to generate next nonce", zap.Error(err)) - return err - } - +func (s *taskDb) UpdateOutputCollectionID(taskID uuid.UUID, outputCollectionID *string) error { now := time.Now() result := s.db.Exec(` UPDATE tasks - SET next_nonce = ?, - updated_at = GREATEST(updated_at, GREATEST(?, last_run)), - last_run = ?, - current_attempts = 0 + SET output_collection_id = ?, + updated_at = ? WHERE task_id = ? - AND next_nonce = ? AND is_deleted = false - `, nextNonce, now, now, taskID, taskRunNonce) + `, outputCollectionID, now, taskID) if result.Error != nil { - log.Error("AdvanceTask failed", zap.Error(result.Error), zap.String("task_id", taskID.String())) + log.Error("UpdateOutputCollectionID failed", zap.Error(result.Error), zap.String("task_id", taskID.String())) return result.Error } if result.RowsAffected == 0 { - log.Warn("AdvanceTask: no rows affected", zap.String("task_id", taskID.String()), zap.String("task_run_nonce", taskRunNonce.String())) + log.Error("UpdateOutputCollectionID: no rows affected", zap.String("task_id", taskID.String())) return common.ErrTaskNotFound } @@ -134,6 +125,110 @@ func (s *taskDb) SoftDelete(inputCollectionID string, taskName string) error { return nil } +// AdvanceTask updates task progress after register operator completes +// This bumps next_nonce and updates completion_offset/next_run +// Returns the authoritative values from the database +func (s *taskDb) AdvanceTask(taskID uuid.UUID, taskRunNonce uuid.UUID, completionOffset int64, nextRunDelaySecs uint64) (*dbmodel.AdvanceTask, error) { + nextNonce, err := uuid.NewV7() + if err != nil { + log.Error("AdvanceTask: failed to generate next nonce", zap.Error(err)) + return nil, err + } + now := time.Now() + // Bump next_nonce to mark a new epoch, but don't touch lowest_live_nonce yet + // lowest_live_nonce will be updated later by finish_task when verification completes + result := s.db.Model(&dbmodel.Task{}).Where("task_id = ?", taskID).Where("is_deleted = false").Where("next_nonce = ?", taskRunNonce).UpdateColumns(map[string]interface{}{ + "completion_offset": completionOffset, + "next_run": now.Add(time.Duration(nextRunDelaySecs) * time.Second), + "last_run": now, + "next_nonce": nextNonce, + "updated_at": gorm.Expr("GREATEST(updated_at, GREATEST(last_run, ?))", now), + }) + + if result.Error != nil { + log.Error("AdvanceTask failed", zap.Error(result.Error), zap.String("task_id", taskID.String())) + return nil, result.Error + } + + if result.RowsAffected == 0 { + log.Error("AdvanceTask: no rows affected", zap.String("task_id", taskID.String())) + return nil, common.ErrTaskNotFound + } + + // get first result row + var task dbmodel.Task + if err := result.Scan(&task).Error; err != nil { + log.Error("AdvanceTask: failed to scan result", zap.Error(err), zap.String("task_id", taskID.String())) + return nil, err + } + + // Return the authoritative values that were written to the database + return &dbmodel.AdvanceTask{ + NextNonce: nextNonce, + NextRun: task.NextRun, + CompletionOffset: task.CompletionOffset, + }, nil +} + +// UpdateCompletionOffset updates ONLY the completion_offset for a task +// This is called during flush_compaction_and_task after work is done +// NOTE: Does NOT update next_nonce (that was done earlier by AdvanceTask in PrepareTask) +func (s *taskDb) UpdateCompletionOffset(taskID uuid.UUID, taskRunNonce uuid.UUID, completionOffset int64) error { + now := time.Now() + // Update only completion_offset and last_run + // Validate that we're updating the correct epoch by checking lowest_live_nonce = taskRunNonce + // This ensures we're updating the completion offset for the exact epoch we're working on + result := s.db.Model(&dbmodel.Task{}). + Where("task_id = ?", taskID). + Where("is_deleted = false"). + Where("lowest_live_nonce = ?", taskRunNonce). // Ensure we're updating the correct epoch + UpdateColumns(map[string]interface{}{ + "completion_offset": completionOffset, + "last_run": now, + "updated_at": now, + }) + + if result.Error != nil { + log.Error("UpdateCompletionOffset failed", zap.Error(result.Error), zap.String("task_id", taskID.String())) + return result.Error + } + + if result.RowsAffected == 0 { + log.Error("UpdateCompletionOffset: no rows affected - task not found or wrong epoch", zap.String("task_id", taskID.String()), zap.String("task_run_nonce", taskRunNonce.String())) + return common.ErrTaskNotFound + } + + return nil +} + +// FinishTask updates lowest_live_nonce to mark the current epoch as verified +// This is called by the finish_task operator after scout_logs recheck completes +func (s *taskDb) FinishTask(taskID uuid.UUID) error { + now := time.Now() + // Set lowest_live_nonce = next_nonce to indicate this epoch is fully verified + // If this fails, lowest_live_nonce < next_nonce will signal that we should skip + // execution next time and only run the recheck phase + result := s.db.Exec(` + UPDATE tasks + SET lowest_live_nonce = next_nonce, + updated_at = ? + WHERE task_id = ? + AND is_deleted = false + `, now, taskID) + + if result.Error != nil { + log.Error("FinishTask failed", zap.Error(result.Error), zap.String("task_id", taskID.String())) + return result.Error + } + + if result.RowsAffected == 0 { + log.Error("FinishTask: no rows affected", zap.String("task_id", taskID.String())) + return common.ErrTaskNotFound + } + + return nil +} + func (s *taskDb) PeekScheduleByCollectionId(collectionIDs []string) ([]*dbmodel.Task, error) { var tasks []*dbmodel.Task err := s.db. @@ -147,3 +242,26 @@ func (s *taskDb) PeekScheduleByCollectionId(collectionIDs []string) ([]*dbmodel. } return tasks, nil } + +// GetMinCompletionOffsetForCollection returns the minimum completion_offset for all non-deleted tasks +// with the given input_collection_id. Returns nil if no tasks exist for the collection. +func (s *taskDb) GetMinCompletionOffsetForCollection(inputCollectionID string) (*int64, error) { + var result struct { + MinOffset *int64 + } + + err := s.db.Model(&dbmodel.Task{}). + Select("MIN(completion_offset) as min_offset"). + Where("input_collection_id = ?", inputCollectionID). + Where("is_deleted = ?", false). + Scan(&result).Error + + if err != nil { + log.Error("GetMinCompletionOffsetForCollection failed", + zap.Error(err), + zap.String("input_collection_id", inputCollectionID)) + return nil, err + } + + return result.MinOffset, nil +} diff --git a/go/pkg/sysdb/metastore/db/dao/task_test.go b/go/pkg/sysdb/metastore/db/dao/task_test.go index 3497c0df316..62e2fbcf185 100644 --- a/go/pkg/sysdb/metastore/db/dao/task_test.go +++ b/go/pkg/sysdb/metastore/db/dao/task_test.go @@ -445,11 +445,80 @@ func (suite *TaskDbTestSuite) TestTaskDb_AdvanceTask_InvalidNonce() { } func (suite *TaskDbTestSuite) TestTaskDb_AdvanceTask_NotFound() { - err := suite.Db.AdvanceTask(uuid.New(), uuid.Must(uuid.NewV7())) + err := suite.Db.AdvanceTask(uuid.New(), uuid.Must(uuid.NewV7()), 0, 0) suite.Require().Error(err) suite.Require().Equal(common.ErrTaskNotFound, err) } +func (suite *TaskDbTestSuite) TestTaskDb_UpdateCompletionOffset() { + taskID := uuid.New() + operatorID := dbmodel.OperatorRecordCounter + originalNonce, _ := uuid.NewV7() + + task := &dbmodel.Task{ + ID: taskID, + Name: "test_update_completion_task", + OperatorID: operatorID, + InputCollectionID: "input_collection_1", + OutputCollectionID: nil, + OutputCollectionName: "output_collection_1", + TenantID: "tenant_1", + DatabaseID: "database_1", + CompletionOffset: 100, + MinRecordsForTask: 10, + NextNonce: originalNonce, + LowestLiveNonce: &originalNonce, + } + + err := suite.Db.Insert(task) + suite.Require().NoError(err) + + // Update completion offset to 200 + err = suite.Db.UpdateCompletionOffset(taskID, originalNonce, 200) + suite.Require().NoError(err) + + // Verify the update + retrieved, err := suite.Db.GetByID(taskID) + suite.Require().NoError(err) + suite.Require().Equal(int64(200), retrieved.CompletionOffset) + // next_nonce should remain unchanged + suite.Require().Equal(originalNonce, retrieved.NextNonce) + + suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID) +} + +func (suite *TaskDbTestSuite) TestTaskDb_UpdateCompletionOffset_InvalidNonce() { + taskID := uuid.New() + operatorID := dbmodel.OperatorRecordCounter + correctNonce, _ := uuid.NewV7() + wrongNonce, _ := uuid.NewV7() + + task := &dbmodel.Task{ + ID: taskID, + Name: "test_update_wrong_nonce", + OperatorID: operatorID, + InputCollectionID: "input_collection_1", + OutputCollectionID: nil, + OutputCollectionName: "output_collection_1", + TenantID: "tenant_1", + DatabaseID: "database_1", + CompletionOffset: 100, + MinRecordsForTask: 10, + NextNonce: correctNonce, + LowestLiveNonce: &correctNonce, + } + + err := suite.Db.Insert(task) + suite.Require().NoError(err) + + // Try to update with wrong nonce + err = suite.Db.UpdateCompletionOffset(taskID, wrongNonce, 200) + suite.Require().Error(err) + suite.Require().Equal(common.ErrTaskNotFound, err) + + suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID) +} + // TestOperatorConstantsMatchSeededDatabase verifies that operator constants in // dbmodel/constants.go match what we seed in the test database (which should match migrations). // This catches drift between constants and migrations at test time. diff --git a/go/pkg/sysdb/metastore/db/dbmodel/collection.go b/go/pkg/sysdb/metastore/db/dbmodel/collection.go index 54789d660e7..50e43f66cff 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/collection.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/collection.go @@ -65,6 +65,8 @@ type ICollectionDb interface { UpdateLogPositionVersionTotalRecordsAndLogicalSize(collectionID string, logPosition int64, currentCollectionVersion int32, totalRecordsPostCompaction uint64, sizeBytesPostCompaction uint64, lastCompactionTimeSecs uint64, tenant string, schemaStr *string) (int32, error) UpdateLogPositionAndVersionInfo(collectionID string, logPosition int64, currentCollectionVersion int32, currentVersionFilePath string, newCollectionVersion int32, newVersionFilePath string, totalRecordsPostCompaction uint64, sizeBytesPostCompaction uint64, lastCompactionTimeSecs uint64, numVersions uint64, schemaStr *string) (int64, error) + UpdateVersionInfo(collectionID string, currentCollectionVersion int32, currentVersionFilePath string, newCollectionVersion int32, newVersionFilePath string, totalRecordsPostCompaction uint64, + sizeBytesPostCompaction uint64, lastCompactionTimeSecs uint64, numVersions uint64) (int64, error) GetCollectionWithoutMetadata(collectionID *string, databaseName *string, softDeletedFlag *bool) (*Collection, error) GetCollectionSize(collectionID string) (uint64, error) ListCollectionsToGc(cutoffTimeSecs *uint64, limit *uint64, tenantID *string, minVersionsIfAlive *uint64) ([]*CollectionToGc, error) diff --git a/go/pkg/sysdb/metastore/db/dbmodel/database.go b/go/pkg/sysdb/metastore/db/dbmodel/database.go index f254ae42380..f800b362ab1 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/database.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/database.go @@ -23,6 +23,7 @@ func (v Database) TableName() string { //go:generate mockery --name=IDatabaseDb type IDatabaseDb interface { GetDatabases(tenantID string, databaseName string) ([]*Database, error) + GetByID(databaseID string) (*Database, error) ListDatabases(limit *int32, offset *int32, tenantID string) ([]*Database, error) Insert(in *Database) error DeleteAll() error diff --git a/go/pkg/sysdb/metastore/db/dbmodel/task.go b/go/pkg/sysdb/metastore/db/dbmodel/task.go index 32935b67ecd..26bb251d6bf 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/task.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/task.go @@ -7,40 +7,52 @@ import ( ) type Task struct { - ID uuid.UUID `gorm:"column:task_id;primaryKey"` - Name string `gorm:"column:task_name;type:text;not null;uniqueIndex:unique_task_per_collection,priority:2"` - TenantID string `gorm:"column:tenant_id;type:text;not null"` - DatabaseID string `gorm:"column:database_id;type:text;not null"` - InputCollectionID string `gorm:"column:input_collection_id;type:text;not null;uniqueIndex:unique_task_per_collection,priority:1"` + ID uuid.UUID `gorm:"column:task_id;primaryKey"` + Name string `gorm:"column:task_name;type:text;not null;uniqueIndex:unique_task_per_collection,priority:2"` + TenantID string `gorm:"column:tenant_id;type:text;not null"` + DatabaseID string `gorm:"column:database_id;type:text;not null"` + InputCollectionID string `gorm:"column:input_collection_id;type:text;not null;uniqueIndex:unique_task_per_collection,priority:1"` OutputCollectionName string `gorm:"column:output_collection_name;type:text;not null"` OutputCollectionID *string `gorm:"column:output_collection_id;type:text;default:null"` - OperatorID uuid.UUID `gorm:"column:operator_id;type:uuid;not null"` - OperatorParams string `gorm:"column:operator_params;type:jsonb;not null"` - CompletionOffset int64 `gorm:"column:completion_offset;type:bigint;not null;default:0"` - LastRun *time.Time `gorm:"column:last_run;type:timestamp"` - NextRun *time.Time `gorm:"column:next_run;type:timestamp"` - MinRecordsForTask int64 `gorm:"column:min_records_for_task;type:bigint;not null;default:100"` - CurrentAttempts int32 `gorm:"column:current_attempts;type:integer;not null;default:0"` - IsAlive bool `gorm:"column:is_alive;type:boolean;not null;default:true"` - IsDeleted bool `gorm:"column:is_deleted;type:boolean;not null;default:false"` - CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP"` - UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;not null;default:CURRENT_TIMESTAMP"` - TaskTemplateParent *uuid.UUID `gorm:"column:task_template_parent;type:uuid;default:null"` - NextNonce uuid.UUID `gorm:"column:next_nonce;type:uuid;not null"` - OldestWrittenNonce *uuid.UUID `gorm:"column:oldest_written_nonce;type:uuid;default:null"` + OperatorID uuid.UUID `gorm:"column:operator_id;type:uuid;not null"` + OperatorParams string `gorm:"column:operator_params;type:jsonb;not null"` + CompletionOffset int64 `gorm:"column:completion_offset;type:bigint;not null;default:0"` + LastRun *time.Time `gorm:"column:last_run;type:timestamp"` + NextRun time.Time `gorm:"column:next_run;type:timestamp;not null"` + MinRecordsForTask int64 `gorm:"column:min_records_for_task;type:bigint;not null;default:100"` + CurrentAttempts int32 `gorm:"column:current_attempts;type:integer;not null;default:0"` + IsAlive bool `gorm:"column:is_alive;type:boolean;not null;default:true"` + IsDeleted bool `gorm:"column:is_deleted;type:boolean;not null;default:false"` + CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP"` + UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;not null;default:CURRENT_TIMESTAMP"` + TaskTemplateParent *uuid.UUID `gorm:"column:task_template_parent;type:uuid;default:null"` + NextNonce uuid.UUID `gorm:"column:next_nonce;type:uuid;not null"` + LowestLiveNonce *uuid.UUID `gorm:"column:lowest_live_nonce;type:uuid;default:null"` + OldestWrittenNonce *uuid.UUID `gorm:"column:oldest_written_nonce;type:uuid;default:null"` } func (v Task) TableName() string { return "tasks" } +// AdvanceTask contains the authoritative task data after AdvanceTask +type AdvanceTask struct { + NextNonce uuid.UUID + NextRun time.Time + CompletionOffset int64 +} + //go:generate mockery --name=ITaskDb type ITaskDb interface { Insert(task *Task) error GetByName(inputCollectionID string, taskName string) (*Task, error) GetByID(taskID uuid.UUID) (*Task, error) - AdvanceTask(taskID uuid.UUID, taskRunNonce uuid.UUID) error + AdvanceTask(taskID uuid.UUID, nextRunNonce uuid.UUID, completionOffset int64, nextRunDelaySecs uint64) (*AdvanceTask, error) + UpdateCompletionOffset(taskID uuid.UUID, taskRunNonce uuid.UUID, completionOffset int64) error + FinishTask(taskID uuid.UUID) error + UpdateOutputCollectionID(taskID uuid.UUID, outputCollectionID *string) error SoftDelete(inputCollectionID string, taskName string) error DeleteAll() error PeekScheduleByCollectionId(collectionIDs []string) ([]*Task, error) + GetMinCompletionOffsetForCollection(inputCollectionID string) (*int64, error) } diff --git a/go/pkg/sysdb/metastore/db/migrations/20251013000000.sql b/go/pkg/sysdb/metastore/db/migrations/20251013000000.sql new file mode 100644 index 00000000000..2dd970ff582 --- /dev/null +++ b/go/pkg/sysdb/metastore/db/migrations/20251013000000.sql @@ -0,0 +1,7 @@ +-- Make next_run NOT NULL +ALTER TABLE "public"."tasks" +ALTER COLUMN "next_run" SET NOT NULL; + +-- Add lowest_live_nonce column, initialized to next_nonce +ALTER TABLE "public"."tasks" +ADD COLUMN "lowest_live_nonce" UUID DEFAULT NULL; diff --git a/go/pkg/sysdb/metastore/db/migrations/atlas.sum b/go/pkg/sysdb/metastore/db/migrations/atlas.sum index 10a745ddb7e..1f9a57c486f 100644 --- a/go/pkg/sysdb/metastore/db/migrations/atlas.sum +++ b/go/pkg/sysdb/metastore/db/migrations/atlas.sum @@ -1,4 +1,4 @@ -h1:Jk3VaF1qoRNVAB7cCxgSDFiiH9Y6r1zSIIj1SxhCklc= +h1:R2eYzl9Eu7q0vyoeuC2DxF+nxtdniO+v9mqTHffSn0A= 20240313233558.sql h1:Gv0TiSYsqGoOZ2T2IWvX4BOasauxool8PrBOIjmmIdg= 20240321194713.sql h1:kVkNpqSFhrXGVGFFvL7JdK3Bw31twFcEhI6A0oCFCkg= 20240327075032.sql h1:nlr2J74XRU8erzHnKJgMr/tKqJxw9+R6RiiEBuvuzgo= @@ -20,3 +20,4 @@ h1:Jk3VaF1qoRNVAB7cCxgSDFiiH9Y6r1zSIIj1SxhCklc= 20250806213245.sql h1:OgEOd3bL+rKdQ2x/Hcm3f0/yyrWirJkPm14V5N4sgKE= 20250930122132.sql h1:ch67SU2K5X4gV5E1knOEk/yprnn9FrbZsJCkmUnAbqo= 20251001073000.sql h1:pdl+M9f46vz7rbXZtJjOWTXlbSBpL2a0nVHl5VUOOsg= +20251013000000.sql h1:oy/34GiJZ9/nUrDViOMOFZsGbuAACPnZKPB9PET3BVE= diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index c8ee3653d47..ffd8af6254f 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -321,6 +321,30 @@ message FlushCollectionCompactionResponse { int64 last_compaction_time = 3; } +// Task update information for transactional flush operations +message TaskUpdateInfo { + string task_id = 1; + string task_run_nonce = 2; + int64 completion_offset = 3; + uint64 next_run_delay_secs = 4; +} + +// Combined request to flush collection compaction and update task atomically in a single transaction +message FlushCollectionCompactionAndTaskRequest { + FlushCollectionCompactionRequest flush_compaction = 1; + TaskUpdateInfo task_update = 2; +} + +message FlushCollectionCompactionAndTaskResponse { + string collection_id = 1; + int32 collection_version = 2; + int64 last_compaction_time = 3; + // Updated task fields from database (authoritative) + string next_nonce = 4; + google.protobuf.Timestamp next_run = 5; + int64 completion_offset = 6; +} + // Used for serializing contents in collection version history file. message CollectionVersionFile { CollectionInfoImmutable collection_info_immutable = 1; @@ -541,23 +565,49 @@ message CreateTaskResponse { string task_id = 1; } +message CreateOutputCollectionForTaskRequest { + string task_id = 1; + string collection_name = 2; + string tenant_id = 3; + string database_id = 4; +} + +message CreateOutputCollectionForTaskResponse { + string collection_id = 1; +} + message GetTaskByNameRequest { string input_collection_id = 1; string task_name = 2; } -message GetTaskByNameResponse { - optional string task_id = 1; - optional string name = 2; - optional string operator_name = 3; - optional string input_collection_id = 4; - optional string output_collection_name = 5; +message Task { + string task_id = 1; + string name = 2; + string operator_name = 3; + string input_collection_id = 4; + string output_collection_name = 5; optional string output_collection_id = 6; optional google.protobuf.Struct params = 7; - optional int64 completion_offset = 8; - optional uint64 min_records_for_task = 9; - optional string tenant_id = 10; - optional string database_id = 11; + int64 completion_offset = 8; + uint64 min_records_for_task = 9; + string tenant_id = 10; + string database_id = 11; + uint64 next_run_at = 12; + string lowest_live_nonce = 13; + string next_nonce = 14; +} + +message GetTaskByNameResponse { + Task task = 1; +} + +message GetTaskByUuidRequest { + string task_id = 1; +} + +message GetTaskByUuidResponse { + Task task = 1; } message DeleteTaskRequest { @@ -574,9 +624,21 @@ message AdvanceTaskRequest { optional string collection_id = 1; optional string task_id = 2; optional string task_run_nonce = 3; + optional int64 completion_offset = 4; + optional uint64 next_run_delay_secs = 5; +} + +message AdvanceTaskResponse { + string next_run_nonce = 1; + uint64 next_run_at = 2; + int64 completion_offset = 3; +} + +message FinishTaskRequest { + string task_id = 1; } -message AdvanceTaskResponse {} +message FinishTaskResponse {} message Operator { string id = 1; @@ -635,6 +697,7 @@ service SysDB { rpc GetLastCompactionTimeForTenant(GetLastCompactionTimeForTenantRequest) returns (GetLastCompactionTimeForTenantResponse) {} rpc SetLastCompactionTimeForTenant(SetLastCompactionTimeForTenantRequest) returns (google.protobuf.Empty) {} rpc FlushCollectionCompaction(FlushCollectionCompactionRequest) returns (FlushCollectionCompactionResponse) {} + rpc FlushCollectionCompactionAndTask(FlushCollectionCompactionAndTaskRequest) returns (FlushCollectionCompactionAndTaskResponse) {} rpc RestoreCollection(RestoreCollectionRequest) returns (RestoreCollectionResponse) {} rpc ListCollectionVersions(ListCollectionVersionsRequest) returns (ListCollectionVersionsResponse) {} rpc GetCollectionSize(GetCollectionSizeRequest) returns (GetCollectionSizeResponse) {} @@ -644,9 +707,12 @@ service SysDB { rpc BatchGetCollectionVersionFilePaths(BatchGetCollectionVersionFilePathsRequest) returns (BatchGetCollectionVersionFilePathsResponse) {} rpc BatchGetCollectionSoftDeleteStatus(BatchGetCollectionSoftDeleteStatusRequest) returns (BatchGetCollectionSoftDeleteStatusResponse) {} rpc CreateTask(CreateTaskRequest) returns (CreateTaskResponse) {} + rpc CreateOutputCollectionForTask(CreateOutputCollectionForTaskRequest) returns (CreateOutputCollectionForTaskResponse) {} rpc GetTaskByName(GetTaskByNameRequest) returns (GetTaskByNameResponse) {} + rpc GetTaskByUuid(GetTaskByUuidRequest) returns (GetTaskByUuidResponse) {} rpc DeleteTask(DeleteTaskRequest) returns (DeleteTaskResponse) {} rpc AdvanceTask(AdvanceTaskRequest) returns (AdvanceTaskResponse) {} + rpc FinishTask(FinishTaskRequest) returns (FinishTaskResponse) {} rpc GetOperators(GetOperatorsRequest) returns (GetOperatorsResponse) {} rpc PeekScheduleByCollectionId(PeekScheduleByCollectionIdRequest) returns (PeekScheduleByCollectionIdResponse) {} } diff --git a/idl/chromadb/proto/heapservice.proto b/idl/chromadb/proto/heapservice.proto index e5237a840b5..fd32e7be413 100644 --- a/idl/chromadb/proto/heapservice.proto +++ b/idl/chromadb/proto/heapservice.proto @@ -7,6 +7,18 @@ import "chromadb/proto/chroma.proto"; message HeapSummaryRequest {} message HeapSummaryResponse {} +message ScheduleTaskRequest { + string collection_id = 1; + string task_id = 2; + uint64 when_to_run = 3; + string task_run_nonce = 4; +} + +message ScheduleTaskResponse { + bool success = 1; +} + service HeapTenderService { rpc Summary(HeapSummaryRequest) returns (HeapSummaryResponse) {} + rpc ScheduleTask(ScheduleTaskRequest) returns (ScheduleTaskResponse) {} } diff --git a/rust/log-service/src/lib.rs b/rust/log-service/src/lib.rs index d4036c70939..bd1b5793edb 100644 --- a/rust/log-service/src/lib.rs +++ b/rust/log-service/src/lib.rs @@ -1388,7 +1388,10 @@ impl LogServer { if records.len() != pull_logs.batch_size as usize || (!records.is_empty() && records[0].log_offset != pull_logs.start_from_offset) { - return Err(Status::not_found("Some entries have been purged")); + return Err(Status::not_found(format!( + "Some entries have been purged {} versus {}", + records[0].log_offset, pull_logs.start_from_offset + ))); } Ok(Response::new(PullLogsResponse { records })) } diff --git a/rust/log/src/in_memory_log.rs b/rust/log/src/in_memory_log.rs index d640029048b..c0e6c93fbe1 100644 --- a/rust/log/src/in_memory_log.rs +++ b/rust/log/src/in_memory_log.rs @@ -53,7 +53,7 @@ impl InMemoryLog { next_offset, log.log_offset ); } - logs.push(log); + logs.push(log.clone()); } } diff --git a/rust/s3heap-service/Cargo.toml b/rust/s3heap-service/Cargo.toml index d67f59bcf29..408e3b6783e 100644 --- a/rust/s3heap-service/Cargo.toml +++ b/rust/s3heap-service/Cargo.toml @@ -13,6 +13,7 @@ serde_json = { workspace = true } tokio = { workspace = true } tonic = { workspace = true } tonic-health = { workspace = true } +tower = { workspace = true } tracing = { workspace = true } chroma-config = { workspace = true } diff --git a/rust/s3heap-service/src/lib.rs b/rust/s3heap-service/src/lib.rs index 5ebaba74187..ea211cac8f8 100644 --- a/rust/s3heap-service/src/lib.rs +++ b/rust/s3heap-service/src/lib.rs @@ -2,10 +2,14 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; +use chrono::{DateTime, Utc}; +use uuid::Uuid; + use figment::providers::{Env, Format, Yaml}; use futures::stream::StreamExt; use tokio::signal::unix::{signal, SignalKind}; use tonic::{transport::Server, Request, Response, Status}; +use tower::ServiceBuilder; use chroma_config::helpers::{deserialize_duration_from_seconds, serialize_duration_to_seconds}; use chroma_config::Configurable; @@ -15,10 +19,13 @@ use chroma_storage::Storage; use chroma_sysdb::{SysDb, SysDbConfig}; use chroma_tracing::OtelFilter; use chroma_tracing::OtelFilterLevel; +use chroma_types::chroma_proto::heap_tender_service_client::HeapTenderServiceClient; use chroma_types::chroma_proto::heap_tender_service_server::{ HeapTenderService, HeapTenderServiceServer, }; -use chroma_types::chroma_proto::{HeapSummaryRequest, HeapSummaryResponse}; +use chroma_types::chroma_proto::{ + HeapSummaryRequest, HeapSummaryResponse, ScheduleTaskRequest, ScheduleTaskResponse, +}; use chroma_types::{dirty_log_path_from_hostname, CollectionUuid, DirtyMarker, ScheduleEntry}; use s3heap::{heap_path_from_hostname, Configuration, HeapWriter, Schedule, Triggerable}; use wal3::{ @@ -138,7 +145,7 @@ impl HeapTender { let schedule = Schedule { triggerable, next_scheduled, - nonce: s.task_run_nonce, + nonce: s.task_run_nonce.0, }; Ok(Some(schedule)) } else { @@ -374,6 +381,178 @@ impl HeapTenderService for HeapTenderServer { ) -> Result, Status> { todo!(); } + + async fn schedule_task( + &self, + request: Request, + ) -> Result, Status> { + let req = request.into_inner(); + + // Parse collection_id + let collection_id = match Uuid::parse_str(&req.collection_id) { + Ok(uuid) => CollectionUuid(uuid), + Err(e) => { + return Err(Status::invalid_argument(format!( + "Invalid collection_id: {}", + e + ))); + } + }; + + // Parse task_id + let task_id = match Uuid::parse_str(&req.task_id) { + Ok(uuid) => uuid, + Err(e) => { + return Err(Status::invalid_argument(format!("Invalid task_id: {}", e))); + } + }; + + // Parse task_run_nonce + let task_run_nonce = match Uuid::parse_str(&req.task_run_nonce) { + Ok(uuid) => uuid, + Err(e) => { + return Err(Status::invalid_argument(format!( + "Invalid task_run_nonce: {}", + e + ))); + } + }; + + // Convert timestamp from microseconds to DateTime + let when_to_run = DateTime::::from_timestamp_micros(req.when_to_run as i64) + .ok_or_else(|| Status::invalid_argument("Invalid when_to_run timestamp"))?; + + // Create the triggerable and schedule + let triggerable = Triggerable { + partitioning: s3heap::UnitOfPartitioningUuid::new(collection_id.0), + scheduling: s3heap::UnitOfSchedulingUuid::new(task_id), + }; + + let schedule = Schedule { + triggerable, + next_scheduled: when_to_run, + nonce: task_run_nonce, + }; + + // Push to the heap writer + match self.tender.writer.push(&[schedule]).await { + Ok(_) => { + tracing::info!( + "Manually scheduled task {} for collection {} at timestamp {}", + req.task_id, + req.collection_id, + req.when_to_run + ); + Ok(Response::new(ScheduleTaskResponse { success: true })) + } + Err(e) => { + tracing::error!("Failed to schedule task: {:?}", e); + Err(Status::internal(format!( + "Failed to schedule task: {:?}", + e + ))) + } + } + } +} + +///////////////////////////////////////// HeapTenderClient ///////////////////////////////////////// + +/// Configuration for connecting to a HeapTender gRPC service. +#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] +pub struct HeapTenderClientConfig { + /// The hostname or IP address of the HeapTender service. + #[serde(default = "HeapTenderClientConfig::default_host")] + pub host: String, + /// The port of the HeapTender service. + #[serde(default = "HeapTenderClientConfig::default_port")] + pub port: u16, +} + +impl HeapTenderClientConfig { + fn default_host() -> String { + "heap-tender-service".to_string() + } + + fn default_port() -> u16 { + 50052 + } +} + +impl Default for HeapTenderClientConfig { + fn default() -> Self { + Self { + host: Self::default_host(), + port: Self::default_port(), + } + } +} + +/// Client for connecting to and calling the HeapTender gRPC service. +#[derive(Clone, Debug)] +pub struct HeapTenderClient { + client: + HeapTenderServiceClient>, +} + +impl HeapTenderClient { + /// Schedule a task with the heap tender service. + /// + /// # Arguments + /// * `collection_id` - The UUID of the collection as a string + /// * `task_id` - The UUID of the task as a string + /// * `when_to_run` - Timestamp in microseconds when the task should run + /// * `task_run_nonce` - The UUID nonce for this task run as a string + pub async fn schedule_task( + &mut self, + collection_id: &str, + task_id: &str, + when_to_run: u64, + task_run_nonce: &str, + ) -> Result<(), Error> { + let request = tonic::Request::new(ScheduleTaskRequest { + collection_id: collection_id.to_string(), + task_id: task_id.to_string(), + when_to_run, + task_run_nonce: task_run_nonce.to_string(), + }); + + self.client + .schedule_task(request) + .await + .map_err(|e| Error::Internal(format!("Failed to schedule task: {}", e)))?; + + Ok(()) + } +} + +#[async_trait::async_trait] +impl Configurable for HeapTenderClient { + async fn try_from_config( + config: &HeapTenderClientConfig, + _registry: &chroma_config::registry::Registry, + ) -> Result> { + let uri = format!("http://{}:{}", config.host, config.port); + let channel = tonic::transport::Channel::from_shared(uri) + .map_err(|e| -> Box { + Box::new(s3heap::Error::Internal(format!("Invalid URI: {}", e))) + })? + .connect() + .await + .map_err(|e| -> Box { + Box::new(s3heap::Error::Internal(format!( + "Failed to connect to heap tender service: {}", + e + ))) + })?; + + let channel = ServiceBuilder::new() + .layer(chroma_tracing::GrpcClientTraceLayer) + .service(channel); + + let client = HeapTenderServiceClient::new(channel); + Ok(HeapTenderClient { client }) + } } //////////////////////////////////////////// RootConfig //////////////////////////////////////////// diff --git a/rust/s3heap-service/src/scheduler.rs b/rust/s3heap-service/src/scheduler.rs index 044282fd3de..a390893fecb 100644 --- a/rust/s3heap-service/src/scheduler.rs +++ b/rust/s3heap-service/src/scheduler.rs @@ -45,7 +45,7 @@ impl HeapScheduler for SysDbScheduler { results.push(true); continue; }; - results.push(schedule.task_run_nonce != *nonce); + results.push(schedule.task_run_nonce.0 != *nonce); } Ok(results) } @@ -67,7 +67,7 @@ impl HeapScheduler for SysDbScheduler { partitioning: schedule.collection_id.0.into(), scheduling: schedule.task_id.into(), }, - nonce: schedule.task_run_nonce, + nonce: schedule.task_run_nonce.0, next_scheduled: when_to_run, }); } diff --git a/rust/sysdb/src/bin/chroma-task-manager.rs b/rust/sysdb/src/bin/chroma-task-manager.rs index 0e813aa980e..82b63752785 100644 --- a/rust/sysdb/src/bin/chroma-task-manager.rs +++ b/rust/sysdb/src/bin/chroma-task-manager.rs @@ -67,6 +67,10 @@ enum Command { task_id: String, #[arg(long, help = "Nonce identifying the specific task run")] task_run_nonce: String, + #[arg(long, help = "Completion offset")] + completion_offset: u64, + #[arg(long, help = "Next run delay in seconds")] + next_run_delay_secs: u64, }, #[command(about = "Get all operators")] GetOperators, @@ -158,7 +162,7 @@ async fn main() -> Result<(), Box> { }; let response = client.get_task_by_name(request).await?; - let task = response.into_inner(); + let task = response.into_inner().task.unwrap(); println!("Task ID: {:?}", task.task_id); println!("Name: {:?}", task.name); @@ -188,11 +192,15 @@ async fn main() -> Result<(), Box> { collection_id, task_id, task_run_nonce, + completion_offset, + next_run_delay_secs, } => { let request = chroma_proto::AdvanceTaskRequest { collection_id: Some(collection_id), task_id: Some(task_id), task_run_nonce: Some(task_run_nonce), + completion_offset: Some(completion_offset as i64), + next_run_delay_secs: Some(next_run_delay_secs), }; client.advance_task(request).await?; diff --git a/rust/sysdb/src/sysdb.rs b/rust/sysdb/src/sysdb.rs index 527d8db3617..96e90138ec8 100644 --- a/rust/sysdb/src/sysdb.rs +++ b/rust/sysdb/src/sysdb.rs @@ -6,6 +6,8 @@ use chroma_config::registry::Registry; use chroma_config::Configurable; use chroma_error::{ChromaError, ErrorCodes, TonicError, TonicMissingFieldError}; use chroma_types::chroma_proto::sys_db_client::SysDbClient; +use chroma_types::chroma_proto::AdvanceTaskRequest; +use chroma_types::chroma_proto::FinishTaskRequest; use chroma_types::chroma_proto::VersionListForCollection; use chroma_types::{ chroma_proto, chroma_proto::CollectionVersionInfo, CollectionAndSegments, @@ -21,11 +23,12 @@ use chroma_types::{ UpdateTenantResponse, VectorIndexConfiguration, }; use chroma_types::{ - BatchGetCollectionSoftDeleteStatusError, BatchGetCollectionVersionFilePathsError, Collection, - CollectionConversionError, CollectionUuid, CountForksError, DatabaseUuid, - FinishDatabaseDeletionError, FlushCompactionResponse, FlushCompactionResponseConversionError, - ForkCollectionError, InternalSchema, SchemaError, Segment, SegmentConversionError, - SegmentScope, Tenant, + AdvanceTaskError, AdvanceTaskResponse, BatchGetCollectionSoftDeleteStatusError, + BatchGetCollectionVersionFilePathsError, Collection, CollectionConversionError, CollectionUuid, + CountForksError, DatabaseUuid, FinishDatabaseDeletionError, FinishTaskError, + FlushCompactionAndTaskResponse, FlushCompactionResponse, + FlushCompactionResponseConversionError, ForkCollectionError, InternalSchema, SchemaError, + Segment, SegmentConversionError, SegmentScope, TaskUpdateInfo, TaskUuid, Tenant, }; use prost_types; use std::collections::HashMap; @@ -626,6 +629,39 @@ impl SysDb { } } + #[allow(clippy::too_many_arguments)] + pub async fn flush_compaction_and_task( + &mut self, + tenant_id: String, + collection_id: CollectionUuid, + log_position: i64, + collection_version: i32, + segment_flush_info: Arc<[SegmentFlushInfo]>, + total_records_post_compaction: u64, + size_bytes_post_compaction: u64, + schema: Option, + task_update: TaskUpdateInfo, + ) -> Result { + match self { + SysDb::Grpc(grpc) => { + grpc.flush_compaction_and_task( + tenant_id, + collection_id, + log_position, + collection_version, + segment_flush_info, + total_records_post_compaction, + size_bytes_post_compaction, + schema, + task_update, + ) + .await + } + SysDb::Sqlite(_) => todo!(), + SysDb::Test(_) => todo!(), + } + } + pub async fn list_collection_versions( &mut self, collection_id: CollectionUuid, @@ -695,6 +731,36 @@ impl SysDb { SysDb::Test(test) => test.peek_schedule_by_collection_id(collection_ids).await, } } + + pub async fn finish_task(&mut self, task_id: TaskUuid) -> Result<(), FinishTaskError> { + match self { + SysDb::Grpc(grpc) => grpc.finish_task(task_id).await, + SysDb::Sqlite(_) => unimplemented!(), + SysDb::Test(test) => test.finish_task(task_id).await, + } + } + + pub async fn advance_task( + &mut self, + task_id: TaskUuid, + task_run_nonce: uuid::Uuid, + completion_offset: i64, + next_run_delay_secs: u64, + ) -> Result { + match self { + SysDb::Grpc(grpc) => { + grpc.advance_task( + task_id, + task_run_nonce, + completion_offset, + next_run_delay_secs, + ) + .await + } + SysDb::Sqlite(_) => unimplemented!(), + SysDb::Test(_) => unimplemented!(), + } + } } #[derive(Clone, Debug)] @@ -1609,6 +1675,88 @@ impl GrpcSysDb { } } + #[allow(clippy::too_many_arguments)] + async fn flush_compaction_and_task( + &mut self, + tenant_id: String, + collection_id: CollectionUuid, + log_position: i64, + collection_version: i32, + segment_flush_info: Arc<[SegmentFlushInfo]>, + total_records_post_compaction: u64, + size_bytes_post_compaction: u64, + schema: Option, + task_update: TaskUpdateInfo, + ) -> Result { + let segment_compaction_info = + segment_flush_info + .iter() + .map(|segment_flush_info| segment_flush_info.try_into()) + .collect::, + SegmentFlushInfoConversionError, + >>(); + + let segment_compaction_info = match segment_compaction_info { + Ok(segment_compaction_info) => segment_compaction_info, + Err(e) => { + return Err(FlushCompactionError::SegmentFlushInfoConversionError(e)); + } + }; + + let schema_str = schema.and_then(|s| { + serde_json::to_string(&s).ok().or_else(|| { + tracing::error!("Failed to serialize schema for flush_compaction_and_task"); + None + }) + }); + + let flush_compaction = Some(chroma_proto::FlushCollectionCompactionRequest { + tenant_id, + collection_id: collection_id.0.to_string(), + log_position, + collection_version, + segment_compaction_info, + total_records_post_compaction, + size_bytes_post_compaction, + schema_str, + }); + + let task_update_proto = Some(chroma_proto::TaskUpdateInfo { + task_id: task_update.task_id.0.to_string(), + task_run_nonce: task_update.task_run_nonce.to_string(), + completion_offset: task_update.completion_offset, + next_run_delay_secs: task_update.next_run_delay_secs, + }); + + let req = chroma_proto::FlushCollectionCompactionAndTaskRequest { + flush_compaction, + task_update: task_update_proto, + }; + + let res = self.client.flush_collection_compaction_and_task(req).await; + match res { + Ok(res) => { + let res = res.into_inner(); + let res = match res.try_into() { + Ok(res) => res, + Err(e) => { + return Err( + FlushCompactionError::FlushCompactionResponseConversionError(e), + ); + } + }; + Ok(res) + } + Err(e) => { + if e.code() == Code::FailedPrecondition { + return Err(FlushCompactionError::FailedToFlushCompaction(e)); + } + Err(FlushCompactionError::FailedToFlushCompaction(e)) + } + } + } + async fn mark_version_for_deletion( &mut self, epoch_id: i64, @@ -1655,6 +1803,68 @@ impl GrpcSysDb { Ok(ResetResponse {}) } + async fn finish_task(&mut self, task_id: TaskUuid) -> Result<(), FinishTaskError> { + let req = FinishTaskRequest { + task_id: task_id.0.to_string(), + }; + self.client.finish_task(req).await.map_err(|e| { + if e.code() == Code::NotFound { + FinishTaskError::TaskNotFound + } else { + FinishTaskError::FailedToFinishTask(e) + } + })?; + Ok(()) + } + + async fn advance_task( + &mut self, + task_id: TaskUuid, + task_run_nonce: uuid::Uuid, + completion_offset: i64, + next_run_delay_secs: u64, + ) -> Result { + let req = AdvanceTaskRequest { + collection_id: None, // Not used by coordinator + task_id: Some(task_id.0.to_string()), + task_run_nonce: Some(task_run_nonce.to_string()), + completion_offset: Some(completion_offset), + next_run_delay_secs: Some(next_run_delay_secs), + }; + + let response = self.client.advance_task(req).await.map_err(|e| { + if e.code() == Code::NotFound { + AdvanceTaskError::TaskNotFound + } else { + AdvanceTaskError::FailedToAdvanceTask(e) + } + })?; + + let response = response.into_inner(); + + // Parse next_nonce + let next_nonce = uuid::Uuid::parse_str(&response.next_run_nonce).map_err(|e| { + tracing::error!( + next_nonce = %response.next_run_nonce, + error = %e, + "Server returned invalid next_nonce UUID" + ); + AdvanceTaskError::FailedToAdvanceTask(tonic::Status::internal( + "Invalid next_nonce in response", + )) + })?; + + // Parse next_run timestamp + let next_run = + std::time::UNIX_EPOCH + std::time::Duration::from_millis(response.next_run_at); + + Ok(AdvanceTaskResponse { + next_nonce, + next_run, + completion_offset: response.completion_offset as u64, + }) + } + #[allow(clippy::too_many_arguments)] pub async fn create_task( &mut self, @@ -1677,7 +1887,6 @@ impl GrpcSysDb { }), _ => None, // Non-object params omitted from proto }; - let req = chroma_proto::CreateTaskRequest { name: name.clone(), operator_name: operator_name.clone(), @@ -1688,9 +1897,7 @@ impl GrpcSysDb { database: database_name.clone(), min_records_for_task, }; - let response = self.client.create_task(req).await?.into_inner(); - // Parse the returned task_id - this should always succeed since the server generated it // If this fails, it indicates a serious server bug or protocol corruption let task_id = chroma_types::TaskUuid( @@ -1703,7 +1910,6 @@ impl GrpcSysDb { CreateTaskError::ServerReturnedInvalidData })?, ); - Ok(task_id) } @@ -1728,35 +1934,142 @@ impl GrpcSysDb { }; let response = response.into_inner(); - // If response has no task_id, task was not found - if response.task_id.is_none() { - return Err(GetTaskError::NotFound); - } + // Extract the nested task from response + let task = response.task.ok_or_else(|| { + GetTaskError::FailedToGetTask(tonic::Status::internal("Missing task in response")) + })?; - // Parse the response and construct Task - let task_id_str = response.task_id.unwrap(); - let task_id = chroma_types::TaskUuid(uuid::Uuid::parse_str(&task_id_str).map_err(|e| { - tracing::error!( - task_id = %task_id_str, - error = %e, - "Server returned invalid task_id UUID" - ); - GetTaskError::ServerReturnedInvalidData - })?); + // Parse task_id + let task_id = + chroma_types::TaskUuid(uuid::Uuid::parse_str(&task.task_id).map_err(|e| { + tracing::error!( + task_id = %task.task_id, + error = %e, + "Server returned invalid task_id UUID" + ); + GetTaskError::ServerReturnedInvalidData + })?); - let operator_id = response.operator_name.ok_or_else(|| { - GetTaskError::FailedToGetTask(tonic::Status::internal( - "Missing operator_name in response", - )) + // Parse input_collection_id + let parsed_input_collection_id = chroma_types::CollectionUuid( + uuid::Uuid::parse_str(&task.input_collection_id).map_err(|e| { + tracing::error!( + input_collection_id = %task.input_collection_id, + error = %e, + "Server returned invalid input_collection_id UUID" + ); + GetTaskError::ServerReturnedInvalidData + })?, + ); + + // Parse next_run timestamp from microseconds + let next_run = + std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_micros(task.next_run_at); + + // Parse nonces + let lowest_live_nonce = if task.lowest_live_nonce.is_empty() { + None + } else { + Some( + uuid::Uuid::parse_str(&task.lowest_live_nonce) + .map(chroma_types::NonceUuid) + .map_err(|e| { + tracing::error!( + lowest_live_nonce = %task.lowest_live_nonce, + error = %e, + "Server returned invalid lowest_live_nonce UUID" + ); + GetTaskError::ServerReturnedInvalidData + })?, + ) + }; + + let next_nonce = uuid::Uuid::parse_str(&task.next_nonce) + .map(chroma_types::NonceUuid) + .map_err(|e| { + tracing::error!( + next_nonce = %task.next_nonce, + error = %e, + "Server returned invalid next_nonce UUID" + ); + GetTaskError::ServerReturnedInvalidData + })?; + + // Convert params from Struct to JSON string + let params_str = task.params.map(|s| { + let json_value = prost_struct_to_json(s); + serde_json::to_string(&json_value).unwrap_or_else(|_| "{}".to_string()) + }); + + let parsed_output_collection_id = task.output_collection_id.map(|id| { + uuid::Uuid::parse_str(&id) + .map(chroma_types::CollectionUuid) + .ok() + .unwrap() + }); + + Ok(chroma_types::Task { + id: task_id, + name: task.name, + operator_id: task.operator_name, + input_collection_id: parsed_input_collection_id, + output_collection_name: task.output_collection_name, + output_collection_id: parsed_output_collection_id, + params: params_str, + tenant_id: task.tenant_id, + database_id: task.database_id, + last_run: None, + next_run, + lowest_live_nonce, + next_nonce, + completion_offset: task.completion_offset as u64, + min_records_for_task: task.min_records_for_task, + is_deleted: false, + created_at: std::time::SystemTime::now(), + updated_at: std::time::SystemTime::now(), + }) + } + + pub async fn get_task_by_uuid( + &mut self, + task_uuid: chroma_types::TaskUuid, + ) -> Result { + let req = chroma_proto::GetTaskByUuidRequest { + task_id: task_uuid.0.to_string(), + }; + + let response = match self.client.get_task_by_uuid(req).await { + Ok(resp) => resp, + Err(status) => { + if status.code() == tonic::Code::NotFound { + return Err(GetTaskError::NotFound); + } + return Err(GetTaskError::FailedToGetTask(status)); + } + }; + let response = response.into_inner(); + + // Extract the nested task from response + let task = response.task.ok_or_else(|| { + GetTaskError::FailedToGetTask(tonic::Status::internal("Missing task in response")) })?; - let input_collection_id_str = response - .input_collection_id - .unwrap_or_else(|| input_collection_id.to_string()); + // Parse task_id + let task_id = + chroma_types::TaskUuid(uuid::Uuid::parse_str(&task.task_id).map_err(|e| { + tracing::error!( + task_id = %task.task_id, + error = %e, + "Server returned invalid task_id UUID" + ); + GetTaskError::ServerReturnedInvalidData + })?); + + // Parse input_collection_id let parsed_input_collection_id = chroma_types::CollectionUuid( - uuid::Uuid::parse_str(&input_collection_id_str).map_err(|e| { + uuid::Uuid::parse_str(&task.input_collection_id).map_err(|e| { tracing::error!( - input_collection_id = %input_collection_id_str, + input_collection_id = %task.input_collection_id, error = %e, "Server returned invalid input_collection_id UUID" ); @@ -1764,32 +2077,121 @@ impl GrpcSysDb { })?, ); + // Parse next_run timestamp from microseconds + let next_run = + std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_micros(task.next_run_at); + + // Parse nonces + let lowest_live_nonce = if task.lowest_live_nonce.is_empty() { + None + } else { + Some( + uuid::Uuid::parse_str(&task.lowest_live_nonce) + .map(chroma_types::NonceUuid) + .map_err(|e| { + tracing::error!( + lowest_live_nonce = %task.lowest_live_nonce, + error = %e, + "Server returned invalid lowest_live_nonce UUID" + ); + GetTaskError::ServerReturnedInvalidData + })?, + ) + }; + + let next_nonce = uuid::Uuid::parse_str(&task.next_nonce) + .map(chroma_types::NonceUuid) + .map_err(|e| { + tracing::error!( + next_nonce = %task.next_nonce, + error = %e, + "Server returned invalid next_nonce UUID" + ); + GetTaskError::ServerReturnedInvalidData + })?; + // Convert params from Struct to JSON string - let params_str = response.params.map(|s| { + let params_str = task.params.map(|s| { let json_value = prost_struct_to_json(s); serde_json::to_string(&json_value).unwrap_or_else(|_| "{}".to_string()) }); + // Parse output_collection_id if present + let parsed_output_collection_id = task.output_collection_id.as_ref().and_then(|id_str| { + if id_str.is_empty() { + None + } else { + uuid::Uuid::parse_str(id_str) + .map(chroma_types::CollectionUuid) + .ok() + } + }); + Ok(chroma_types::Task { id: task_id, - name: response.name.unwrap_or(task_name), - operator_id, + name: task.name, + operator_id: task.operator_name, input_collection_id: parsed_input_collection_id, - output_collection_name: response.output_collection_name.unwrap_or_default(), - output_collection_id: Some(response.output_collection_id.unwrap_or_default()), + output_collection_name: task.output_collection_name, + output_collection_id: parsed_output_collection_id, params: params_str, - tenant_id: response.tenant_id.unwrap_or_default(), - database_id: response.database_id.unwrap_or_default(), + tenant_id: task.tenant_id, + database_id: task.database_id, last_run: None, - next_run: None, - completion_offset: response.completion_offset.unwrap_or(0) as u64, - min_records_for_task: response.min_records_for_task.unwrap_or(100), + next_run, + lowest_live_nonce, + next_nonce, + completion_offset: task.completion_offset as u64, + min_records_for_task: task.min_records_for_task, is_deleted: false, created_at: std::time::SystemTime::now(), updated_at: std::time::SystemTime::now(), }) } + pub async fn create_output_collection_for_task( + &mut self, + task_id: chroma_types::TaskUuid, + collection_name: String, + tenant_id: String, + database_id: String, + ) -> Result { + let req = chroma_proto::CreateOutputCollectionForTaskRequest { + task_id: task_id.0.to_string(), + collection_name, + tenant_id, + database_id, + }; + + let response = self + .client + .create_output_collection_for_task(req) + .await + .map_err(|e| { + if e.code() == tonic::Code::NotFound { + return CreateOutputCollectionForTaskError::TaskNotFound; + } + if e.code() == tonic::Code::AlreadyExists { + return CreateOutputCollectionForTaskError::OutputCollectionAlreadyExists; + } + CreateOutputCollectionForTaskError::FailedToCreateOutputCollectionForTask(e) + })?; + + let response = response.into_inner(); + + // Parse the returned collection_id + let collection_id = uuid::Uuid::parse_str(&response.collection_id).map_err(|e| { + tracing::error!( + collection_id = %response.collection_id, + error = %e, + "Server returned invalid collection_id UUID" + ); + CreateOutputCollectionForTaskError::ServerReturnedInvalidData + })?; + + Ok(CollectionUuid(collection_id)) + } + pub async fn soft_delete_task( &mut self, _task_id: chroma_types::TaskUuid, @@ -2018,6 +2420,45 @@ impl SysDb { } } + pub async fn get_task_by_uuid( + &mut self, + task_uuid: chroma_types::TaskUuid, + ) -> Result { + match self { + SysDb::Grpc(grpc) => grpc.get_task_by_uuid(task_uuid).await, + SysDb::Sqlite(_) => { + // TODO: Implement for Sqlite + Err(GetTaskError::NotFound) + } + SysDb::Test(_) => { + // TODO: Implement for TestSysDb + Err(GetTaskError::NotFound) + } + } + } + + pub async fn create_output_collection_for_task( + &mut self, + task_id: chroma_types::TaskUuid, + collection_name: String, + tenant_id: String, + database_id: String, + ) -> Result { + match self { + SysDb::Grpc(grpc) => { + grpc.create_output_collection_for_task( + task_id, + collection_name, + tenant_id, + database_id, + ) + .await + } + SysDb::Sqlite(_) => todo!(), + SysDb::Test(_) => todo!(), + } + } + pub async fn soft_delete_task( &mut self, task_id: chroma_types::TaskUuid, @@ -2094,6 +2535,33 @@ impl ChromaError for GetTaskError { } } +#[derive(Error, Debug)] +pub enum CreateOutputCollectionForTaskError { + #[error("Task not found")] + TaskNotFound, + #[error("Output collection already exists")] + OutputCollectionAlreadyExists, + #[error("Failed to create output collection for task: {0}")] + FailedToCreateOutputCollectionForTask(#[from] tonic::Status), + #[error("Server returned invalid data")] + ServerReturnedInvalidData, +} + +impl ChromaError for CreateOutputCollectionForTaskError { + fn code(&self) -> ErrorCodes { + match self { + CreateOutputCollectionForTaskError::TaskNotFound => ErrorCodes::NotFound, + CreateOutputCollectionForTaskError::OutputCollectionAlreadyExists => { + ErrorCodes::AlreadyExists + } + CreateOutputCollectionForTaskError::FailedToCreateOutputCollectionForTask(e) => { + e.code().into() + } + CreateOutputCollectionForTaskError::ServerReturnedInvalidData => ErrorCodes::Internal, + } + } +} + #[derive(Error, Debug)] pub enum DeleteTaskError { #[error("Task not found")] diff --git a/rust/sysdb/src/test_sysdb.rs b/rust/sysdb/src/test_sysdb.rs index abf3dae6d2e..88b5c9d8a9a 100644 --- a/rust/sysdb/src/test_sysdb.rs +++ b/rust/sysdb/src/test_sysdb.rs @@ -42,6 +42,7 @@ struct Inner { tenant_resource_names: HashMap, collection_to_version_file: HashMap, soft_deleted_collections: HashSet, + tasks: HashMap, #[derivative(Debug = "ignore")] storage: Option, mock_time: u64, @@ -58,6 +59,7 @@ impl TestSysDb { tenant_resource_names: HashMap::new(), collection_to_version_file: HashMap::new(), soft_deleted_collections: HashSet::new(), + tasks: HashMap::new(), storage: None, mock_time: 0, })), @@ -669,4 +671,13 @@ impl TestSysDb { ) -> Result, crate::sysdb::PeekScheduleError> { Ok(vec![]) } + + pub(crate) async fn finish_task( + &mut self, + _task_id: chroma_types::TaskUuid, + ) -> Result<(), chroma_types::FinishTaskError> { + // For testing, always succeed + // In a real implementation, this would update lowest_live_nonce = next_nonce + Ok(()) + } } diff --git a/rust/types/src/execution/operator.rs b/rust/types/src/execution/operator.rs index f0454adc733..deab08d1c3d 100644 --- a/rust/types/src/execution/operator.rs +++ b/rust/types/src/execution/operator.rs @@ -136,7 +136,7 @@ impl Serialize for Filter { { // For the search API, serialize directly as the where clause (or empty object if None) // If query_ids are present, they should be combined with the where_clause as Key::ID.is_in([...]) - + match (&self.query_ids, &self.where_clause) { (None, None) => { // No filter at all - serialize empty object diff --git a/rust/types/src/flush.rs b/rust/types/src/flush.rs index a34aaf67319..5193f3e3f79 100644 --- a/rust/types/src/flush.rs +++ b/rust/types/src/flush.rs @@ -1,6 +1,9 @@ -use super::{CollectionUuid, ConversionError}; +use super::{CollectionUuid, ConversionError, TaskUuid}; use crate::{ - chroma_proto::{FilePaths, FlushCollectionCompactionResponse, FlushSegmentCompactionInfo}, + chroma_proto::{ + FilePaths, FlushCollectionCompactionAndTaskResponse, FlushCollectionCompactionResponse, + FlushSegmentCompactionInfo, + }, SegmentUuid, }; use chroma_error::{ChromaError, ErrorCodes}; @@ -14,6 +17,67 @@ pub struct SegmentFlushInfo { pub file_paths: HashMap>, } +#[derive(Debug, Clone)] +pub struct TaskUpdateInfo { + pub task_id: TaskUuid, + pub task_run_nonce: uuid::Uuid, + pub completion_offset: i64, + pub next_run_delay_secs: u64, +} + +#[derive(Error, Debug)] +pub enum FinishTaskError { + #[error("Failed to finish task: {0}")] + FailedToFinishTask(#[from] tonic::Status), + #[error("Task not found")] + TaskNotFound, +} + +impl ChromaError for FinishTaskError { + fn code(&self) -> ErrorCodes { + match self { + FinishTaskError::FailedToFinishTask(_) => ErrorCodes::Internal, + FinishTaskError::TaskNotFound => ErrorCodes::NotFound, + } + } +} + +#[derive(Error, Debug)] +pub enum GetMinCompletionOffsetError { + #[error("Failed to get min completion offset: {0}")] + FailedToGetMinCompletionOffset(#[from] tonic::Status), +} + +impl ChromaError for GetMinCompletionOffsetError { + fn code(&self) -> ErrorCodes { + ErrorCodes::Internal + } +} + +#[derive(Error, Debug)] +pub enum AdvanceTaskError { + #[error("Failed to advance task: {0}")] + FailedToAdvanceTask(#[from] tonic::Status), + #[error("Task not found - nonce mismatch or task doesn't exist")] + TaskNotFound, +} + +impl ChromaError for AdvanceTaskError { + fn code(&self) -> ErrorCodes { + match self { + AdvanceTaskError::FailedToAdvanceTask(_) => ErrorCodes::Internal, + AdvanceTaskError::TaskNotFound => ErrorCodes::NotFound, + } + } +} + +#[derive(Debug, Clone)] +pub struct AdvanceTaskResponse { + pub next_nonce: uuid::Uuid, + pub next_run: std::time::SystemTime, + pub completion_offset: u64, +} + impl TryInto for &SegmentFlushInfo { type Error = SegmentFlushInfoConversionError; @@ -45,6 +109,17 @@ pub struct FlushCompactionResponse { pub last_compaction_time: i64, } +#[derive(Debug)] +pub struct FlushCompactionAndTaskResponse { + pub collection_id: CollectionUuid, + pub collection_version: i32, + pub last_compaction_time: i64, + // Completion offset updated during register + pub completion_offset: u64, + // NOTE: next_nonce and next_run are no longer returned + // They were already set by PrepareTask via advance_task() +} + impl FlushCompactionResponse { pub fn new( collection_id: CollectionUuid, @@ -73,18 +148,47 @@ impl TryFrom for FlushCompactionResponse { } } +impl TryFrom for FlushCompactionAndTaskResponse { + type Error = FlushCompactionResponseConversionError; + + fn try_from(value: FlushCollectionCompactionAndTaskResponse) -> Result { + let id = Uuid::parse_str(&value.collection_id) + .map_err(|_| FlushCompactionResponseConversionError::InvalidUuid)?; + + // Note: next_nonce and next_run are no longer populated by the server + // They were already set by PrepareTask via advance_task() + // We only use completion_offset from the response + + Ok(FlushCompactionAndTaskResponse { + collection_id: CollectionUuid(id), + collection_version: value.collection_version, + last_compaction_time: value.last_compaction_time, + completion_offset: value.completion_offset as u64, + }) + } +} + #[derive(Error, Debug)] pub enum FlushCompactionResponseConversionError { #[error(transparent)] DecodeError(#[from] ConversionError), #[error("Invalid collection id, valid UUID required")] InvalidUuid, + #[error("Invalid task nonce, valid UUID required")] + InvalidTaskNonce, + #[error("Missing next_run timestamp")] + MissingNextRun, + #[error("Invalid timestamp format")] + InvalidTimestamp, } impl ChromaError for FlushCompactionResponseConversionError { fn code(&self) -> ErrorCodes { match self { FlushCompactionResponseConversionError::InvalidUuid => ErrorCodes::InvalidArgument, + FlushCompactionResponseConversionError::InvalidTaskNonce => ErrorCodes::InvalidArgument, + FlushCompactionResponseConversionError::MissingNextRun => ErrorCodes::InvalidArgument, + FlushCompactionResponseConversionError::InvalidTimestamp => ErrorCodes::InvalidArgument, FlushCompactionResponseConversionError::DecodeError(e) => e.code(), } } diff --git a/rust/types/src/task.rs b/rust/types/src/task.rs index 77614922514..eb1f93b7d23 100644 --- a/rust/types/src/task.rs +++ b/rust/types/src/task.rs @@ -6,6 +6,48 @@ use uuid::Uuid; use crate::CollectionUuid; +/// JobUuid is a wrapper around Uuid to provide a unified type for job identifiers. +/// Jobs can be either collection compaction jobs or task execution jobs. +#[derive( + Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, +)] +pub struct JobUuid(pub Uuid); + +impl JobUuid { + pub fn new() -> Self { + JobUuid(Uuid::new_v4()) + } +} + +impl From for JobUuid { + fn from(collection_uuid: CollectionUuid) -> Self { + JobUuid(collection_uuid.0) + } +} + +impl From for JobUuid { + fn from(task_uuid: TaskUuid) -> Self { + JobUuid(task_uuid.0) + } +} + +impl std::str::FromStr for JobUuid { + type Err = uuid::Error; + + fn from_str(s: &str) -> Result { + match Uuid::parse_str(s) { + Ok(uuid) => Ok(JobUuid(uuid)), + Err(err) => Err(err), + } + } +} + +impl std::fmt::Display for JobUuid { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + /// TaskUuid is a wrapper around Uuid to provide a type for task identifiers. #[derive( Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, @@ -36,8 +78,69 @@ impl std::fmt::Display for TaskUuid { } } +#[derive( + Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, +)] +pub struct TaskRunUuid(pub Uuid); + +impl TaskRunUuid { + pub fn new() -> Self { + TaskRunUuid(Uuid::now_v7()) + } +} + +impl std::str::FromStr for TaskRunUuid { + type Err = uuid::Error; + + fn from_str(s: &str) -> Result { + match Uuid::parse_str(s) { + Ok(uuid) => Ok(TaskRunUuid(uuid)), + Err(err) => Err(err), + } + } +} + +impl std::fmt::Display for TaskRunUuid { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// NonceUuid is a wrapper around Uuid to provide a type for task execution nonces. +#[derive( + Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, +)] +pub struct NonceUuid(pub Uuid); + +impl NonceUuid { + pub fn new() -> Self { + NonceUuid(Uuid::now_v7()) + } +} + +impl std::str::FromStr for NonceUuid { + type Err = uuid::Error; + + fn from_str(s: &str) -> Result { + match Uuid::parse_str(s) { + Ok(uuid) => Ok(NonceUuid(uuid)), + Err(err) => Err(err), + } + } +} + +impl std::fmt::Display for NonceUuid { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + /// Task represents an asynchronous task that is triggered by collection writes /// to map records from a source collection to a target collection. +fn default_systemtime() -> SystemTime { + SystemTime::UNIX_EPOCH +} + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct Task { /// Unique identifier for the task @@ -51,7 +154,7 @@ pub struct Task { /// Name of target collection where task output is stored pub output_collection_name: String, /// ID of the output collection (lazily filled in after creation) - pub output_collection_id: Option, + pub output_collection_id: Option, /// Optional JSON parameters for the operator pub params: Option, /// Tenant name this task belongs to (despite field name, this is a name not a UUID) @@ -61,9 +164,9 @@ pub struct Task { /// Timestamp of the last successful task run #[serde(skip, default)] pub last_run: Option, - /// Timestamp when the task should next run (None if not yet scheduled) - #[serde(skip, default)] - pub next_run: Option, + /// Timestamp when the task should next run + #[serde(skip, default = "default_systemtime")] + pub next_run: SystemTime, /// Completion offset: the WAL position up to which the task has processed records pub completion_offset: u64, /// Minimum number of new records required before the task runs again @@ -72,9 +175,18 @@ pub struct Task { #[serde(skip, default)] pub is_deleted: bool, /// Timestamp when the task was created + #[serde(default = "default_systemtime")] pub created_at: SystemTime, /// Timestamp when the task was last updated + #[serde(default = "default_systemtime")] pub updated_at: SystemTime, + /// Next nonce (UUIDv7) for task execution tracking + pub next_nonce: NonceUuid, + /// Lowest live nonce (UUIDv7) - marks the earliest epoch that still needs verification + /// When lowest_live_nonce is Some and < next_nonce, it indicates finish_task failed and we should + /// skip execution and only run the scout_logs recheck phase + /// None indicates the task has never been scheduled (brand new task) + pub lowest_live_nonce: Option, } /// ScheduleEntry represents a scheduled task run for a collection. @@ -82,7 +194,7 @@ pub struct Task { pub struct ScheduleEntry { pub collection_id: CollectionUuid, pub task_id: Uuid, - pub task_run_nonce: Uuid, + pub task_run_nonce: NonceUuid, pub when_to_run: Option>, } @@ -117,7 +229,7 @@ impl TryFrom for ScheduleEntry { "task_run_nonce".to_string(), )) .and_then(|nonce| { - Uuid::parse_str(&nonce).map_err(|_| { + Uuid::parse_str(&nonce).map(NonceUuid).map_err(|_| { ScheduleEntryConversionError::InvalidUuid("task_run_nonce".to_string()) }) })?; diff --git a/rust/worker/Cargo.toml b/rust/worker/Cargo.toml index a578d12de30..0046b856e98 100644 --- a/rust/worker/Cargo.toml +++ b/rust/worker/Cargo.toml @@ -79,8 +79,10 @@ shuttle = { workspace = true } rand = { workspace = true } rand_xorshift = { workspace = true } tempfile = { workspace = true } +reqwest = { workspace = true, features = ["json"] } chroma-benchmark = { workspace = true } +chroma-frontend = { workspace = true } [[bench]] name = "filter" diff --git a/rust/worker/src/compactor/compaction_manager.rs b/rust/worker/src/compactor/compaction_manager.rs index 7fa236e2dea..4b05bc94a24 100644 --- a/rust/worker/src/compactor/compaction_manager.rs +++ b/rust/worker/src/compactor/compaction_manager.rs @@ -2,6 +2,7 @@ use super::scheduler::Scheduler; use super::scheduler_policy::LasCompactionTimeSchedulerPolicy; use super::OneOffCompactMessage; use super::RebuildMessage; +use crate::compactor::tasks::SchedulableTask; use crate::compactor::types::{ListDeadJobsMessage, ScheduledCompactMessage}; use crate::config::CompactionServiceConfig; use crate::execution::operators::purge_dirty_log::PurgeDirtyLog; @@ -31,7 +32,7 @@ use chroma_system::Dispatcher; use chroma_system::Orchestrator; use chroma_system::TaskResult; use chroma_system::{Component, ComponentContext, ComponentHandle, Handler, System}; -use chroma_types::CollectionUuid; +use chroma_types::{CollectionUuid, JobUuid}; use futures::stream::FuturesUnordered; use futures::FutureExt; use futures::StreamExt; @@ -58,13 +59,19 @@ use uuid::Uuid; type CompactionOutput = Result>; type BoxedFuture = Pin + Send>>; +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum JobMode { + Compaction, + Task, +} + struct CompactionTask { - collection_id: CollectionUuid, + job_uuid: JobUuid, future: BoxedFuture, } struct CompactionTaskCompletion { - collection_id: CollectionUuid, + job_uuid: JobUuid, result: CompactionOutput, } @@ -94,6 +101,7 @@ pub(crate) struct CompactionManagerContext { } pub(crate) struct CompactionManager { + mode: JobMode, scheduler: Scheduler, context: CompactionManagerContext, compact_awaiter_channel: mpsc::Sender, @@ -119,6 +127,7 @@ impl ChromaError for CompactionError { impl CompactionManager { #[allow(clippy::too_many_arguments)] pub(crate) fn new( + mode: JobMode, system: System, scheduler: Scheduler, log: Log, @@ -147,6 +156,7 @@ impl CompactionManager { compact_awaiter_loop(compact_awaiter_rx, completion_tx).await; }); CompactionManager { + mode, scheduler, context: CompactionManagerContext { system, @@ -178,28 +188,60 @@ impl CompactionManager { self.process_completions(); let compact_awaiter_channel = &self.compact_awaiter_channel; self.scheduler.schedule().await; - let jobs_iter = self.scheduler.get_jobs(); - for job in jobs_iter { - let instrumented_span = span!( - parent: None, - tracing::Level::INFO, - "Compacting job", - collection_id = ?job.collection_id - ); - Span::current().add_link(instrumented_span.context().span().span_context().clone()); - - let future = self - .context - .clone() - .compact(job.collection_id, false) - .instrument(instrumented_span); - compact_awaiter_channel - .send(CompactionTask { - collection_id: job.collection_id, - future: Box::pin(future), - }) - .await - .unwrap(); + + match self.mode { + JobMode::Compaction => { + let jobs_iter = self.scheduler.get_jobs(); + for job in jobs_iter { + let instrumented_span = span!( + parent: None, + tracing::Level::INFO, + "Compacting job", + collection_id = ?job.collection_id + ); + Span::current() + .add_link(instrumented_span.context().span().span_context().clone()); + + let future = self + .context + .clone() + .compact(job.collection_id, false) + .instrument(instrumented_span); + compact_awaiter_channel + .send(CompactionTask { + job_uuid: job.collection_id.into(), + future: Box::pin(future), + }) + .await + .unwrap(); + } + } + JobMode::Task => { + let tasks_iter = self.scheduler.get_tasks_scheduled_for_execution(); + for task in tasks_iter { + let instrumented_span = span!( + parent: None, + tracing::Level::INFO, + "Compacting task", + collection_id = ?task.collection_id + ); + Span::current() + .add_link(instrumented_span.context().span().span_context().clone()); + + let future = self + .context + .clone() + .execute_task(task.clone()) + .instrument(instrumented_span); + compact_awaiter_channel + .send(CompactionTask { + job_uuid: task.task_id.into(), + future: Box::pin(future), + }) + .await + .unwrap(); + } + } } } @@ -215,6 +257,9 @@ impl CompactionManager { #[instrument(name = "CompactionManager::purge_dirty_log", skip(ctx))] pub(crate) async fn purge_dirty_log(&mut self, ctx: &ComponentContext) { + if !matches!(self.mode, JobMode::Compaction) { + return; // Tasks don't purge logs + } let deleted_collection_uuids = self.scheduler.drain_deleted_collections(); if deleted_collection_uuids.is_empty() { tracing::info!("Skipping purge dirty log because there is no deleted collections"); @@ -251,6 +296,9 @@ impl CompactionManager { #[instrument(name = "CompactionManager::repair_log_offsets", skip(ctx))] pub(crate) async fn repair_log_offsets(&mut self, ctx: &ComponentContext) { + if !matches!(self.mode, JobMode::Compaction) { + return; // Tasks don't repair offsets + } let log_offsets_to_repair = self.scheduler.drain_collections_requiring_repair(); if log_offsets_to_repair.is_empty() { tracing::info!("No offsets to repair"); @@ -292,28 +340,30 @@ impl CompactionManager { while let Ok(resp) = compact_awaiter_completion_channel.try_recv() { match resp.result { Ok(ref compaction_response) => match compaction_response { - CompactionResponse::Success { collection_id } => { - if *collection_id != resp.collection_id { - tracing::event!(Level::ERROR, name = "mismatched collection ids in result", lhs =? *collection_id, rhs =? resp.collection_id); + CompactionResponse::Success { job_id } => { + if job_id != &resp.job_uuid.0 { + tracing::event!(Level::ERROR, name = "mismatched collection ids in result", lhs =? *job_id, rhs =? resp.job_uuid); } - self.scheduler.succeed_collection(resp.collection_id); + self.scheduler.succeed_job(resp.job_uuid); } CompactionResponse::RequireCompactionOffsetRepair { collection_id, witnessed_offset_in_sysdb, } => { - if *collection_id != resp.collection_id { - tracing::event!(Level::ERROR, name = "mismatched collection ids in result", lhs =? *collection_id, rhs =? resp.collection_id); - self.scheduler.succeed_collection(resp.collection_id); + if collection_id.0 != resp.job_uuid.0 { + tracing::event!(Level::ERROR, name = "mismatched collection ids in result", lhs =? *collection_id, rhs =? resp.job_uuid); + self.scheduler.fail_job(resp.job_uuid); } else { - self.scheduler - .require_repair(resp.collection_id, *witnessed_offset_in_sysdb); - self.scheduler.succeed_collection(resp.collection_id); + self.scheduler.require_repair( + chroma_types::CollectionUuid(resp.job_uuid.0), + *witnessed_offset_in_sysdb, + ); + self.scheduler.succeed_job(resp.job_uuid); } } }, Err(_) => { - self.scheduler.fail_collection(resp.collection_id); + self.scheduler.fail_job(resp.job_uuid); } } completed_collections.push(resp); @@ -345,7 +395,7 @@ impl CompactionManagerContext { }; let orchestrator = CompactOrchestrator::new( - collection_id, + collection_id, // input_collection_id rebuild, self.fetch_log_batch_size, self.max_compaction_size, @@ -372,6 +422,46 @@ impl CompactionManagerContext { } } } + + async fn execute_task(self, task: SchedulableTask) -> CompactionOutput { + tracing::info!("Executing task {}", task.task_id); + let dispatcher = match self.dispatcher { + Some(ref dispatcher) => dispatcher.clone(), + None => { + tracing::error!("No dispatcher found"); + return Err(Box::new(CompactionError::FailedToCompact)); + } + }; + + let orchestrator = CompactOrchestrator::new_for_task( + task.collection_id, + false, + self.fetch_log_batch_size, + self.max_compaction_size, + self.max_partition_size, + self.log.clone(), + self.sysdb.clone(), + self.blockfile_provider.clone(), + self.hnsw_index_provider.clone(), + self.spann_provider.clone(), + dispatcher, + None, + task.task_id, + task.nonce, + ); + match orchestrator.run(self.system.clone()).await { + Ok(result) => { + tracing::info!("Task {} completed: {:?}", task.task_id, result); + Ok(result) + } + Err(e) => { + if e.should_trace_error() { + tracing::error!("Task {} failed: {:?}", task.task_id, e); + } + Err(Box::new(e)) + } + } + } } #[async_trait] @@ -429,6 +519,7 @@ impl Configurable<(CompactionServiceConfig, System)> for CompactionManager { let job_expiry_seconds = config.compactor.job_expiry_seconds; let max_failure_count = config.compactor.max_failure_count; let scheduler = Scheduler::new( + JobMode::Compaction, // Default to Compaction mode my_ip, log.clone(), sysdb.clone(), @@ -465,6 +556,7 @@ impl Configurable<(CompactionServiceConfig, System)> for CompactionManager { .await?; Ok(CompactionManager::new( + JobMode::Compaction, // Default to Compaction mode system.clone(), scheduler, log, @@ -484,6 +576,84 @@ impl Configurable<(CompactionServiceConfig, System)> for CompactionManager { )) } } +pub(crate) async fn create_taskrunner_manager( + config: &CompactionServiceConfig, + task_config: &crate::compactor::config::TaskRunnerConfig, + system: System, + _dispatcher: ComponentHandle, + registry: &Registry, +) -> Result> { + let log_config = &config.log; + let log = Log::try_from_config(&(log_config.clone(), system.clone()), registry).await?; + + let sysdb_config = &config.sysdb; + let sysdb = SysDb::try_from_config(sysdb_config, registry).await?; + + let storage = Storage::try_from_config(&config.storage, registry).await?; + + let my_ip = config.my_member_id.clone(); + let policy = Box::new(LasCompactionTimeSchedulerPolicy {}); + let assignment_policy_config = &config.assignment_policy; + let assignment_policy = + Box::::try_from_config(assignment_policy_config, registry).await?; + + let scheduler = Scheduler::new( + JobMode::Task, // Taskrunner mode + my_ip, + log.clone(), + sysdb.clone(), + storage.clone(), + policy, + task_config.max_concurrent_jobs, + 0, // min_compaction_size not used for tasks + assignment_policy, + HashSet::new(), // disabled_collections not used for tasks + task_config.job_expiry_seconds, + task_config.max_failure_count, + ); + + let blockfile_provider = BlockfileProvider::try_from_config( + &(config.blockfile_provider.clone(), storage.clone()), + registry, + ) + .await?; + + let hnsw_index_provider = HnswIndexProvider::try_from_config( + &(config.hnsw_provider.clone(), storage.clone()), + registry, + ) + .await?; + + let spann_provider = SpannProvider::try_from_config( + &( + hnsw_index_provider.clone(), + blockfile_provider.clone(), + config.spann_provider.clone(), + ), + registry, + ) + .await?; + + Ok(CompactionManager::new( + JobMode::Task, // Taskrunner mode + system.clone(), + scheduler, + log, + sysdb, + storage.clone(), + blockfile_provider, + hnsw_index_provider, + spann_provider, + task_config.compaction_manager_queue_size, + Duration::from_secs(task_config.compaction_interval_sec), + 0, // min_compaction_size not used for tasks + task_config.max_compaction_size, + task_config.max_partition_size, + task_config.fetch_log_batch_size, + 0, // purge_dirty_log_timeout_seconds not used for tasks + 0, // repair_log_offsets_timeout_seconds not used for tasks + )) +} async fn compact_awaiter_loop( mut job_rx: mpsc::Receiver, @@ -497,21 +667,21 @@ async fn compact_awaiter_loop( let result = AssertUnwindSafe(job.future).catch_unwind().await; match result { Ok(response) => CompactionTaskCompletion { - collection_id: job.collection_id, + job_uuid: job.job_uuid, result: response, }, Err(_) => CompactionTaskCompletion { - collection_id: job.collection_id, + job_uuid: job.job_uuid, result: Err(Box::new(CompactionError::FailedToCompact)), }, } }); } Some(completed_job) = futures.next() => { - let collection_id = completed_job.collection_id; + let job_uuid = completed_job.job_uuid; match completion_tx.send(completed_job) { Ok(_) => {}, - Err(_) => tracing::error!("Failed to record compaction result for collection {}", collection_id), + Err(_) => tracing::error!("Failed to record compaction result for job {}", job_uuid), } } else => { @@ -696,7 +866,7 @@ impl Handler for CompactionManager { _ctx: &ComponentContext, ) { let dead_jobs = self.scheduler.get_dead_jobs(); - if let Err(e) = message.response_tx.send(dead_jobs) { + if let Err(e) = message.response_tx.send(dead_jobs.into_iter().collect()) { tracing::error!("Failed to send dead jobs response: {:?}", e); } } @@ -900,6 +1070,7 @@ mod tests { assignment_policy.set_members(vec![my_member.member_id.clone()]); let mut scheduler = Scheduler::new( + JobMode::Compaction, my_member.member_id.clone(), log.clone(), sysdb.clone(), @@ -951,6 +1122,7 @@ mod tests { }; let system = System::new(); let mut manager = CompactionManager::new( + JobMode::Compaction, system.clone(), scheduler, log, @@ -982,7 +1154,8 @@ mod tests { let start = std::time::Instant::now(); let timeout = std::time::Duration::from_secs(10); - let expected_compactions = HashSet::from([collection_uuid_1, collection_uuid_2]); + let expected_compactions = + HashSet::from([collection_uuid_1.into(), collection_uuid_2.into()]); let mut completed_compactions = HashSet::new(); @@ -993,8 +1166,8 @@ mod tests { completed .iter() .filter(|c| c.result.is_ok()) - .map(|c| c.collection_id) - .collect::>(), + .map(|c| c.job_uuid) + .collect::>(), ); } @@ -1032,6 +1205,7 @@ mod tests { assignment_policy.set_members(vec!["test-member".to_string()]); let mut scheduler = Scheduler::new( + JobMode::Compaction, "test-member".to_string(), Log::InMemory(InMemoryLog::new()), SysDb::Test(TestSysDb::new()), @@ -1047,11 +1221,11 @@ mod tests { // Simulate a dead job by marking a collection as killed (moved to dead_jobs) let test_collection_id = CollectionUuid::new(); - scheduler.kill_collection(test_collection_id); + scheduler.kill_job(test_collection_id.into()); // Verify it's in dead jobs let dead_jobs = scheduler.get_dead_jobs(); assert_eq!(dead_jobs.len(), 1); - assert!(dead_jobs.contains(&test_collection_id)); + assert!(dead_jobs.contains(&test_collection_id.into())); } } diff --git a/rust/worker/src/compactor/config.rs b/rust/worker/src/compactor/config.rs index b13c6ecc3db..26b981aa795 100644 --- a/rust/worker/src/compactor/config.rs +++ b/rust/worker/src/compactor/config.rs @@ -1,5 +1,82 @@ use serde::{Deserialize, Serialize}; +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct TaskRunnerConfig { + #[serde(default = "TaskRunnerConfig::default_enabled")] + pub enabled: bool, + #[serde(default = "TaskRunnerConfig::default_compaction_manager_queue_size")] + pub compaction_manager_queue_size: usize, + #[serde(default = "TaskRunnerConfig::default_job_expiry_seconds")] + pub job_expiry_seconds: u64, + #[serde(default = "TaskRunnerConfig::default_max_concurrent_jobs")] + pub max_concurrent_jobs: usize, + #[serde(default = "TaskRunnerConfig::default_compaction_interval_sec")] + pub compaction_interval_sec: u64, + #[serde(default = "TaskRunnerConfig::default_max_compaction_size")] + pub max_compaction_size: usize, + #[serde(default = "TaskRunnerConfig::default_max_partition_size")] + pub max_partition_size: usize, + #[serde(default = "TaskRunnerConfig::default_fetch_log_batch_size")] + pub fetch_log_batch_size: u32, + #[serde(default = "TaskRunnerConfig::default_max_failure_count")] + pub max_failure_count: u8, +} + +impl TaskRunnerConfig { + fn default_enabled() -> bool { + false // Disabled by default for safety + } + + fn default_compaction_manager_queue_size() -> usize { + 1000 + } + + fn default_max_concurrent_jobs() -> usize { + 50 // Lower default for tasks + } + + fn default_compaction_interval_sec() -> u64 { + 30 // More frequent for tasks + } + + fn default_job_expiry_seconds() -> u64 { + 3600 + } + + fn default_max_compaction_size() -> usize { + 10_000 + } + + fn default_max_partition_size() -> usize { + 5_000 + } + + fn default_fetch_log_batch_size() -> u32 { + 100 + } + + fn default_max_failure_count() -> u8 { + 5 + } +} + +impl Default for TaskRunnerConfig { + fn default() -> Self { + TaskRunnerConfig { + enabled: TaskRunnerConfig::default_enabled(), + compaction_manager_queue_size: TaskRunnerConfig::default_compaction_manager_queue_size( + ), + job_expiry_seconds: TaskRunnerConfig::default_job_expiry_seconds(), + max_concurrent_jobs: TaskRunnerConfig::default_max_concurrent_jobs(), + compaction_interval_sec: TaskRunnerConfig::default_compaction_interval_sec(), + max_compaction_size: TaskRunnerConfig::default_max_compaction_size(), + max_partition_size: TaskRunnerConfig::default_max_partition_size(), + fetch_log_batch_size: TaskRunnerConfig::default_fetch_log_batch_size(), + max_failure_count: TaskRunnerConfig::default_max_failure_count(), + } + } +} + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct CompactorConfig { #[serde(default = "CompactorConfig::default_compaction_manager_queue_size")] diff --git a/rust/worker/src/compactor/scheduler.rs b/rust/worker/src/compactor/scheduler.rs index 31c2b3c42e8..8972724f817 100644 --- a/rust/worker/src/compactor/scheduler.rs +++ b/rust/worker/src/compactor/scheduler.rs @@ -8,15 +8,16 @@ use chroma_log::{CollectionInfo, CollectionRecord, Log}; use chroma_memberlist::memberlist_provider::Memberlist; use chroma_storage::Storage; use chroma_sysdb::{GetCollectionsOptions, SysDb}; -use chroma_types::CollectionUuid; +use chroma_types::{CollectionUuid, JobUuid}; use figment::providers::Env; use figment::Figment; use s3heap_service::SysDbScheduler; use serde::Deserialize; use uuid::Uuid; +use crate::compactor::compaction_manager::JobMode; use crate::compactor::scheduler_policy::SchedulerPolicy; -use crate::compactor::tasks::TaskHeapReader; +use crate::compactor::tasks::{SchedulableTask, TaskHeapReader}; use crate::compactor::types::CompactionJob; #[derive(Debug, Clone)] @@ -82,6 +83,7 @@ impl InProgressJob { } pub(crate) struct Scheduler { + mode: JobMode, my_member_id: String, log: Log, sysdb: SysDb, @@ -95,13 +97,14 @@ pub(crate) struct Scheduler { disabled_collections: HashSet, deleted_collections: HashSet, collections_needing_repair: HashMap, - in_progress_jobs: HashMap, + in_progress_jobs: HashMap, job_expiry_seconds: u64, - failing_jobs: HashMap, - dead_jobs: HashSet, + failing_jobs: HashMap, + dead_jobs: HashSet, max_failure_count: u8, metrics: SchedulerMetrics, tasks: TaskHeapReader, + task_queue: Vec, } #[derive(Deserialize, Debug)] @@ -112,6 +115,7 @@ struct RunTimeConfig { impl Scheduler { #[allow(clippy::too_many_arguments)] pub(crate) fn new( + mode: JobMode, my_ip: String, log: Log, sysdb: SysDb, @@ -129,6 +133,7 @@ impl Scheduler { let tasks = TaskHeapReader::new(storage, heap_scheduler); Scheduler { + mode, my_member_id: my_ip, log, sysdb, @@ -149,6 +154,7 @@ impl Scheduler { dead_jobs: HashSet::new(), metrics: SchedulerMetrics::default(), tasks, + task_queue: Vec::with_capacity(max_concurrent_jobs), } } @@ -173,7 +179,7 @@ impl Scheduler { .insert(collection_id, offset_in_sysdb); } - pub(crate) fn get_dead_jobs(&self) -> Vec { + pub(crate) fn get_dead_jobs(&self) -> Vec { self.dead_jobs.iter().cloned().collect() } @@ -203,7 +209,7 @@ impl Scheduler { for collection_info in collections { let failure_count = self .failing_jobs - .get(&collection_info.collection_id) + .get(&collection_info.collection_id.into()) .map(|job| job.failure_count()) .unwrap_or(0); @@ -213,13 +219,15 @@ impl Scheduler { collection_info.collection_id, self.max_failure_count ); - self.kill_collection(collection_info.collection_id); + self.kill_job(collection_info.collection_id.into()); continue; } if self .disabled_collections .contains(&collection_info.collection_id) - || self.dead_jobs.contains(&collection_info.collection_id) + || self + .dead_jobs + .contains(&collection_info.collection_id.into()) { tracing::info!( "Ignoring collection: {:?} because it disabled for compaction", @@ -389,13 +397,13 @@ impl Scheduler { } fn is_job_in_progress(&mut self, collection_id: &CollectionUuid) -> bool { - match self.in_progress_jobs.get(collection_id) { + match self.in_progress_jobs.get(&(*collection_id).into()) { Some(job) if job.is_expired() => { tracing::info!( "Compaction for {} is expired, removing from dedup set.", collection_id ); - self.fail_collection(*collection_id); + self.fail_job((*collection_id).into()); false } Some(_) => true, @@ -404,52 +412,51 @@ impl Scheduler { } fn add_in_progress(&mut self, collection_id: CollectionUuid) { - self.in_progress_jobs - .insert(collection_id, InProgressJob::new(self.job_expiry_seconds)); + self.in_progress_jobs.insert( + collection_id.into(), + InProgressJob::new(self.job_expiry_seconds), + ); } - pub(crate) fn succeed_collection(&mut self, collection_id: CollectionUuid) { - if self.in_progress_jobs.remove(&collection_id).is_none() { + pub(crate) fn succeed_job(&mut self, job_uuid: JobUuid) { + if self.in_progress_jobs.remove(&job_uuid).is_none() { tracing::warn!( "Expired compaction for {} just successfully finished.", - collection_id + job_uuid ); return; } - self.failing_jobs.remove(&collection_id); + self.failing_jobs.remove(&job_uuid); } - pub(crate) fn fail_collection(&mut self, collection_id: CollectionUuid) { - if self.in_progress_jobs.remove(&collection_id).is_none() { + pub(crate) fn fail_job(&mut self, job_uuid: JobUuid) { + if self.in_progress_jobs.remove(&job_uuid).is_none() { tracing::warn!( "Expired compaction for {} just unsuccessfully finished.", - collection_id + job_uuid ); return; } - match self.failing_jobs.get_mut(&collection_id) { + match self.failing_jobs.get_mut(&job_uuid) { Some(failed_job) => { failed_job.increment_failure(self.max_failure_count); tracing::warn!( "Job for collection {} failed {}/{} times", - collection_id, + job_uuid, failed_job.failure_count(), self.max_failure_count ); } None => { - self.failing_jobs.insert(collection_id, FailedJob::new()); - tracing::warn!( - "Job for collection {} failed for the first time", - collection_id - ); + self.failing_jobs.insert(job_uuid, FailedJob::new()); + tracing::warn!("Job for collection {} failed for the first time", job_uuid); } } } - pub(crate) fn kill_collection(&mut self, collection_id: CollectionUuid) { - self.failing_jobs.remove(&collection_id); - self.dead_jobs.insert(collection_id); + pub(crate) fn kill_job(&mut self, job_uuid: JobUuid) { + self.failing_jobs.remove(&job_uuid); + self.dead_jobs.insert(job_uuid); self.metrics.update_dead_jobs_count(self.dead_jobs.len()); } @@ -480,39 +487,83 @@ impl Scheduler { pub(crate) async fn schedule(&mut self) { // For now, we clear the job queue every time, assuming we will not have any pending jobs running self.job_queue.clear(); + self.task_queue.clear(); + if self.memberlist.is_none() || self.memberlist.as_ref().unwrap().is_empty() { tracing::error!("Memberlist is not set or empty. Cannot schedule compaction jobs."); return; } - // Recompute disabled list. - self.recompute_disabled_collections(); - let collections = self.get_collections_with_new_data().await; - let tasks = self - .tasks - .get_tasks_scheduled_for_execution( - s3heap::Limits::default().with_items(self.max_concurrent_jobs), - ) - .await; + + match self.mode { + JobMode::Compaction => { + // Recompute disabled list. + self.recompute_disabled_collections(); + let collections = self.get_collections_with_new_data().await; + if collections.is_empty() { + return; + } + let collection_records = self.verify_and_enrich_collections(collections).await; + self.schedule_internal(collection_records).await; + } + JobMode::Task => { + let tasks = self + .tasks + .get_tasks_scheduled_for_execution( + s3heap::Limits::default().with_items(self.max_concurrent_jobs), + ) + .await; + self.schedule_tasks(tasks); + } + } + } + + pub(crate) fn schedule_tasks(&mut self, tasks: Vec) { + let members = self.memberlist.as_ref().unwrap(); + let members_as_string = members + .iter() + .map(|member| member.member_id.clone()) + .collect(); + self.assignment_policy.set_members(members_as_string); for task in tasks { + let result = self + .assignment_policy + .assign_one(task.collection_id.0.to_string().as_str()); + if result.is_err() { + tracing::error!( + "Failed to assign task {} for collection {} to member: {}", + task.task_id, + task.collection_id, + result.err().unwrap() + ); + continue; + } + + let member = result.unwrap(); + if member != self.my_member_id { + continue; + } + tracing::info!( "SCHEDULING TASKS FOR {:?} {:?} {:?} {:?}", task.bucket, task.collection_id, task.task_id, - task.nonce, + task.nonce ); + + // TODO: dedup with already-enqueued tasks + self.task_queue.push(task); } - if collections.is_empty() { - return; - } - let collection_records = self.verify_and_enrich_collections(collections).await; - self.schedule_internal(collection_records).await; } pub(crate) fn get_jobs(&self) -> impl Iterator { self.job_queue.iter() } + pub(crate) fn get_tasks_scheduled_for_execution(&self) -> &Vec { + &self.task_queue + } + pub(crate) fn set_memberlist(&mut self, memberlist: Memberlist) { self.memberlist = Some(memberlist); } @@ -535,6 +586,8 @@ mod tests { use chroma_sysdb::TestSysDb; use chroma_types::{Collection, LogRecord, Operation, OperationRecord}; + use crate::compactor::compaction_manager::JobMode; + #[tokio::test] async fn test_k8s_integration_scheduler() { let storage = s3_client_for_test_with_new_bucket().await; @@ -635,6 +688,7 @@ mod tests { assignment_policy.set_members(vec![my_member.member_id.clone()]); let mut scheduler = Scheduler::new( + JobMode::Compaction, my_member.member_id.clone(), log, sysdb.clone(), @@ -667,7 +721,7 @@ mod tests { // Scheduler ignores collection that failed to fetch last compaction time assert_eq!(jobs.len(), 1); assert_eq!(jobs[0].collection_id, collection_uuid_1,); - scheduler.succeed_collection(collection_uuid_1); + scheduler.succeed_job(collection_uuid_1.into()); // Add last compaction time for tenant_2 match sysdb { @@ -684,8 +738,8 @@ mod tests { assert_eq!(jobs.len(), 2); assert_eq!(jobs[0].collection_id, collection_uuid_2,); assert_eq!(jobs[1].collection_id, collection_uuid_1,); - scheduler.succeed_collection(collection_uuid_1); - scheduler.succeed_collection(collection_uuid_2); + scheduler.succeed_job(collection_uuid_1.into()); + scheduler.succeed_job(collection_uuid_2.into()); // Set disable list. std::env::set_var( @@ -697,7 +751,7 @@ mod tests { let jobs = jobs.collect::>(); assert_eq!(jobs.len(), 1); assert_eq!(jobs[0].collection_id, collection_uuid_2,); - scheduler.succeed_collection(collection_uuid_2); + scheduler.succeed_job(collection_uuid_2.into()); std::env::set_var( "CHROMA_COMPACTION_SERVICE__COMPACTOR__DISABLED_COLLECTIONS", "[]", @@ -716,8 +770,8 @@ mod tests { let jobs = jobs.collect::>(); assert_eq!(jobs.len(), 1); assert_eq!(jobs[0].collection_id, collection_uuid_1,); - scheduler.succeed_collection(collection_uuid_1); - scheduler.succeed_collection(collection_uuid_2); + scheduler.succeed_job(collection_uuid_1.into()); + scheduler.succeed_job(collection_uuid_2.into()); std::env::set_var( "CHROMA_COMPACTION_SERVICE.COMPACTOR.DISABLED_COLLECTIONS", "[]", @@ -739,7 +793,7 @@ mod tests { scheduler.schedule().await; let jobs = scheduler.get_jobs(); assert_eq!(jobs.count(), 1); - scheduler.succeed_collection(collection_uuid_2); + scheduler.succeed_job(collection_uuid_2.into()); let members = vec![member_1.clone()]; scheduler.set_memberlist(members); @@ -754,15 +808,15 @@ mod tests { let jobs = scheduler.get_jobs(); let jobs = jobs.collect::>(); assert_eq!(jobs.len(), 2); - scheduler.fail_collection(collection_uuid_1); - scheduler.succeed_collection(collection_uuid_2); + scheduler.fail_job(collection_uuid_1.into()); + scheduler.succeed_job(collection_uuid_2.into()); } scheduler.schedule().await; let jobs = scheduler.get_jobs(); let jobs = jobs.collect::>(); assert_eq!(jobs.len(), 1); assert_eq!(jobs[0].collection_id, collection_uuid_2); - scheduler.succeed_collection(collection_uuid_2); + scheduler.succeed_job(collection_uuid_2.into()); } #[tokio::test] @@ -891,6 +945,7 @@ mod tests { assignment_policy.set_members(vec![my_member.member_id.clone()]); let mut scheduler = Scheduler::new( + JobMode::Compaction, my_member.member_id.clone(), log, sysdb.clone(), diff --git a/rust/worker/src/compactor/tasks.rs b/rust/worker/src/compactor/tasks.rs index ecf4f48beac..3e565128b55 100644 --- a/rust/worker/src/compactor/tasks.rs +++ b/rust/worker/src/compactor/tasks.rs @@ -3,15 +3,15 @@ use std::sync::Arc; use chrono::{DateTime, Utc}; use chroma_storage::Storage; -use chroma_types::CollectionUuid; +use chroma_types::{CollectionUuid, NonceUuid, TaskUuid}; use s3heap::{heap_path_from_hostname, Error, HeapReader, HeapScheduler, Limits}; /// A task that has been scheduled for execution. #[derive(Clone, Debug)] pub struct SchedulableTask { pub collection_id: CollectionUuid, - pub task_id: uuid::Uuid, - pub nonce: uuid::Uuid, + pub task_id: TaskUuid, + pub nonce: NonceUuid, pub bucket: DateTime, } @@ -70,10 +70,11 @@ impl TaskHeapReader { Ok(items) => { tracing::trace!("Found {} tasks in {}", items.len(), heap_prefix); for (bucket, item) in items { + let collection_id = CollectionUuid(*item.trigger.partitioning.as_uuid()); all_tasks.push(SchedulableTask { - collection_id: CollectionUuid(*item.trigger.partitioning.as_uuid()), - task_id: *item.trigger.scheduling.as_uuid(), - nonce: item.nonce, + collection_id, + task_id: TaskUuid(*item.trigger.scheduling.as_uuid()), + nonce: NonceUuid(item.nonce), bucket, }); } diff --git a/rust/worker/src/compactor/types.rs b/rust/worker/src/compactor/types.rs index c261c46bff9..f9957046e0b 100644 --- a/rust/worker/src/compactor/types.rs +++ b/rust/worker/src/compactor/types.rs @@ -1,4 +1,4 @@ -use chroma_types::CollectionUuid; +use chroma_types::{CollectionUuid, JobUuid}; use tokio::sync::oneshot; #[derive(Clone, Eq, PartialEq, Debug)] @@ -21,5 +21,5 @@ pub struct RebuildMessage { #[derive(Debug)] pub struct ListDeadJobsMessage { - pub response_tx: oneshot::Sender>, + pub response_tx: oneshot::Sender>, } diff --git a/rust/worker/src/config.rs b/rust/worker/src/config.rs index 43b3f12a91e..0cf5163c091 100644 --- a/rust/worker/src/config.rs +++ b/rust/worker/src/config.rs @@ -228,6 +228,8 @@ pub struct CompactionServiceConfig { #[serde(default)] pub compactor: crate::compactor::config::CompactorConfig, #[serde(default)] + pub taskrunner: Option, + #[serde(default)] pub blockfile_provider: chroma_blockstore::config::BlockfileProviderConfig, #[serde(default)] pub hnsw_provider: chroma_index::config::HnswProviderConfig, diff --git a/rust/worker/src/execution/operators/execute_task.rs b/rust/worker/src/execution/operators/execute_task.rs new file mode 100644 index 00000000000..0e204ab31f5 --- /dev/null +++ b/rust/worker/src/execution/operators/execute_task.rs @@ -0,0 +1,192 @@ +use async_trait::async_trait; +use chroma_blockstore::provider::BlockfileProvider; +use chroma_error::ChromaError; +use chroma_log::Log; +use chroma_segment::blockfile_record::{RecordSegmentReader, RecordSegmentReaderCreationError}; +use chroma_system::{Operator, OperatorType}; +use chroma_types::{Chunk, LogRecord, Operation, OperationRecord, Segment, UpdateMetadataValue}; +use std::sync::Arc; +use thiserror::Error; + +/// Trait for task executors that process input records and produce output records. +/// Implementors can read from the output collection to maintain state across executions. +#[async_trait] +pub trait TaskExecutor: Send + Sync + std::fmt::Debug { + /// Execute the task logic on input records. + /// + /// # Arguments + /// * `input_records` - The log records to process + /// * `output_reader` - Optional reader for the output collection's compacted data + /// + /// # Returns + /// The output records to be written to the output collection + async fn execute( + &self, + input_records: Chunk, + output_reader: Option<&RecordSegmentReader<'_>>, + ) -> Result, Box>; +} + +/// A simple counting task that maintains a running total of records processed. +/// Stores the count in a metadata field called "total_count". +#[derive(Debug)] +pub struct CountTask; + +#[async_trait] +impl TaskExecutor for CountTask { + async fn execute( + &self, + input_records: Chunk, + _output_reader: Option<&RecordSegmentReader<'_>>, + ) -> Result, Box> { + let records_count = input_records.len() as i64; + + let new_total_count = records_count; + + // Create output record with updated count + let mut metadata = std::collections::HashMap::new(); + metadata.insert( + "total_count".to_string(), + UpdateMetadataValue::Int(new_total_count), + ); + + let operation_record = OperationRecord { + id: "task_result".to_string(), + embedding: Some(vec![0.0]), + encoding: None, + metadata: Some(metadata), + document: None, + operation: Operation::Upsert, + }; + + let log_record = LogRecord { + log_offset: 0, // Will be set by caller + record: operation_record, + }; + + Ok(Chunk::new(Arc::new([log_record]))) + } +} + +/// The ExecuteTask operator executes task logic based on fetched logs. +/// Uses a TaskExecutor trait to allow different task implementations. +#[derive(Debug)] +pub struct ExecuteTaskOperator { + pub log_client: Log, + pub task_executor: Arc, +} + +/// Input for the ExecuteTask operator +#[derive(Debug)] +pub struct ExecuteTaskInput { + /// The fetched log records to process + pub log_records: Chunk, + /// The tenant ID + pub tenant_id: String, + /// The output collection ID where results are written + pub output_collection_id: String, + /// The current completion offset + pub completion_offset: u64, + /// The output collection's record segment to read existing data + pub output_record_segment: Segment, + /// Blockfile provider for reading segments + pub blockfile_provider: BlockfileProvider, +} + +/// Output from the ExecuteTask operator +#[derive(Debug)] +pub struct ExecuteTaskOutput { + /// The number of records processed in this execution + pub records_processed: u64, + /// The output log records to be partitioned and compacted + pub output_records: Chunk, +} + +#[derive(Debug, Error)] +pub enum ExecuteTaskError { + #[error("Failed to read from segment: {0}")] + SegmentRead(#[from] Box), + #[error("Failed to create record segment reader: {0}")] + RecordReader(#[from] RecordSegmentReaderCreationError), + #[error("Invalid collection UUID: {0}")] + InvalidUuid(String), +} + +impl ChromaError for ExecuteTaskError { + fn code(&self) -> chroma_error::ErrorCodes { + match self { + ExecuteTaskError::SegmentRead(e) => e.code(), + ExecuteTaskError::RecordReader(e) => e.code(), + ExecuteTaskError::InvalidUuid(_) => chroma_error::ErrorCodes::InvalidArgument, + } + } +} + +#[async_trait] +impl Operator for ExecuteTaskOperator { + type Error = ExecuteTaskError; + + fn get_type(&self) -> OperatorType { + OperatorType::IO + } + + async fn run(&self, input: &ExecuteTaskInput) -> Result { + tracing::info!( + "[ExecuteTask]: Processing {} records for output collection {}", + input.log_records.len(), + input.output_collection_id + ); + + let records_count = input.log_records.len() as u64; + + // Create record segment reader from the output collection's record segment + let record_segment_reader = match Box::pin(RecordSegmentReader::from_segment( + &input.output_record_segment, + &input.blockfile_provider, + )) + .await + { + Ok(reader) => Some(reader), + Err(e) if matches!(*e, RecordSegmentReaderCreationError::UninitializedSegment) => { + // Output collection has no data yet - this is the first run + tracing::info!("[ExecuteTask]: Output segment uninitialized - first task run"); + None + } + Err(e) => return Err((*e).into()), + }; + + // Execute the task using the provided executor + let output_records = self + .task_executor + .execute(input.log_records.clone(), record_segment_reader.as_ref()) + .await + .map_err(ExecuteTaskError::SegmentRead)?; + + // Update log offsets for output records + // completion_offset = -1 means "no records processed yet", treat as 0 + let base_offset = if input.completion_offset == u64::MAX { + 0i64 + } else { + input.completion_offset as i64 + }; + let output_records_with_offsets: Vec = output_records + .iter() + .enumerate() + .map(|(i, (log_record, _))| LogRecord { + log_offset: base_offset + i as i64, + record: log_record.record.clone(), + }) + .collect(); + + tracing::info!( + "[ExecuteTask]: Task executed successfully, produced {} output records", + output_records_with_offsets.len() + ); + + // Return the output records to be partitioned + Ok(ExecuteTaskOutput { + records_processed: records_count, + output_records: Chunk::new(Arc::from(output_records_with_offsets)), + }) + } +} diff --git a/rust/worker/src/execution/operators/finish_task.rs b/rust/worker/src/execution/operators/finish_task.rs new file mode 100644 index 00000000000..bcd518317fa --- /dev/null +++ b/rust/worker/src/execution/operators/finish_task.rs @@ -0,0 +1,138 @@ +use async_trait::async_trait; +use chroma_error::{ChromaError, ErrorCodes}; +use chroma_log::Log; +use chroma_sysdb::SysDb; +use chroma_system::Operator; +use chroma_types::{FinishTaskError as SysDbFinishTaskError, Task, TaskUuid}; +use thiserror::Error; + +/// The finish task operator is responsible for updating task state in SysDB +/// after a successful task execution run. +#[derive(Debug)] +pub struct FinishTaskOperator { + log_client: Log, + sysdb: SysDb, +} + +impl FinishTaskOperator { + /// Create a new finish task operator. + pub fn new(log_client: Log, sysdb: SysDb) -> Box { + Box::new(FinishTaskOperator { log_client, sysdb }) + } +} + +#[derive(Debug)] +/// The input for the finish task operator. +/// # Parameters +/// * `updated_task` - The updated task record from sysdb. +/// * `records_processed` - The number of records processed in this run. +/// * `sysdb` - The sysdb client. +pub struct FinishTaskInput { + // Updated Task record from sysdb + updated_task: Task, +} + +impl FinishTaskInput { + /// Create a new finish task input. + pub fn new(updated_task: Task) -> Self { + FinishTaskInput { updated_task } + } +} + +/// The output for the finish task operator. +#[derive(Debug)] +pub struct FinishTaskOutput { + pub _task_id: TaskUuid, + pub _new_completion_offset: u64, +} + +#[derive(Error, Debug)] +pub enum FinishTaskError { + #[error("Failed to scout logs: {0}")] + ScoutLogsError(String), + #[error("Failed to finish task in SysDB: {0}")] + SysDbError(#[from] SysDbFinishTaskError), +} + +impl ChromaError for FinishTaskError { + fn code(&self) -> ErrorCodes { + match self { + FinishTaskError::ScoutLogsError(_) => ErrorCodes::Internal, + FinishTaskError::SysDbError(e) => e.code(), + } + } +} + +#[async_trait] +impl Operator for FinishTaskOperator { + type Error = FinishTaskError; + + fn get_name(&self) -> &'static str { + "FinishTaskOperator" + } + + async fn run(&self, input: &FinishTaskInput) -> Result { + // Step 1: Scout the logs to see if there are any new records written since we started processing + // This recheck ensures we don't miss any records that were written during our task execution + tracing::info!( + "Rechecking logs for task {} with completion offset {}", + input.updated_task.id.0, + input.updated_task.completion_offset + ); + + // Scout the logs to check for new records written since we started processing + // This catches any records that were written during our task execution + // scout_logs returns the offset of the next record to be inserted + let mut log_client = self.log_client.clone(); + let next_log_offset = log_client + .scout_logs( + &input.updated_task.tenant_id, + input.updated_task.input_collection_id, + input.updated_task.completion_offset, + ) + .await + .map_err(|e| { + tracing::error!( + task_id = %input.updated_task.id.0, + error = %e, + "Failed to scout logs during finish_task recheck" + ); + FinishTaskError::ScoutLogsError(format!("Failed to scout logs: {}", e)) + })?; + + // Calculate how many new records were written since we started processing + let new_records_count = + next_log_offset.saturating_sub(input.updated_task.completion_offset); + let new_records_found = new_records_count >= input.updated_task.min_records_for_task; + + if new_records_found { + tracing::info!( + task_id = %input.updated_task.id.0, + new_records_count = new_records_count, + min_records_threshold = input.updated_task.min_records_for_task, + "Detected new records written during task execution that exceed threshold" + ); + + // TODO: Schedule a new task for next nonce. + } + + // Step 2: Update lowest_live_nonce to equal next_nonce + // This indicates that finish_task completed successfully and this epoch is verified + // If this fails, lowest_live_nonce < next_nonce will indicate + // that we should skip execution next time and only do the recheck phase + let mut sysdb = self.sysdb.clone(); + sysdb.finish_task(input.updated_task.id).await?; + + // TODO: delete old nonce from scheduler + + tracing::info!( + "Task {} finish_task completed. lowest_live_nonce updated", + input.updated_task.id.0, + ); + + Ok(FinishTaskOutput { + _task_id: input.updated_task.id, + _new_completion_offset: input.updated_task.completion_offset, + }) + } +} diff --git a/rust/worker/src/execution/operators/get_collection_and_segments.rs b/rust/worker/src/execution/operators/get_collection_and_segments.rs index 8a93a4adec6..642428d25dd 100644 --- a/rust/worker/src/execution/operators/get_collection_and_segments.rs +++ b/rust/worker/src/execution/operators/get_collection_and_segments.rs @@ -6,25 +6,32 @@ use chroma_types::{CollectionAndSegments, CollectionUuid, GetCollectionWithSegme use thiserror::Error; /// The `GetCollectionAndSegmentsOperator` fetches a consistent snapshot of collection and segment information +/// for both input and output collections (which may be the same for regular compaction). /// /// # Parameters /// - `sysdb`: The sysdb client -/// - `collection_id`: The id for the collection to be fetched +/// - `input_collection_id`: The id for the input collection to be fetched +/// - `output_collection_id`: The id for the output collection to be fetched /// /// # Inputs /// - No input is required /// /// # Outputs -/// - The collection and segments information. If not found, an error will be thrown +/// - The input and output collection and segments information. If not found, an error will be thrown #[derive(Clone, Debug)] pub struct GetCollectionAndSegmentsOperator { pub sysdb: SysDb, - pub collection_id: CollectionUuid, + pub input_collection_id: CollectionUuid, + pub output_collection_id: CollectionUuid, } type GetCollectionAndSegmentsInput = (); -pub type GetCollectionAndSegmentsOutput = CollectionAndSegments; +#[derive(Clone, Debug)] +pub struct GetCollectionAndSegmentsOutput { + pub input: CollectionAndSegments, + pub output: CollectionAndSegments, +} #[derive(Debug, Error)] pub enum GetCollectionAndSegmentsError { @@ -60,14 +67,29 @@ impl Operator _: &GetCollectionAndSegmentsInput, ) -> Result { tracing::trace!( - "[{}]: Collection ID {}", + "[{}]: Fetching input collection {} and output collection {}", self.get_name(), - self.collection_id.0 + self.input_collection_id.0, + self.output_collection_id.0 ); - Ok(self - .sysdb - .clone() - .get_collection_with_segments(self.collection_id) - .await?) + + let mut sysdb = self.sysdb.clone(); + + // Fetch input collection and segments + let input = sysdb + .get_collection_with_segments(self.input_collection_id) + .await?; + + // Fetch output collection and segments + // If input and output are the same collection, clone instead of fetching twice + let output = if self.input_collection_id == self.output_collection_id { + input.clone() + } else { + sysdb + .get_collection_with_segments(self.output_collection_id) + .await? + }; + + Ok(GetCollectionAndSegmentsOutput { input, output }) } } diff --git a/rust/worker/src/execution/operators/mod.rs b/rust/worker/src/execution/operators/mod.rs index dd325345b4d..985efe970a0 100644 --- a/rust/worker/src/execution/operators/mod.rs +++ b/rust/worker/src/execution/operators/mod.rs @@ -1,14 +1,17 @@ pub mod apply_log_to_segment_writer; pub mod commit_segment_writer; -pub(super) mod count_records; +pub mod count_records; +pub mod execute_task; +pub mod fetch_log; +pub(super) mod finish_task; pub mod flush_segment_writer; pub mod materialize_logs; +pub(super) mod prepare_task; pub(super) mod register; pub mod spann_bf_pl; pub(super) mod spann_centers_search; pub(super) mod spann_fetch_pl; -pub mod fetch_log; pub mod filter; pub mod get_collection_and_segments; pub mod idf; diff --git a/rust/worker/src/execution/operators/prepare_task.rs b/rust/worker/src/execution/operators/prepare_task.rs new file mode 100644 index 00000000000..05052f5e667 --- /dev/null +++ b/rust/worker/src/execution/operators/prepare_task.rs @@ -0,0 +1,257 @@ +use async_trait::async_trait; +use chroma_error::ChromaError; +use chroma_log::Log; +use chroma_sysdb::{CreateOutputCollectionForTaskError, GetTaskError, SysDb}; +use chroma_system::{Operator, OperatorType}; +use chroma_types::{AdvanceTaskError, CollectionUuid, Task}; +use thiserror::Error; + +/// The `PrepareTaskOperator` prepares a task execution by: +/// 1. Fetching the latest task state from SysDB using task_uuid +/// 2. Asserting that the input nonce matches next_nonce or lowest_live_nonce +/// 3. Determining state transition (waiting→scheduled or already scheduled) +/// 4. If transitioning to scheduled, call advance_task and scout_logs to check if we should skip +/// 5. Creating the output collection if it doesn't exist +/// +/// # Parameters +/// - `sysdb`: The sysdb client +/// - `log`: The log client for scout_logs +/// - `task_uuid`: The UUID of the task +/// +/// # Inputs +/// - `nonce`: The invocation nonce from the scheduler +/// +/// # Outputs +/// - The task object with updated state, execution_nonce, and a flag indicating whether to skip execution +#[derive(Clone, Debug)] +pub struct PrepareTaskOperator { + pub sysdb: SysDb, + pub log: Log, + pub task_uuid: chroma_types::TaskUuid, +} + +#[derive(Clone, Debug)] +pub struct PrepareTaskInput { + pub nonce: chroma_types::NonceUuid, +} + +#[derive(Clone, Debug)] +pub struct PrepareTaskOutput { + /// The task object fetched from SysDB + pub task: Task, + /// The nonce to use for this task execution + pub execution_nonce: chroma_types::NonceUuid, + /// If true, skip execution and go directly to FinishTask + /// This happens when there aren't enough new records to process + pub should_skip_execution: bool, + /// The output collection ID (created if it didn't exist) + pub output_collection_id: CollectionUuid, +} + +#[derive(Debug, Error)] +pub enum PrepareTaskError { + #[error("Task not found in SysDB")] + TaskNotFound, + #[error("Failed to get task: {0}")] + GetTask(#[from] GetTaskError), + #[error("Failed to create output collection for task: {0}")] + CreateOutputCollectionForTask(#[from] CreateOutputCollectionForTaskError), + #[error("Invalid nonce: provided={provided}, expected next={expected_next} or lowest={expected_lowest}")] + InvalidNonce { + provided: chroma_types::NonceUuid, + expected_next: chroma_types::NonceUuid, + expected_lowest: chroma_types::NonceUuid, + }, + #[error("Failed to advance task: {0}")] + AdvanceTask(#[from] AdvanceTaskError), + #[error("Failed to scout logs: {0}")] + ScoutLogsError(String), +} + +impl ChromaError for PrepareTaskError { + fn code(&self) -> chroma_error::ErrorCodes { + match self { + PrepareTaskError::TaskNotFound => chroma_error::ErrorCodes::NotFound, + PrepareTaskError::GetTask(e) => e.code(), + PrepareTaskError::CreateOutputCollectionForTask(e) => e.code(), + PrepareTaskError::InvalidNonce { .. } => chroma_error::ErrorCodes::InvalidArgument, + PrepareTaskError::AdvanceTask(e) => e.code(), + PrepareTaskError::ScoutLogsError(_) => chroma_error::ErrorCodes::Internal, + } + } +} + +#[async_trait] +impl Operator for PrepareTaskOperator { + type Error = PrepareTaskError; + + fn get_type(&self) -> OperatorType { + OperatorType::IO + } + + async fn run(&self, input: &PrepareTaskInput) -> Result { + tracing::info!( + "[{}]: Preparing task {} with nonce {}", + self.get_name(), + self.task_uuid.0, + input.nonce + ); + + let mut sysdb = self.sysdb.clone(); + let mut log = self.log.clone(); + + // 1. Fetch the task from SysDB using UUID + let mut task = sysdb + .get_task_by_uuid(self.task_uuid) + .await + .map_err(|e| match e { + GetTaskError::NotFound => PrepareTaskError::TaskNotFound, + other => PrepareTaskError::GetTask(other), + })?; + + tracing::debug!( + "[{}]: Retrieved task {} - next_nonce={}, lowest_live_nonce={:?}", + self.get_name(), + task.name, + task.next_nonce, + task.lowest_live_nonce + ); + + // 2. ASSERT: nonce must match either next_nonce or lowest_live_nonce + let matches_lowest = task.lowest_live_nonce == Some(input.nonce); + if input.nonce != task.next_nonce && !matches_lowest { + tracing::error!( + "[{}]: Invalid nonce for task {} - provided={}, expected next={} or lowest={:?}", + self.get_name(), + task.name, + input.nonce, + task.next_nonce, + task.lowest_live_nonce + ); + return Err(PrepareTaskError::InvalidNonce { + provided: input.nonce, + expected_next: task.next_nonce, + expected_lowest: task.lowest_live_nonce.unwrap_or_default(), + }); + } + + // 3. Determine state transition and whether to skip execution + let execution_nonce = input.nonce; + let mut should_skip_execution = false; + + if task + .lowest_live_nonce + .is_some_and(|lln| task.next_nonce != lln) + { + // Incomplete nonce exists - we are already **scheduled** + tracing::info!( + "[{}]: Task {} already in scheduled state (incomplete nonce exists)", + self.get_name(), + task.name + ); + + // Scout logs to see if we should skip execution (task may have already executed) + let next_log_offset = log + .scout_logs( + &task.tenant_id, + task.input_collection_id, + task.completion_offset, + ) + .await + .map_err(|e| { + tracing::error!( + "[{}]: Failed to scout logs for task {}: {}", + self.get_name(), + task.name, + e + ); + PrepareTaskError::ScoutLogsError(format!("Failed to scout logs: {}", e)) + })?; + + let new_records_count = next_log_offset.saturating_sub(task.completion_offset); + should_skip_execution = new_records_count < task.min_records_for_task; + + if should_skip_execution { + tracing::info!( + "[{}]: Skipping execution for task {} - not enough new records (new={}, min={})", + self.get_name(), + task.name, + new_records_count, + task.min_records_for_task + ); + } else { + tracing::info!( + "[{}]: Task {} will proceed with execution ({} new records available)", + self.get_name(), + task.name, + new_records_count + ); + } + } else { + // Currently **waiting**, transition to **scheduled** + tracing::info!( + "[{}]: Task {} transitioning from waiting to scheduled", + self.get_name(), + task.name + ); + + // Call advance_task to increment next_nonce and set next_run (with nonce check for concurrency safety) + // Set next_run to some reasonable delay (e.g., 60 seconds) since we're starting work + const DEFAULT_THROTTLE_INTERVAL_SECS: u64 = 60; + let advance_response = sysdb + .advance_task( + task.id, + input.nonce.0, + task.completion_offset as i64, + DEFAULT_THROTTLE_INTERVAL_SECS, // Set next_run since we're advancing nonce + ) + .await?; + + tracing::debug!( + "[{}]: Advanced task {} - new next_nonce={}", + self.get_name(), + task.name, + advance_response.next_nonce + ); + + // Update task with the new nonce values + task.next_nonce = chroma_types::NonceUuid(advance_response.next_nonce); + task.next_run = advance_response.next_run; + } + + // 4. Create output collection if it doesn't exist + let output_collection_id = if let Some(output_id) = task.output_collection_id { + // Output collection already exists + output_id + } else { + // Create new output collection atomically with task update + tracing::info!( + "[{}]: Creating output collection '{}' for task {}", + self.get_name(), + task.output_collection_name, + task.name + ); + + let collection_id = sysdb + .create_output_collection_for_task( + task.id, + task.output_collection_name.clone(), + task.tenant_id.clone(), + task.database_id.clone(), + ) + .await?; + + // Update local task object with the new output collection ID + task.output_collection_id = Some(collection_id); + + collection_id + }; + + Ok(PrepareTaskOutput { + task: task.clone(), + execution_nonce, + should_skip_execution, + output_collection_id, + }) + } +} diff --git a/rust/worker/src/execution/operators/register.rs b/rust/worker/src/execution/operators/register.rs index 1fe41c37442..ee192da9476 100644 --- a/rust/worker/src/execution/operators/register.rs +++ b/rust/worker/src/execution/operators/register.rs @@ -9,6 +9,8 @@ use chroma_types::{CollectionUuid, FlushCompactionResponse, SegmentFlushInfo}; use std::sync::Arc; use thiserror::Error; +use crate::execution::orchestration::TaskContext; + /// The register operator is responsible for flushing compaction data to the sysdb /// as well as updating the log offset in the log service. #[derive(Debug)] @@ -48,6 +50,7 @@ pub struct RegisterInput { sysdb: SysDb, log: Log, schema: Option, + task_context: Option, } impl RegisterInput { @@ -64,6 +67,7 @@ impl RegisterInput { sysdb: SysDb, log: Log, schema: Option, + task_context: Option, ) -> Self { RegisterInput { tenant, @@ -76,6 +80,7 @@ impl RegisterInput { sysdb, log, schema, + task_context, } } } @@ -83,9 +88,11 @@ impl RegisterInput { /// The output for the flush sysdb operator. /// # Parameters /// * `result` - The result of the flush compaction operation. +/// * `updated_task` - The updated task if this was a task-based compaction. #[derive(Debug)] pub struct RegisterOutput { _sysdb_registration_result: FlushCompactionResponse, + pub updated_task: Option, } #[derive(Error, Debug)] @@ -94,6 +101,8 @@ pub enum RegisterError { FlushCompactionError(#[from] FlushCompactionError), #[error("Update log offset error: {0}")] UpdateLogOffsetError(#[from] Box), + #[error("Generic error: {0}")] + Generic(String), } impl ChromaError for RegisterError { @@ -101,6 +110,7 @@ impl ChromaError for RegisterError { match self { RegisterError::FlushCompactionError(e) => e.code(), RegisterError::UpdateLogOffsetError(e) => e.code(), + RegisterError::Generic(_) => ErrorCodes::FailedPrecondition, } } @@ -108,6 +118,7 @@ impl ChromaError for RegisterError { match self { RegisterError::FlushCompactionError(e) => e.should_trace_error(), RegisterError::UpdateLogOffsetError(e) => e.should_trace_error(), + RegisterError::Generic(_) => true, } } } @@ -122,37 +133,96 @@ impl Operator for RegisterOperator { async fn run(&self, input: &RegisterInput) -> Result { let mut sysdb = input.sysdb.clone(); - let mut log = input.log.clone(); - let result = sysdb - .flush_compaction( - input.tenant.clone(), - input.collection_id, - input.log_position, - input.collection_version, - input.segment_flush_info.clone(), - input.total_records_post_compaction, - input.collection_logical_size_bytes, - input.schema.clone(), - ) - .await; - - // We must make sure that the log postion in sysdb is always greater than or equal to the log position - // in the log service. If the log position in sysdb is less than the log position in the log service, - // the we may lose data in compaction. - let sysdb_registration_result = match result { - Ok(response) => response, - Err(error) => return Err(RegisterError::FlushCompactionError(error)), - }; - let result = log - .update_collection_log_offset(&input.tenant, input.collection_id, input.log_position) - .await; - - match result { - Ok(_) => Ok(RegisterOutput { - _sysdb_registration_result: sysdb_registration_result, - }), - Err(error) => Err(RegisterError::UpdateLogOffsetError(error)), + // Handle task-based vs non-task compactions separately + match &input.task_context { + Some(task_context) => { + // Extract the task - it must be present by the time we reach RegisterOperator + let task = task_context.task.as_ref().ok_or_else(|| { + RegisterError::Generic( + "Task context present but task not populated - PrepareTask should have run first" + .to_string(), + ) + })?; + + const DEFAULT_THROTTLE_INTERVAL_SECS: u64 = 60; + // log_position is "up to which offset we've compacted" + // completion_offset is "last offset processed" + // In practice, log_position means "next offset to start compacting from" + // So to get "last offset processed", we subtract 1 + let last_offset_processed = if input.log_position > 0 { + input.log_position - 1 + } else { + input.log_position // Keep as-is if 0 or negative + }; + let task_update = chroma_types::TaskUpdateInfo { + task_id: task.id, + task_run_nonce: task_context.execution_nonce.0, // Use execution_nonce from context + completion_offset: last_offset_processed, + next_run_delay_secs: DEFAULT_THROTTLE_INTERVAL_SECS, + }; + // Task-based compaction + let task_response = sysdb + .flush_compaction_and_task( + input.tenant.clone(), + input.collection_id, + input.log_position, + input.collection_version, + input.segment_flush_info.clone(), + input.total_records_post_compaction, + input.collection_logical_size_bytes, + input.schema.clone(), + task_update, + ) + .await + .map_err(RegisterError::FlushCompactionError)?; + + // Create updated task with authoritative database values + let mut updated_task = task.clone(); + updated_task.completion_offset = task_response.completion_offset; + // Note: next_run and next_nonce were already set by PrepareTask via advance_task() + // flush_compaction_and_task only updates completion_offset + + Ok(RegisterOutput { + _sysdb_registration_result: chroma_types::FlushCompactionResponse { + collection_id: task_response.collection_id, + collection_version: task_response.collection_version, + last_compaction_time: task_response.last_compaction_time, + }, + updated_task: Some(updated_task), + }) + } + None => { + // Non-task compaction + let mut log = input.log.clone(); + let response = sysdb + .flush_compaction( + input.tenant.clone(), + input.collection_id, + input.log_position, + input.collection_version, + input.segment_flush_info.clone(), + input.total_records_post_compaction, + input.collection_logical_size_bytes, + input.schema.clone(), + ) + .await + .map_err(RegisterError::FlushCompactionError)?; + + // Update log offset + log.update_collection_log_offset( + &input.tenant, + input.collection_id, + input.log_position, + ) + .await + .map_err(RegisterError::UpdateLogOffsetError)?; + + Ok(RegisterOutput { + _sysdb_registration_result: response, + updated_task: None, + }) + } } } } @@ -269,7 +339,8 @@ mod tests { size_bytes_post_compaction, sysdb.clone(), log.clone(), - None, + None, // schema + None, // task_context ); let result = operator.run(&input).await; diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs index d385ea3640b..c7a0e824121 100644 --- a/rust/worker/src/execution/orchestration/compact.rs +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -29,14 +29,15 @@ use chroma_system::{ OrchestratorContext, PanicError, TaskError, TaskMessage, TaskResult, }; use chroma_types::{ - Chunk, Collection, CollectionUuid, InternalSchema, LogRecord, SchemaError, SegmentFlushInfo, - SegmentType, SegmentUuid, + Chunk, Collection, CollectionUuid, InternalSchema, LogRecord, NonceUuid, SchemaError, Segment, + SegmentFlushInfo, SegmentType, SegmentUuid, Task, TaskUuid, }; use opentelemetry::trace::TraceContextExt; use thiserror::Error; use tokio::sync::oneshot::{error::RecvError, Sender}; use tracing::Span; use tracing_opentelemetry::OpenTelemetrySpanExt; +use uuid::Uuid; use crate::execution::operators::{ apply_log_to_segment_writer::{ @@ -47,7 +48,11 @@ use crate::execution::operators::{ CommitSegmentWriterInput, CommitSegmentWriterOperator, CommitSegmentWriterOperatorError, CommitSegmentWriterOutput, }, + execute_task::{ + CountTask, ExecuteTaskError, ExecuteTaskInput, ExecuteTaskOperator, ExecuteTaskOutput, + }, fetch_log::{FetchLogError, FetchLogOperator, FetchLogOutput}, + finish_task::{FinishTaskError, FinishTaskInput, FinishTaskOperator, FinishTaskOutput}, flush_segment_writer::{ FlushSegmentWriterInput, FlushSegmentWriterOperator, FlushSegmentWriterOperatorError, FlushSegmentWriterOutput, @@ -64,6 +69,7 @@ use crate::execution::operators::{ prefetch_segment::{ PrefetchSegmentError, PrefetchSegmentInput, PrefetchSegmentOperator, PrefetchSegmentOutput, }, + prepare_task::{PrepareTaskError, PrepareTaskInput, PrepareTaskOperator, PrepareTaskOutput}, register::{RegisterError, RegisterInput, RegisterOperator, RegisterOutput}, source_record_segment::{ SourceRecordSegmentError, SourceRecordSegmentInput, SourceRecordSegmentOperator, @@ -115,6 +121,14 @@ enum ExecutionState { Partition, MaterializeApplyCommitFlush, Register, + FinishTask, +} + +#[derive(Clone, Debug)] +pub(crate) struct TaskContext { + pub(crate) task_id: TaskUuid, + pub(crate) task: Option, + pub(crate) execution_nonce: NonceUuid, } #[derive(Clone, Debug)] @@ -127,14 +141,14 @@ pub(crate) struct CompactWriters { #[derive(Debug)] pub struct CompactOrchestrator { - collection_id: CollectionUuid, + // === Compaction Configuration === hnsw_index_uuid: Option, rebuild: bool, fetch_log_batch_size: u32, max_compaction_size: usize, max_partition_size: usize, - // Dependencies + // === Shared Services & Providers === context: OrchestratorContext, blockfile_provider: BlockfileProvider, log: Log, @@ -142,15 +156,31 @@ pub struct CompactOrchestrator { hnsw_provider: HnswIndexProvider, spann_provider: SpannProvider, - collection: OnceCell, + // === Input Collection (read logs/segments from) === + /// Collection to read logs and segments from + /// For regular compaction: input_collection_id == output_collection_id + /// For task compaction: input_collection_id != output_collection_id + input_collection_id: CollectionUuid, + input_collection: OnceCell, + input_segments: OnceCell>, + /// How much to pull from fetch_logs for INPUT collection + pulled_log_offset: i64, + + // === Output Collection (write compacted data to) === + /// Collection to write compacted segments to + output_collection_id: OnceCell, + output_collection: OnceCell, + output_segments: OnceCell>, + + // === Writers & Results === writers: OnceCell, flush_results: Vec, result_channel: Option>>, + + // === State Tracking === num_uncompleted_materialization_tasks: usize, num_uncompleted_tasks_by_segment: HashMap, collection_logical_size_delta_bytes: i64, - // How much to pull from fetch_logs - pulled_log_offset: i64, state: ExecutionState, // Total number of records in the collection after the compaction @@ -166,6 +196,9 @@ pub struct CompactOrchestrator { // schema after applying deltas schema: Option, + // === Task Context (optional) === + /// Available if this orchestrator is for a task + task_context: Option, } #[derive(Error, Debug)] @@ -178,8 +211,12 @@ pub enum CompactionError { Channel(#[from] ChannelError), #[error("Error commiting segment writers: {0}")] Commit(#[from] CommitSegmentWriterOperatorError), + #[error("Error executing task: {0}")] + ExecuteTask(#[from] ExecuteTaskError), #[error("Error fetching logs: {0}")] FetchLog(#[from] FetchLogError), + #[error("Error finishing task: {0}")] + FinishTask(#[from] FinishTaskError), #[error("Error flushing segment writers: {0}")] Flush(#[from] FlushSegmentWriterOperatorError), #[error("Error getting collection and segments: {0}")] @@ -198,6 +235,8 @@ pub enum CompactionError { Partition(#[from] PartitionError), #[error("Error prefetching segment: {0}")] PrefetchSegment(#[from] PrefetchSegmentError), + #[error("Error preparing task: {0}")] + PrepareTask(#[from] PrepareTaskError), #[error("Error creating record segment reader: {0}")] RecordSegmentReader(#[from] RecordSegmentReaderCreationError), #[error("Error creating record segment writer: {0}")] @@ -244,7 +283,9 @@ impl ChromaError for CompactionError { Self::ApplyLog(e) => e.should_trace_error(), Self::Channel(e) => e.should_trace_error(), Self::Commit(e) => e.should_trace_error(), + Self::ExecuteTask(e) => e.should_trace_error(), Self::FetchLog(e) => e.should_trace_error(), + Self::FinishTask(e) => e.should_trace_error(), Self::Flush(e) => e.should_trace_error(), Self::GetCollectionAndSegments(e) => e.should_trace_error(), Self::HnswSegment(e) => e.should_trace_error(), @@ -254,6 +295,7 @@ impl ChromaError for CompactionError { Self::Panic(e) => e.should_trace_error(), Self::Partition(e) => e.should_trace_error(), Self::PrefetchSegment(e) => e.should_trace_error(), + Self::PrepareTask(e) => e.should_trace_error(), Self::RecordSegmentReader(e) => e.should_trace_error(), Self::RecordSegmentWriter(e) => e.should_trace_error(), Self::Register(e) => e.should_trace_error(), @@ -269,7 +311,7 @@ impl ChromaError for CompactionError { #[derive(Debug)] pub enum CompactionResponse { Success { - collection_id: CollectionUuid, + job_id: Uuid, }, RequireCompactionOffsetRepair { collection_id: CollectionUuid, @@ -280,7 +322,7 @@ pub enum CompactionResponse { impl CompactOrchestrator { #[allow(clippy::too_many_arguments)] pub fn new( - collection_id: CollectionUuid, + input_collection_id: CollectionUuid, rebuild: bool, fetch_log_batch_size: u32, max_compaction_size: usize, @@ -294,8 +336,9 @@ impl CompactOrchestrator { result_channel: Option>>, ) -> Self { let context = OrchestratorContext::new(dispatcher); + let output_collection_cell = OnceCell::new(); + output_collection_cell.set(input_collection_id).unwrap(); CompactOrchestrator { - collection_id, hnsw_index_uuid: None, rebuild, fetch_log_batch_size, @@ -307,29 +350,125 @@ impl CompactOrchestrator { sysdb, hnsw_provider, spann_provider, - collection: OnceCell::new(), + input_collection_id, + input_collection: OnceCell::new(), + input_segments: OnceCell::new(), + pulled_log_offset: 0, + output_collection_id: output_collection_cell, + output_collection: OnceCell::new(), + output_segments: OnceCell::new(), writers: OnceCell::new(), flush_results: Vec::new(), result_channel, num_uncompleted_materialization_tasks: 0, num_uncompleted_tasks_by_segment: HashMap::new(), collection_logical_size_delta_bytes: 0, - pulled_log_offset: 0, state: ExecutionState::Pending, total_records_post_compaction: 0, num_materialized_logs: 0, segment_spans: HashMap::new(), metrics: CompactOrchestratorMetrics::default(), schema: None, + task_context: None, } } + #[allow(clippy::too_many_arguments)] + pub fn new_for_task( + input_collection_id: CollectionUuid, + rebuild: bool, + fetch_log_batch_size: u32, + max_compaction_size: usize, + max_partition_size: usize, + log: Log, + sysdb: SysDb, + blockfile_provider: BlockfileProvider, + hnsw_provider: HnswIndexProvider, + spann_provider: SpannProvider, + dispatcher: ComponentHandle, + result_channel: Option>>, + task_uuid: TaskUuid, + execution_nonce: NonceUuid, + ) -> Self { + let mut orchestrator = CompactOrchestrator::new( + input_collection_id, + rebuild, + fetch_log_batch_size, + max_compaction_size, + max_partition_size, + log, + sysdb, + blockfile_provider, + hnsw_provider, + spann_provider, + dispatcher, + result_channel, + ); + orchestrator.task_context = Some(TaskContext { + task_id: task_uuid, + task: None, + execution_nonce, + }); + orchestrator + } + async fn try_purge_hnsw(path: &Path, hnsw_index_uuid: Option) { if let Some(hnsw_index_uuid) = hnsw_index_uuid { let _ = HnswIndexProvider::purge_one_id(path, hnsw_index_uuid).await; } } + async fn do_task( + &mut self, + log_records: Chunk, + ctx: &ComponentContext, + ) { + let task_context = self.task_context.as_ref().unwrap(); + let task = task_context + .task + .as_ref() + .expect("Task should be populated by PrepareTask"); + + let output_collection = self + .output_collection + .get() + .expect("Output collection should be set"); + let output_segments = self + .output_segments + .get() + .expect("Output segments should be set"); + let output_record_segment = output_segments + .iter() + .find(|s| s.r#type == SegmentType::BlockfileRecord) + .expect("Output record segment should exist"); + + // TODO: Get the actual task executor based on operator_id + // For now, hardcode CountTask as a placeholder + let task_executor = Arc::new(CountTask); + + let execute_task_op = ExecuteTaskOperator { + log_client: self.log.clone(), + task_executor, + }; + + let execute_task_input = ExecuteTaskInput { + log_records, + tenant_id: output_collection.tenant.clone(), + output_collection_id: self.output_collection_id.get().unwrap().to_string(), + completion_offset: task.completion_offset, + output_record_segment: output_record_segment.clone(), + blockfile_provider: self.blockfile_provider.clone(), + }; + + let task_msg = wrap( + Box::new(execute_task_op), + execute_task_input, + ctx.receiver(), + self.context.task_cancellation_token.clone(), + ); + self.send(task_msg, ctx, Some(Span::current())).await; + } + async fn partition( &mut self, records: Chunk, @@ -453,7 +592,7 @@ impl CompactOrchestrator { writer, materialized_logs.clone(), writers.record_reader.clone(), - self.collection.get().and_then(|c| c.schema.clone()), + self.output_collection.get().and_then(|c| c.schema.clone()), ); let task = wrap( operator, @@ -537,12 +676,13 @@ impl CompactOrchestrator { .add(self.num_materialized_logs, &[]); self.state = ExecutionState::Register; + // Register uses OUTPUT collection let collection_cell = - self.collection + self.output_collection .get() .cloned() .ok_or(CompactionError::InvariantViolation( - "Collection information should have been obtained", + "Output collection information should have been obtained", )); let collection = match self.ok_or_terminate(collection_cell, ctx).await { Some(collection) => collection, @@ -579,6 +719,7 @@ impl CompactOrchestrator { self.sysdb.clone(), self.log.clone(), self.schema.clone(), + self.task_context.clone(), ); let task = wrap( @@ -671,11 +812,33 @@ impl Orchestrator for CompactOrchestrator { &mut self, ctx: &ComponentContext, ) -> Vec<(TaskMessage, Option)> { + // For task-based compaction, start with PrepareTask to fetch the task + if self.task_context.is_some() { + let task_context = self.task_context.as_ref().unwrap(); + return vec![( + wrap( + Box::new(PrepareTaskOperator { + sysdb: self.sysdb.clone(), + log: self.log.clone(), + task_uuid: task_context.task_id, + }), + PrepareTaskInput { + nonce: task_context.execution_nonce, + }, + ctx.receiver(), + self.context.task_cancellation_token.clone(), + ), + Some(Span::current()), + )]; + } + + // For non-task compaction, start with GetCollectionAndSegments vec![( wrap( Box::new(GetCollectionAndSegmentsOperator { sysdb: self.sysdb.clone(), - collection_id: self.collection_id, + input_collection_id: self.input_collection_id, + output_collection_id: *self.output_collection_id.get().unwrap(), }), (), ctx.receiver(), @@ -705,6 +868,59 @@ impl Orchestrator for CompactOrchestrator { } // ============== Handlers ============== +#[async_trait] +impl Handler> for CompactOrchestrator { + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let output = match self.ok_or_terminate(message.into_inner(), ctx).await { + Some(output) => output, + None => return, + }; + + tracing::info!( + "[CompactOrchestrator] PrepareTask completed, task_id={}, execution_nonce={}", + output.task.id.0, + output.execution_nonce + ); + + // Store the task and execution_nonce in task_context + let task_context = self.task_context.as_mut().unwrap(); + task_context.task = Some(output.task.clone()); + task_context.execution_nonce = output.execution_nonce; + self.output_collection_id = output.output_collection_id.into(); + + if output.should_skip_execution { + // Proceed to FinishTask + let task = wrap( + FinishTaskOperator::new(self.log.clone(), self.sysdb.clone()), + FinishTaskInput::new(output.task), + ctx.receiver(), + self.context.task_cancellation_token.clone(), + ); + self.send(task, ctx, Some(Span::current())).await; + return; + } + + // Proceed to GetCollectionAndSegments + let task = wrap( + Box::new(GetCollectionAndSegmentsOperator { + sysdb: self.sysdb.clone(), + input_collection_id: self.input_collection_id, + output_collection_id: *self.output_collection_id.get().unwrap(), + }), + (), + ctx.receiver(), + self.context.task_cancellation_token.clone(), + ); + self.send(task, ctx, Some(Span::current())).await; + } +} + #[async_trait] impl Handler> for CompactOrchestrator @@ -721,26 +937,96 @@ impl Handler return, }; - let collection = output.collection.clone(); - if self.collection.set(collection.clone()).is_err() { + // Store input collection and segments + let mut input_collection = output.input.collection.clone(); + if self.input_collection.set(input_collection.clone()).is_err() { self.terminate_with_result( Err(CompactionError::InvariantViolation( - "Collection information should not have been initialized", + "Input collection information should not have been initialized", )), ctx, ) .await; return; - }; + } + self.schema = input_collection.schema.clone(); + // Create input segments vec from individual segment fields + let input_segments = vec![ + output.input.metadata_segment.clone(), + output.input.record_segment.clone(), + output.input.vector_segment.clone(), + ]; + if self.input_segments.set(input_segments).is_err() { + self.terminate_with_result( + Err(CompactionError::InvariantViolation( + "Input segments should not have been initialized", + )), + ctx, + ) + .await; + return; + } + + // Store output collection + let output_collection = output.output.collection.clone(); + if self + .output_collection + .set(output_collection.clone()) + .is_err() + { + self.terminate_with_result( + Err(CompactionError::InvariantViolation( + "Output collection information should not have been initialized", + )), + ctx, + ) + .await; + return; + } - self.schema = collection.schema.clone(); + // Create output segments vec from individual segment fields + let output_segments = vec![ + output.output.metadata_segment.clone(), + output.output.record_segment.clone(), + output.output.vector_segment.clone(), + ]; + if self.output_segments.set(output_segments).is_err() { + self.terminate_with_result( + Err(CompactionError::InvariantViolation( + "Output segments should not have been initialized", + )), + ctx, + ) + .await; + return; + } + + // TODO move this somewhere cleaner + if self.task_context.is_some() { + let task_completion_offset = as Clone>::clone( + &self.task_context.as_ref().unwrap().task, + ) + .unwrap() + .completion_offset; + // completion_offset = u64::MAX (-1 in Go) means "no records processed yet" + // Keep it as -1 so FetchLog fetches from offset 0 + // Otherwise, completion_offset is "last offset processed" + // which matches the semantics of log_position + input_collection.log_position = if task_completion_offset == u64::MAX { + -1 + } else { + task_completion_offset as i64 + }; + } - self.pulled_log_offset = collection.log_position; + // Set pulled_log_offset from INPUT collection's log position + self.pulled_log_offset = input_collection.log_position; - let record_reader = match self + // Create record reader from INPUT segments (for reading existing data) + let input_record_reader = match self .ok_or_terminate( match Box::pin(RecordSegmentReader::from_segment( - &output.record_segment, + &output.input.record_segment, &self.blockfile_provider, )) .await @@ -759,11 +1045,11 @@ impl Handler return, }; - let log_task = match self.rebuild { + let log_task = match self.rebuild || self.task_context.is_some() { true => wrap( Box::new(SourceRecordSegmentOperator {}), SourceRecordSegmentInput { - record_segment_reader: record_reader.clone(), + record_segment_reader: input_record_reader.clone(), }, ctx.receiver(), self.context.task_cancellation_token.clone(), @@ -772,12 +1058,12 @@ impl Handler dim as usize, None => { // Collection is not yet initialized, there is no need to initialize the writers @@ -795,9 +1082,10 @@ impl Handler match self .ok_or_terminate( self.spann_provider - .write(&collection, &vector_segment, dimension) + .write(&output_collection, &vector_segment, dimension) .await, ctx, ) @@ -858,7 +1146,7 @@ impl Handler match self .ok_or_terminate( DistributedHNSWSegmentWriter::from_segment( - &collection, + &output_collection, &vector_segment, dimension, self.hnsw_provider.clone(), @@ -878,8 +1166,33 @@ impl Handler Ok(Some(reader)), + Err(err) => match *err { + RecordSegmentReaderCreationError::UninitializedSegment => Ok(None), + _ => Err(*err), + }, + }, + ctx, + ) + .await + { + Some(reader) => reader, + None => return, + }; + } let writers = CompactWriters { - record_reader: record_reader.clone().filter(|_| !self.rebuild), + record_reader: output_record_reader.clone().filter(|_| !self.rebuild), metadata_writer, record_writer, vector_writer, @@ -898,13 +1211,14 @@ impl Handler vec![output.record_segment], + true => vec![output.output.record_segment], false => { - let mut segments = vec![output.metadata_segment, output.record_segment]; + let mut segments = + vec![output.output.metadata_segment, output.output.record_segment]; if is_vector_segment_spann { - segments.push(output.vector_segment); + segments.push(output.output.vector_segment); } segments } @@ -967,7 +1281,7 @@ impl Handler> for CompactOrchestrator } None => { tracing::warn!("No logs were pulled from the log service, this can happen when the log compaction offset is behing the sysdb."); - if let Some(collection) = self.collection.get() { + if let Some(collection) = self.input_collection.get() { self.terminate_with_result( Ok(CompactionResponse::RequireCompactionOffsetRepair { collection_id: collection.collection_id, @@ -979,7 +1293,7 @@ impl Handler> for CompactOrchestrator } else { self.terminate_with_result( Err(CompactionError::InvariantViolation( - "self.collection not set", + "self.input_collection not set", )), ctx, ) @@ -988,7 +1302,38 @@ impl Handler> for CompactOrchestrator return; } } - self.partition(output, ctx).await; + + // For task-based compaction, call ExecuteTask to run task logic + if self.task_context.is_some() { + self.do_task(output, ctx).await; + } else { + // For regular compaction, go directly to partition + self.partition(output, ctx).await; + } + } +} + +#[async_trait] +impl Handler> for CompactOrchestrator { + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let output = match self.ok_or_terminate(message.into_inner(), ctx).await { + Some(output) => output, + None => return, + }; + + tracing::info!( + "[CompactOrchestrator] ExecuteTask completed. Processed {} records", + output.records_processed + ); + + // Proceed to partition the output records from the task + self.partition(output.output_records, ctx).await; } } @@ -1010,7 +1355,7 @@ impl Handler> tracing::info!("Sourced Records: {}", output.len()); // Each record should corresond to a log self.total_records_post_compaction = output.len() as u64; - if output.is_empty() { + if output.is_empty() && self.task_context.is_none() { let writers = match self.ok_or_terminate(self.get_segment_writers(), ctx).await { Some(writer) => writer, None => return, @@ -1025,6 +1370,9 @@ impl Handler> ctx, ) .await; + } else if self.task_context.is_some() { + self.pulled_log_offset = self.input_collection.get().unwrap().log_position; + self.do_task(output, ctx).await; } else { self.partition(output, ctx).await; } @@ -1217,41 +1565,96 @@ impl Handler> for CompactOrchestrator { +impl Handler> for CompactOrchestrator { type Result = (); async fn handle( &mut self, - message: TaskResult, + message: TaskResult, ctx: &ComponentContext, ) { + self.state = ExecutionState::FinishTask; + let _finish_output = match self.ok_or_terminate(message.into_inner(), ctx).await { + Some(output) => output, + None => return, + }; + + tracing::info!( + "Task finish_task completed for output collection {}", + *self.output_collection_id.get().unwrap() + ); + + // Task verification complete, terminate with success self.terminate_with_result( - message - .into_inner() - .map_err(|e| e.into()) - .map(|_| CompactionResponse::Success { - collection_id: self.collection_id, - }), + Ok(CompactionResponse::Success { + job_id: self.task_context.as_ref().unwrap().task_id.0, + }), ctx, ) .await; } } +#[async_trait] +impl Handler> for CompactOrchestrator { + type Result = (); + + async fn handle( + &mut self, + message: TaskResult, + ctx: &ComponentContext, + ) { + let register_output = match self.ok_or_terminate(message.into_inner(), ctx).await { + Some(output) => output, + None => return, + }; + + // If this was a task-based compaction, invoke finish_task operator + if let Some(updated_task) = register_output.updated_task { + tracing::info!( + "Invoking finish_task operator for task {}", + updated_task.id.0 + ); + + let finish_task_op = FinishTaskOperator::new(self.log.clone(), self.sysdb.clone()); + let finish_task_input = FinishTaskInput::new(updated_task); + + let task = wrap( + finish_task_op, + finish_task_input, + ctx.receiver(), + self.context.task_cancellation_token.clone(), + ); + self.send(task, ctx, Some(Span::current())).await; + } else { + // No task, terminate immediately with success + self.terminate_with_result( + Ok(CompactionResponse::Success { + job_id: self.output_collection_id.get().unwrap().0, + }), + ctx, + ) + .await; + } + } +} + #[cfg(test)] mod tests { + use chroma_blockstore::provider::BlockfileProvider; use chroma_config::{registry::Registry, Configurable}; use chroma_log::{ in_memory_log::{InMemoryLog, InternalLogRecord}, test::{add_delete_generator, LogGenerator}, Log, }; - use chroma_segment::test::TestDistributedSegment; + use chroma_segment::{blockfile_record::RecordSegmentReader, test::TestDistributedSegment}; use chroma_sysdb::{SysDb, TestSysDb}; use chroma_system::{Dispatcher, Orchestrator, System}; use chroma_types::{ operator::{Filter, Limit, Projection}, - DocumentExpression, DocumentOperator, MetadataExpression, PrimitiveOperator, Where, + CollectionUuid, DocumentExpression, DocumentOperator, MetadataExpression, + PrimitiveOperator, Where, }; use regex::Regex; @@ -1449,4 +1852,334 @@ mod tests { assert_eq!(new_vals, old_vals); } + + // Helper to read total_count from task result metadata + async fn get_total_count_output( + sysdb: &mut SysDb, + collection_id: CollectionUuid, + blockfile_provider: &BlockfileProvider, + ) -> i64 { + let output_info = sysdb + .get_collection_with_segments(collection_id) + .await + .expect("Should get output collection"); + let reader = + RecordSegmentReader::from_segment(&output_info.record_segment, blockfile_provider) + .await + .expect("Should create reader"); + let offset_id = reader + .get_offset_id_for_user_id("task_result") + .await + .expect("Should get offset") + .expect("task_result should exist"); + let data_record = reader + .get_data_for_offset_id(offset_id) + .await + .expect("Should get data") + .expect("Data should exist"); + let metadata = data_record.metadata.expect("Metadata should exist"); + match metadata.get("total_count") { + Some(chroma_types::MetadataValue::Int(c)) => *c, + _ => panic!("total_count should be an Int"), + } + } + + #[tokio::test] + async fn test_k8s_integration_task_execution() { + // Setup test environment + let config = RootConfig::default(); + let system = System::default(); + let registry = Registry::new(); + let dispatcher = Dispatcher::try_from_config(&config.query_service.dispatcher, ®istry) + .await + .expect("Should be able to initialize dispatcher"); + let dispatcher_handle = system.start_component(dispatcher); + + // Connect to Grpc SysDb (requires Tilt running) + let grpc_sysdb = chroma_sysdb::GrpcSysDb::try_from_config( + &chroma_sysdb::GrpcSysDbConfig { + host: "localhost".to_string(), + port: 50051, + connect_timeout_ms: 5000, + request_timeout_ms: 10000, + num_channels: 4, + }, + ®istry, + ) + .await + .expect("Should connect to grpc sysdb"); + let mut sysdb = SysDb::Grpc(grpc_sysdb); + + let test_segments = TestDistributedSegment::new().await; + let mut in_memory_log = InMemoryLog::new(); + + // Step 1: Create input collection via HTTP API + let collection_name = format!("test_task_collection_{}", uuid::Uuid::new_v4()); + // let create_response = http_client + // .post("http://localhost:8000/api/v2/tenants/default_tenant/databases/default_database/collections") + // .json(&chroma_frontend::server::CreateCollectionPayload { + // name: collection_name.clone(), + // configuration: None, + // schema: None, + // metadata: None, + // get_or_create: false, + // }) + // .send() + // .await + // .expect("Should be able to create collection"); + // assert_eq!(create_response.status(), 200); + let collection_id = CollectionUuid::new(); + sysdb + .create_collection( + test_segments.collection.tenant, + test_segments.collection.database, + collection_id, + collection_name, + vec![ + test_segments.record_segment.clone(), + test_segments.metadata_segment.clone(), + test_segments.vector_segment.clone(), + ], + None, + None, + None, + test_segments.collection.dimension, + false, + ) + .await + .expect("Collection create should be successful"); + let input_collection_id = collection_id; + let tenant = "default_tenant".to_string(); + let db = "default_database".to_string(); + + // Update input collection's log_position to -1 (no logs compacted yet) + sysdb + .flush_compaction( + tenant.clone(), + input_collection_id, + -1, // log_position = -1 means no logs compacted yet + 0, // collection_version + std::sync::Arc::new([]), // no segment flushes + 0, // total_records + 0, // size_bytes + None, // schema + ) + .await + .expect("Should be able to update log_position"); + + // Step 2: Add 50 log records + add_delete_generator + .generate_vec(1..=50) + .into_iter() + .for_each(|log| { + in_memory_log.add_log( + input_collection_id, + InternalLogRecord { + collection_id: input_collection_id, + log_offset: log.log_offset - 1, + log_ts: log.log_offset, + record: log, + }, + ) + }); + let log = Log::InMemory(in_memory_log.clone()); + let task_name = "test_count_task"; + + // Step 3: Create a task via sysdb + let task_id = sysdb + .create_task( + task_name.to_string(), + "record_counter".to_string(), + input_collection_id, + format!("test_output_collection_{}", uuid::Uuid::new_v4()), + serde_json::Value::Null, + tenant.clone(), + db.clone(), + 10, + ) + .await + .expect("Task creation should succeed"); + + // compact everything + let compact_orchestrator = CompactOrchestrator::new( + input_collection_id, + false, + 50, + 1000, + 50, + log.clone(), + sysdb.clone(), + test_segments.blockfile_provider.clone(), + test_segments.hnsw_provider.clone(), + test_segments.spann_provider.clone(), + dispatcher_handle.clone(), + None, + ); + + let result = compact_orchestrator.run(system.clone()).await; + assert!( + result.is_ok(), + "First compaction should succeed: {:?}", + result.err() + ); + + // Fetch the task to get the current nonce + let task_before_run = sysdb + .get_task_by_name(input_collection_id, task_name.to_string()) + .await + .expect("Task should be found"); + let execution_nonce = task_before_run.next_nonce; + + // Run first compaction (PrepareTask will fetch and populate the task) + let compact_orchestrator = CompactOrchestrator::new_for_task( + input_collection_id, + false, + 50, + 1000, + 50, + log.clone(), + sysdb.clone(), + test_segments.blockfile_provider.clone(), + test_segments.hnsw_provider.clone(), + test_segments.spann_provider.clone(), + dispatcher_handle.clone(), + None, + task_id, + execution_nonce, + ); + let result = compact_orchestrator.run(system.clone()).await; + assert!( + result.is_ok(), + "First task run should succeed: {:?}", + result.err() + ); + // Verify task was updated with output collection ID + let updated_task = sysdb + .get_task_by_name(input_collection_id, task_name.to_string()) + .await + .expect("Task should be found"); + assert_eq!( + updated_task.completion_offset, 49, + "Processed logs 0-49, so completion_offset should be 49 (last offset processed)" + ); + + let output_collection_id = updated_task.output_collection_id.unwrap(); + + // Verify first run: Read total_count from task result metadata + let total_count = get_total_count_output( + &mut sysdb, + output_collection_id, + &test_segments.blockfile_provider, + ) + .await; + assert_eq!( + total_count, 34, + "CountTask should have counted 34 records in input collection" + ); + + tracing::info!( + "First task run completed. CountTask result: total_count={}", + total_count + ); + + // Step 5: Add 50 more records and run again + add_delete_generator + .generate_vec(51..=100) + .into_iter() + .for_each(|log| { + in_memory_log.add_log( + input_collection_id, + InternalLogRecord { + collection_id: input_collection_id, + log_offset: log.log_offset - 1, + log_ts: log.log_offset, + record: log, + }, + ) + }); + + let log_2 = Log::InMemory(in_memory_log.clone()); + + // compact everything + let compact_orchestrator = CompactOrchestrator::new( + input_collection_id, + false, + 50, + 1000, + 50, + log_2.clone(), + sysdb.clone(), + test_segments.blockfile_provider.clone(), + test_segments.hnsw_provider.clone(), + test_segments.spann_provider.clone(), + dispatcher_handle.clone(), + None, + ); + + let result = compact_orchestrator.run(system.clone()).await; + assert!( + result.is_ok(), + "Second compaction should succeed: {:?}", + result.err() + ); + + let output_collection_id = updated_task.output_collection_id.unwrap(); + + // Fetch the task to get the updated nonce for second run + let task_before_run_2 = sysdb + .get_task_by_name(input_collection_id, "test_count_task".to_string()) + .await + .expect("Task should be found"); + let execution_nonce_2 = task_before_run_2.next_nonce; + + // Run second task (PrepareTask will fetch updated task state) + let compact_orchestrator_2 = CompactOrchestrator::new_for_task( + input_collection_id, + false, + 100, + 1000, + 50, + log_2.clone(), + sysdb.clone(), + test_segments.blockfile_provider.clone(), + test_segments.hnsw_provider.clone(), + test_segments.spann_provider.clone(), + dispatcher_handle.clone(), + None, + task_id, + execution_nonce_2, + ); + let result = compact_orchestrator_2.run(system.clone()).await; + assert!( + result.is_ok(), + "Second task run should succeed: {:?}", + result.err() + ); + + let updated_task_2 = sysdb + .get_task_by_name(input_collection_id, "test_count_task".to_string()) + .await + .expect("Task should be found"); + assert_eq!( + updated_task_2.completion_offset, 99, + "Processed logs 0-99, so completion_offset should be 99 (last offset processed)" + ); + + // Verify second run: Read updated total_count from task result metadata + let total_count_2 = get_total_count_output( + &mut sysdb, + output_collection_id, + &test_segments.blockfile_provider, + ) + .await; + assert_eq!( + total_count_2, 67, + "CountTask should have counted 67 total records in input collection" + ); + + tracing::info!( + "Task execution test completed. First run: total_count=50, Second run: total_count={}", + total_count_2 + ); + } } diff --git a/rust/worker/src/lib.rs b/rust/worker/src/lib.rs index d9c9ff52602..43c6f206ebe 100644 --- a/rust/worker/src/lib.rs +++ b/rust/worker/src/lib.rs @@ -137,6 +137,36 @@ pub async fn compaction_service_entrypoint() { let mut compaction_manager_handle = system.start_component(compaction_manager); memberlist.subscribe(compaction_manager_handle.receiver()); + // Create taskrunner manager if config is present and enabled (runtime config) + let taskrunner_manager_handle = if let Some(task_config) = &config.taskrunner { + if !task_config.enabled { + None + } else { + match crate::compactor::create_taskrunner_manager( + &config, + task_config, + system.clone(), + dispatcher_handle.clone(), + ®istry, + ) + .await + { + Ok(mut task_manager) => { + task_manager.set_dispatcher(dispatcher_handle.clone()); + let task_handle = system.start_component(task_manager); + memberlist.subscribe(task_handle.receiver()); + Some(task_handle) + } + Err(err) => { + println!("Failed to create taskrunner manager: {:?}", err); + None + } + } + } + } else { + None + }; + let mut memberlist_handle = system.start_component(memberlist); let compaction_server = CompactionServer { @@ -167,6 +197,10 @@ pub async fn compaction_service_entrypoint() { let _ = dispatcher_handle.join().await; compaction_manager_handle.stop(); let _ = compaction_manager_handle.join().await; + if let Some(mut handle) = taskrunner_manager_handle { + handle.stop(); + let _ = handle.join().await; + } system.stop().await; system.join().await; let _ = server_join_handle.await;