Skip to content

Commit

Permalink
Fixes 4927: add flag to prevent requeuing canceled tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
xbhouse committed Nov 6, 2024
1 parent da478a2 commit 3b1de79
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 23 deletions.
2 changes: 1 addition & 1 deletion db/migrations.latest
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
20241018154315
20241104115955

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
BEGIN;

ALTER TABLE tasks DROP COLUMN IF EXISTS cancel_attempted;

COMMIT;
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
BEGIN;

ALTER TABLE tasks ADD COLUMN IF NOT EXISTS cancel_attempted BOOLEAN DEFAULT FALSE;

UPDATE tasks SET cancel_attempted = true WHERE status = 'canceled';

COMMIT;
39 changes: 20 additions & 19 deletions pkg/models/task_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,26 @@ import (
// Shared by DAO and queue packages
// GORM only used in DAO to read from table
type TaskInfo struct {
Id uuid.UUID `gorm:"primary_key;column:id"`
Typename string `gorm:"column:type"` // "introspect" or "snapshot"
Payload json.RawMessage `gorm:"type:jsonb"`
OrgId string
AccountId string
ObjectUUID uuid.UUID
ObjectType *string
Dependencies pq.StringArray `gorm:"->;column:t_dependencies;type:text[]"`
Dependents pq.StringArray `gorm:"->;column:t_dependents;type:text[]"`
Token uuid.UUID
Queued *time.Time `gorm:"column:queued_at"`
Started *time.Time `gorm:"column:started_at"`
Finished *time.Time `gorm:"column:finished_at"`
Error *string
Status string
RequestID string
Retries int
NextRetryTime *time.Time
Priority int
Id uuid.UUID `gorm:"primary_key;column:id"`
Typename string `gorm:"column:type"` // "introspect" or "snapshot"
Payload json.RawMessage `gorm:"type:jsonb"`
OrgId string
AccountId string
ObjectUUID uuid.UUID
ObjectType *string
Dependencies pq.StringArray `gorm:"->;column:t_dependencies;type:text[]"`
Dependents pq.StringArray `gorm:"->;column:t_dependents;type:text[]"`
Token uuid.UUID
Queued *time.Time `gorm:"column:queued_at"`
Started *time.Time `gorm:"column:started_at"`
Finished *time.Time `gorm:"column:finished_at"`
Error *string
Status string
RequestID string
Retries int
NextRetryTime *time.Time
Priority int
CancelAttempted bool
}

type TaskInfoRepositoryConfiguration struct {
Expand Down
22 changes: 19 additions & 3 deletions pkg/tasks/queue/pgqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
"github.com/rs/zerolog/log"
)

const taskInfoReturning = ` id, type, payload, queued_at, started_at, finished_at, status, error, org_id, object_uuid, object_type, token, request_id, retries, next_retry_time, priority ` // fields to return when returning taskInfo
const taskInfoReturning = ` id, type, payload, queued_at, started_at, finished_at, status, error, org_id, object_uuid, object_type, token, request_id, retries, next_retry_time, priority, cancel_attempted ` // fields to return when returning taskInfo

const (
sqlNotify = `NOTIFY tasks`
Expand Down Expand Up @@ -120,6 +120,10 @@ const (
WHERE id = $1`
sqlDeleteAllTasks = `
TRUNCATE task_heartbeats, task_dependencies; DELETE FROM TASKS;`
sqlSetCancelAttempted = `
UPDATE tasks
SET cancel_attempted = true
WHERE id = $1`
)

// These interfaces represent all the interactions with pgxpool that are needed for the pgqueue
Expand Down Expand Up @@ -409,7 +413,7 @@ func (p *PgQueue) dequeueMaybe(ctx context.Context, token uuid.UUID, taskTypes [
err = tx.QueryRow(ctx, sqlDequeue, token, taskTypes).Scan(
&info.Id, &info.Typename, &info.Payload, &info.Queued, &info.Started, &info.Finished, &info.Status,
&info.Error, &info.OrgId, &info.ObjectUUID, &info.ObjectType, &info.Token, &info.RequestID,
&info.Retries, &info.NextRetryTime, &info.Priority,
&info.Retries, &info.NextRetryTime, &info.Priority, &info.CancelAttempted,
)
if err != nil {
return nil, fmt.Errorf("error during dequeue query: %w", err)
Expand Down Expand Up @@ -473,7 +477,7 @@ func (p *PgQueue) Status(taskId uuid.UUID) (*models.TaskInfo, error) {
err = conn.QueryRow(context.Background(), sqlQueryTaskStatus, taskId).Scan(
&info.Id, &info.Typename, &info.Payload, &info.Queued, &info.Started, &info.Finished, &info.Status,
&info.Error, &info.OrgId, &info.ObjectUUID, &info.ObjectType, &info.Token, &info.RequestID,
&info.Retries, &info.NextRetryTime, &info.Priority,
&info.Retries, &info.NextRetryTime, &info.Priority, &info.CancelAttempted,
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -669,6 +673,9 @@ func (p *PgQueue) Requeue(taskId uuid.UUID) error {
if err == pgx.ErrNoRows {
return ErrNotExist
}
if info.CancelAttempted {
return ErrTaskCanceled
}
if info.Started == nil || info.Finished != nil {
return ErrNotRunning
}
Expand Down Expand Up @@ -877,6 +884,10 @@ func (p *PgQueue) ListenForCancel(ctx context.Context, taskID uuid.UUID, cancelF

// Cancel context only if context has not already been canceled. If the context has already been canceled, the task has finished.
if !errors.Is(ErrNotRunning, context.Cause(ctx)) {
if err := p.setCancelAttempted(taskID); err != nil {
logger.Error().Err(err).Msg("ListenForCancel: error setting cancel_attempted")
return
}
logger.Debug().Msg("[Canceled Task]")
cancelFunc(ErrTaskCanceled)
}
Expand All @@ -889,3 +900,8 @@ func isContextCancelled(ctx context.Context) bool {
func getCancelChannelName(taskID uuid.UUID) string {
return strings.Replace("task_"+taskID.String(), "-", "", -1)
}

func (p *PgQueue) setCancelAttempted(taskID uuid.UUID) error {
_, err := p.Pool.Exec(context.Background(), sqlSetCancelAttempted, taskID)
return err
}
20 changes: 20 additions & 0 deletions pkg/tasks/queue/pgqueue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,26 @@ func (s *QueueSuite) TestRequeueFailedTasks() {
assert.True(s.T(), info.Queued.After(*originalQueueTime))
}

func (s *QueueSuite) TestCannotRequeueCanceledTasks() {
id, err := s.queue.Enqueue(&testTask)
require.NoError(s.T(), err)
assert.NotEqual(s.T(), uuid.Nil, id)

_, err = s.queue.Status(id)
require.NoError(s.T(), err)

_, err = s.queue.Dequeue(context.Background(), []string{testTaskType})
require.NoError(s.T(), err)

err = s.queue.Cancel(context.Background(), id)
require.NoError(s.T(), err)
err = s.queue.setCancelAttempted(id)
require.NoError(s.T(), err)

err = s.queue.Requeue(id)
assert.ErrorIs(s.T(), err, ErrTaskCanceled)
}

func (s *QueueSuite) TestRequeueFailedTasksExceedRetries() {
config.Get().Tasking.RetryWaitUpperBound = 0

Expand Down

0 comments on commit 3b1de79

Please sign in to comment.