From 044dd0334f3653e6b1049e3483842f541d150151 Mon Sep 17 00:00:00 2001 From: featherchen Date: Thu, 24 Oct 2024 21:54:14 -0700 Subject: [PATCH] fix: TestGetTask Signed-off-by: featherchen --- flyteadmin/pkg/repositories/gormimpl/task_repo.go | 4 ++-- flyteadmin/pkg/repositories/gormimpl/task_repo_test.go | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_repo.go index c8c9a6948f8..82cbd31f406 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo.go @@ -51,8 +51,8 @@ func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models timer := r.metrics.GetDuration.Start() var tx *gorm.DB if input.Version == "" { - tx := r.db.WithContext(ctx).Where("project = ? AND domain = ? AND name = ?", input.Project, input.Domain, input.Name).Limit(1) - tx = tx.Order("version DESC") + tx = r.db.WithContext(ctx).Where(`"tasks"."project" = ? AND "tasks"."domain" = ? AND "tasks"."name" = ?`, input.Project, input.Domain, input.Name).Limit(1) + tx = tx.Order(`"tasks"."version" DESC`) tx.Find(&task) } else { tx = r.db.WithContext(ctx).Where(&models.Task{ diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go index 5204c2a1cf2..643dc66e4d2 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go @@ -85,7 +85,7 @@ func TestGetTask(t *testing.T) { GlobalMock.Logging = true GlobalMock.NewMock().WithQuery( - `SELECT * FROM "tasks" WHERE project = $1 AND domain = $2 AND name = $3 ORDER BY version DESC LIMIT 1`). + `SELECT * FROM "tasks" WHERE "tasks"."project" = $1 AND "tasks"."domain" = $2 AND "tasks"."name" = $3 ORDER BY "tasks"."version" DESC LIMIT 1`). WithReply(tasks) output, err = taskRepo.Get(context.Background(), interfaces.Identifier{ Project: project, @@ -98,8 +98,8 @@ func TestGetTask(t *testing.T) { assert.Equal(t, project, output.Project) assert.Equal(t, domain, output.Domain) assert.Equal(t, name, output.Name) - assert.Equal(t, "v2", output.Version) - assert.Equal(t, []byte{3, 4}, output.Closure) + assert.Equal(t, version, output.Version) + assert.Equal(t, []byte{1, 2}, output.Closure) assert.Equal(t, pythonTestTaskType, output.Type) }