Skip to content

Commit

Permalink
fix: hangup workflow listener when workflow run finishes (hatchet-dev…
Browse files Browse the repository at this point in the history
…#161)

* wip: class based listener pattern

* fix: workflow run listener hangups

* fix: hang up workflow listener on finished

* fix: case for current workflow run

* address review comments

* bump version

---------

Co-authored-by: g <[email protected]>
  • Loading branch information
abelanger5 and grutt authored Feb 9, 2024
1 parent d5f991d commit f8e9c8b
Show file tree
Hide file tree
Showing 20 changed files with 603 additions and 421 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dump.rdb
*.pfx
*.cert
.next
.venv

node_modules

Expand Down
4 changes: 4 additions & 0 deletions api-contracts/dispatcher/dispatcher.proto
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,8 @@ message WorkflowEvent {

// the event payload
string eventPayload = 6;

// whether this is the last event for the workflow run - server
// will hang up the connection but clients might want to case
bool hangup = 7;
}
2 changes: 1 addition & 1 deletion api/v1/server/handlers/step-runs/rerun.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (t *StepRunService) StepRunUpdateRerun(ctx echo.Context, request gen.StepRu
}

// update step run
_, err = t.config.Repository.StepRun().UpdateStepRun(tenant.ID, stepRun.ID, &repository.UpdateStepRunOpts{
_, _, err = t.config.Repository.StepRun().UpdateStepRun(tenant.ID, stepRun.ID, &repository.UpdateStepRunOpts{
Input: inputBytes,
Status: repository.StepRunStatusPtr(db.StepRunStatusPending),
IsRerun: true,
Expand Down
52 changes: 37 additions & 15 deletions internal/repository/prisma/step_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,17 +165,19 @@ var retrier = func(l *zerolog.Logger, f func() error) error {
return nil
}

func (s *stepRunRepository) UpdateStepRun(tenantId, stepRunId string, opts *repository.UpdateStepRunOpts) (*db.StepRunModel, error) {
func (s *stepRunRepository) UpdateStepRun(tenantId, stepRunId string, opts *repository.UpdateStepRunOpts) (*db.StepRunModel, *repository.StepRunUpdateInfo, error) {
if err := s.v.Validate(opts); err != nil {
return nil, err
return nil, nil, err
}

updateParams, updateJobRunLookupDataParams, resolveJobRunParams, resolveLaterStepRunsParams, err := getUpdateParams(tenantId, stepRunId, opts)

if err != nil {
return nil, err
return nil, nil, err
}

var updateInfo *repository.StepRunUpdateInfo

err = retrier(s.l, func() error {
tx, err := s.pool.Begin(context.Background())

Expand All @@ -185,7 +187,7 @@ func (s *stepRunRepository) UpdateStepRun(tenantId, stepRunId string, opts *repo

defer deferRollback(context.Background(), s.l, tx.Rollback)

err = s.updateStepRun(tx, tenantId, updateParams, updateJobRunLookupDataParams, resolveJobRunParams, resolveLaterStepRunsParams)
updateInfo, err = s.updateStepRun(tx, tenantId, updateParams, updateJobRunLookupDataParams, resolveJobRunParams, resolveLaterStepRunsParams)

if err != nil {
return err
Expand All @@ -197,10 +199,10 @@ func (s *stepRunRepository) UpdateStepRun(tenantId, stepRunId string, opts *repo
})

if err != nil {
return nil, err
return nil, nil, err
}

return s.client.StepRun.FindUnique(
stepRun, err := s.client.StepRun.FindUnique(
db.StepRun.ID.Equals(stepRunId),
).With(
db.StepRun.Children.Fetch(),
Expand All @@ -215,6 +217,12 @@ func (s *stepRunRepository) UpdateStepRun(tenantId, stepRunId string, opts *repo
),
db.StepRun.Ticker.Fetch(),
).Exec(context.Background())

if err != nil {
return nil, nil, err
}

return stepRun, updateInfo, nil
}

func (s *stepRunRepository) QueueStepRun(tenantId, stepRunId string, opts *repository.UpdateStepRunOpts) (*db.StepRunModel, error) {
Expand Down Expand Up @@ -250,7 +258,7 @@ func (s *stepRunRepository) QueueStepRun(tenantId, stepRunId string, opts *repos
return nil, repository.ErrStepRunIsNotPending
}

err = s.updateStepRun(tx, tenantId, updateParams, updateJobRunLookupDataParams, resolveJobRunParams, resolveLaterStepRunsParams)
_, err = s.updateStepRun(tx, tenantId, updateParams, updateJobRunLookupDataParams, resolveJobRunParams, resolveLaterStepRunsParams)

if err != nil {
return nil, err
Expand Down Expand Up @@ -376,46 +384,59 @@ func (s *stepRunRepository) updateStepRun(
updateJobRunLookupDataParams *dbsqlc.UpdateJobRunLookupDataWithStepRunParams,
resolveJobRunParams dbsqlc.ResolveJobRunStatusParams,
resolveLaterStepRunsParams dbsqlc.ResolveLaterStepRunsParams,
) error {
) (*repository.StepRunUpdateInfo, error) {
_, err := s.queries.UpdateStepRun(context.Background(), tx, updateParams)

if err != nil {
return fmt.Errorf("could not update step run: %w", err)
return nil, fmt.Errorf("could not update step run: %w", err)
}

_, err = s.queries.ResolveLaterStepRuns(context.Background(), tx, resolveLaterStepRunsParams)

if err != nil {
return fmt.Errorf("could not resolve later step runs: %w", err)
return nil, fmt.Errorf("could not resolve later step runs: %w", err)
}

jobRun, err := s.queries.ResolveJobRunStatus(context.Background(), tx, resolveJobRunParams)

if err != nil {
return fmt.Errorf("could not resolve job run status: %w", err)
return nil, fmt.Errorf("could not resolve job run status: %w", err)
}

resolveWorkflowRunParams := dbsqlc.ResolveWorkflowRunStatusParams{
Jobrunid: jobRun.ID,
Tenantid: sqlchelpers.UUIDFromStr(tenantId),
}

_, err = s.queries.ResolveWorkflowRunStatus(context.Background(), tx, resolveWorkflowRunParams)
workflowRun, err := s.queries.ResolveWorkflowRunStatus(context.Background(), tx, resolveWorkflowRunParams)

if err != nil {
return fmt.Errorf("could not resolve workflow run status: %w", err)
return nil, fmt.Errorf("could not resolve workflow run status: %w", err)
}

// update the job run lookup data if not nil
if updateJobRunLookupDataParams != nil {
err = s.queries.UpdateJobRunLookupDataWithStepRun(context.Background(), tx, *updateJobRunLookupDataParams)

if err != nil {
return fmt.Errorf("could not update job run lookup data: %w", err)
return nil, fmt.Errorf("could not update job run lookup data: %w", err)
}
}

return nil
return &repository.StepRunUpdateInfo{
JobRunFinalState: isFinalJobRunStatus(jobRun.Status),
WorkflowRunFinalState: isFinalWorkflowRunStatus(workflowRun.Status),
WorkflowRunId: sqlchelpers.UUIDToStr(workflowRun.ID),
WorkflowRunStatus: string(workflowRun.Status),
}, nil
}

func isFinalJobRunStatus(status dbsqlc.JobRunStatus) bool {
return status != dbsqlc.JobRunStatusPENDING && status != dbsqlc.JobRunStatusRUNNING
}

func isFinalWorkflowRunStatus(status dbsqlc.WorkflowRunStatus) bool {
return status != dbsqlc.WorkflowRunStatusPENDING && status != dbsqlc.WorkflowRunStatusRUNNING && status != dbsqlc.WorkflowRunStatusQUEUED
}

func (s *stepRunRepository) GetStepRunById(tenantId, stepRunId string) (*db.StepRunModel, error) {
Expand All @@ -432,6 +453,7 @@ func (s *stepRunRepository) GetStepRunById(tenantId, stepRunId string) (*db.Step
),
db.StepRun.JobRun.Fetch().With(
db.JobRun.LookupData.Fetch(),
db.JobRun.WorkflowRun.Fetch(),
),
db.StepRun.Ticker.Fetch(),
).Exec(context.Background())
Expand Down
9 changes: 8 additions & 1 deletion internal/repository/step_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,21 @@ func StepRunStatusPtr(status db.StepRunStatus) *db.StepRunStatus {

var ErrStepRunIsNotPending = fmt.Errorf("step run is not pending")

type StepRunUpdateInfo struct {
JobRunFinalState bool
WorkflowRunFinalState bool
WorkflowRunId string
WorkflowRunStatus string
}

type StepRunRepository interface {
// ListAllStepRuns returns a list of all step runs which match the given options.
ListAllStepRuns(opts *ListAllStepRunsOpts) ([]db.StepRunModel, error)

// ListStepRuns returns a list of step runs for a tenant which match the given options.
ListStepRuns(tenantId string, opts *ListStepRunsOpts) ([]db.StepRunModel, error)

UpdateStepRun(tenantId, stepRunId string, opts *UpdateStepRunOpts) (*db.StepRunModel, error)
UpdateStepRun(tenantId, stepRunId string, opts *UpdateStepRunOpts) (*db.StepRunModel, *StepRunUpdateInfo, error)

GetStepRunById(tenantId, stepRunId string) (*db.StepRunModel, error)

Expand Down
40 changes: 33 additions & 7 deletions internal/services/controllers/jobs/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,14 @@ func (ec *JobsControllerImpl) handleJobRunTimedOut(ctx context.Context, task *ta
now := time.Now().UTC()

// cancel current step run
stepRun, err := ec.repo.StepRun().UpdateStepRun(metadata.TenantId, currStepRun.ID, &repository.UpdateStepRunOpts{
stepRun, updateInfo, err := ec.repo.StepRun().UpdateStepRun(metadata.TenantId, currStepRun.ID, &repository.UpdateStepRunOpts{
CancelledAt: &now,
CancelledReason: repository.StringPtr("JOB_RUN_TIMED_OUT"),
Status: repository.StepRunStatusPtr(db.StepRunStatusCancelled),
})

defer ec.handleStepRunUpdateInfo(stepRun, updateInfo)

if err != nil {
return fmt.Errorf("could not update step run: %w", err)
}
Expand Down Expand Up @@ -421,12 +423,14 @@ func (ec *JobsControllerImpl) handleStepRunRequeue(ctx context.Context, task *ta

// if the current time is after the scheduleTimeoutAt, then mark this as timed out
if scheduleTimeoutAt, ok := stepRunCp.ScheduleTimeoutAt(); ok && scheduleTimeoutAt.Before(now) {
_, err = ec.repo.StepRun().UpdateStepRun(payload.TenantId, stepRunCp.ID, &repository.UpdateStepRunOpts{
stepRun, updateInfo, err := ec.repo.StepRun().UpdateStepRun(payload.TenantId, stepRunCp.ID, &repository.UpdateStepRunOpts{
CancelledAt: &now,
CancelledReason: repository.StringPtr("SCHEDULING_TIMED_OUT"),
Status: repository.StepRunStatusPtr(db.StepRunStatusCancelled),
})

defer ec.handleStepRunUpdateInfo(stepRun, updateInfo)

if err != nil {
return fmt.Errorf("could not update step run %s: %w", stepRunCp.ID, err)
}
Expand All @@ -436,7 +440,7 @@ func (ec *JobsControllerImpl) handleStepRunRequeue(ctx context.Context, task *ta

requeueAfter := time.Now().UTC().Add(time.Second * 5)

stepRun, err := ec.repo.StepRun().UpdateStepRun(payload.TenantId, stepRunCp.ID, &repository.UpdateStepRunOpts{
stepRun, _, err := ec.repo.StepRun().UpdateStepRun(payload.TenantId, stepRunCp.ID, &repository.UpdateStepRunOpts{
RequeueAfter: &requeueAfter,
})

Expand Down Expand Up @@ -679,11 +683,13 @@ func (ec *JobsControllerImpl) handleStepRunStarted(ctx context.Context, task *ta
return fmt.Errorf("could not parse started at: %w", err)
}

_, err = ec.repo.StepRun().UpdateStepRun(metadata.TenantId, payload.StepRunId, &repository.UpdateStepRunOpts{
stepRun, updateInfo, err := ec.repo.StepRun().UpdateStepRun(metadata.TenantId, payload.StepRunId, &repository.UpdateStepRunOpts{
StartedAt: &startedAt,
Status: repository.StepRunStatusPtr(db.StepRunStatusRunning),
})

defer ec.handleStepRunUpdateInfo(stepRun, updateInfo)

return err
}

Expand Down Expand Up @@ -725,12 +731,14 @@ func (ec *JobsControllerImpl) handleStepRunFinished(ctx context.Context, task *t
stepOutput = []byte(stepOutputStr)
}

stepRun, err := ec.repo.StepRun().UpdateStepRun(metadata.TenantId, payload.StepRunId, &repository.UpdateStepRunOpts{
stepRun, updateInfo, err := ec.repo.StepRun().UpdateStepRun(metadata.TenantId, payload.StepRunId, &repository.UpdateStepRunOpts{
FinishedAt: &finishedAt,
Status: repository.StepRunStatusPtr(db.StepRunStatusSucceeded),
Output: stepOutput,
})

defer ec.handleStepRunUpdateInfo(stepRun, updateInfo)

if err != nil {
return fmt.Errorf("could not update step run: %w", err)
}
Expand Down Expand Up @@ -803,12 +811,14 @@ func (ec *JobsControllerImpl) handleStepRunFailed(ctx context.Context, task *tas
return fmt.Errorf("could not parse started at: %w", err)
}

stepRun, err := ec.repo.StepRun().UpdateStepRun(metadata.TenantId, payload.StepRunId, &repository.UpdateStepRunOpts{
stepRun, updateInfo, err := ec.repo.StepRun().UpdateStepRun(metadata.TenantId, payload.StepRunId, &repository.UpdateStepRunOpts{
FinishedAt: &failedAt,
Error: &payload.Error,
Status: repository.StepRunStatusPtr(db.StepRunStatusFailed),
})

defer ec.handleStepRunUpdateInfo(stepRun, updateInfo)

if err != nil {
return fmt.Errorf("could not update step run: %w", err)
}
Expand Down Expand Up @@ -884,12 +894,14 @@ func (ec *JobsControllerImpl) cancelStepRun(ctx context.Context, tenantId, stepR
// cancel current step run
now := time.Now().UTC()

stepRun, err := ec.repo.StepRun().UpdateStepRun(tenantId, stepRunId, &repository.UpdateStepRunOpts{
stepRun, updateInfo, err := ec.repo.StepRun().UpdateStepRun(tenantId, stepRunId, &repository.UpdateStepRunOpts{
CancelledAt: &now,
CancelledReason: repository.StringPtr(reason),
Status: repository.StepRunStatusPtr(db.StepRunStatusCancelled),
})

defer ec.handleStepRunUpdateInfo(stepRun, updateInfo)

if err != nil {
return fmt.Errorf("could not update step run: %w", err)
}
Expand Down Expand Up @@ -922,6 +934,20 @@ func (ec *JobsControllerImpl) cancelStepRun(ctx context.Context, tenantId, stepR
return nil
}

func (ec *JobsControllerImpl) handleStepRunUpdateInfo(stepRun *db.StepRunModel, updateInfo *repository.StepRunUpdateInfo) {
if updateInfo.WorkflowRunFinalState {
err := ec.tq.AddTask(
context.Background(),
taskqueue.WORKFLOW_PROCESSING_QUEUE,
tasktypes.WorkflowRunFinishedToTask(stepRun.TenantID, updateInfo.WorkflowRunId, updateInfo.WorkflowRunStatus),
)

if err != nil {
ec.l.Error().Err(err).Msg("could not add workflow run finished task to task queue")
}
}
}

func (ec *JobsControllerImpl) handleTickerRemoved(ctx context.Context, task *taskqueue.Task) error {
ctx, span := telemetry.NewSpan(ctx, "handle-ticker-removed")
defer span.End()
Expand Down
Loading

0 comments on commit f8e9c8b

Please sign in to comment.