diff --git a/flyteadmin/pkg/repositories/gormimpl/named_entity_repo.go b/flyteadmin/pkg/repositories/gormimpl/named_entity_repo.go index 40ec860f68..bfb55e07bd 100644 --- a/flyteadmin/pkg/repositories/gormimpl/named_entity_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/named_entity_repo.go @@ -182,7 +182,7 @@ func (r *NamedEntityRepo) List(ctx context.Context, input interfaces.ListNamedEn "Cannot list entity names for resource type: %v", input.ResourceType) } - tx := getSubQueryJoin(r.db, tableName, input) + tx := getSubQueryJoin(r.db.WithContext(ctx), tableName, input) // Apply filters tx, err := applyScopedFilters(tx, input.InlineFilters, input.MapFilters) diff --git a/flyteadmin/pkg/repositories/gormimpl/signal_repo.go b/flyteadmin/pkg/repositories/gormimpl/signal_repo.go index dccfbda748..a42c3ec463 100644 --- a/flyteadmin/pkg/repositories/gormimpl/signal_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/signal_repo.go @@ -25,7 +25,7 @@ type SignalRepo struct { func (s *SignalRepo) Get(ctx context.Context, input models.SignalKey) (models.Signal, error) { var signal models.Signal timer := s.metrics.GetDuration.Start() - tx := s.db.Where(&models.Signal{ + tx := s.db.WithContext(ctx).Where(&models.Signal{ SignalKey: input, }).Take(&signal) timer.Stop() @@ -41,7 +41,7 @@ func (s *SignalRepo) Get(ctx context.Context, input models.SignalKey) (models.Si // GetOrCreate returns a signal if it already exists, if not it creates a new one given the input func (s *SignalRepo) GetOrCreate(ctx context.Context, input *models.Signal) error { timer := s.metrics.CreateDuration.Start() - tx := s.db.FirstOrCreate(&input, input) + tx := s.db.WithContext(ctx).FirstOrCreate(&input, input) timer.Stop() if tx.Error != nil { return s.errorTransformer.ToFlyteAdminError(tx.Error) @@ -56,7 +56,7 @@ func (s *SignalRepo) List(ctx context.Context, input interfaces.ListResourceInpu return nil, err } var signals []models.Signal - tx := s.db.Limit(input.Limit).Offset(input.Offset) + tx := s.db.WithContext(ctx).Limit(input.Limit).Offset(input.Offset) // Apply filters tx, err := applyFilters(tx, input.InlineFilters, input.MapFilters) @@ -85,7 +85,7 @@ func (s *SignalRepo) Update(ctx context.Context, input models.SignalKey, value [ } timer := s.metrics.GetDuration.Start() - tx := s.db.Model(&signal).Select("value").Updates(signal) + tx := s.db.WithContext(ctx).Model(&signal).Select("value").Updates(signal) timer.Stop() if tx.Error != nil { return s.errorTransformer.ToFlyteAdminError(tx.Error) diff --git a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go index ae4c300ce4..f2ac2adf52 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go @@ -80,7 +80,7 @@ func (r *TaskExecutionRepo) Get(ctx context.Context, input interfaces.GetTaskExe func (r *TaskExecutionRepo) Update(ctx context.Context, execution models.TaskExecution) error { timer := r.metrics.UpdateDuration.Start() - tx := r.db.WithContext(ctx).WithContext(ctx).Save(&execution) // TODO @hmaersaw - need to add WithContext to all db calls to link otel spans + tx := r.db.WithContext(ctx).WithContext(ctx).Save(&execution) timer.Stop() if err := tx.Error; err != nil {