diff --git a/flyteplugins/go/tasks/pluginmachinery/catalog/async_client_impl.go b/flyteplugins/go/tasks/pluginmachinery/catalog/async_client_impl.go index d9e886d0eb..c056989905 100644 --- a/flyteplugins/go/tasks/pluginmachinery/catalog/async_client_impl.go +++ b/flyteplugins/go/tasks/pluginmachinery/catalog/async_client_impl.go @@ -48,7 +48,7 @@ func hashInputs(ctx context.Context, key Key) (string, error) { } inputs = retInputs } - return HashLiteralMap(ctx, inputs) + return HashLiteralMap(ctx, inputs, key.CacheIgnoreInputVars) } func (c AsyncClientImpl) Download(ctx context.Context, requests ...DownloadRequest) (outputFuture DownloadFuture, err error) { diff --git a/flyteplugins/go/tasks/pluginmachinery/catalog/hashing.go b/flyteplugins/go/tasks/pluginmachinery/catalog/hashing.go index 3c73710eac..4cc2fbd5cd 100644 --- a/flyteplugins/go/tasks/pluginmachinery/catalog/hashing.go +++ b/flyteplugins/go/tasks/pluginmachinery/catalog/hashing.go @@ -4,6 +4,8 @@ import ( "context" "encoding/base64" + "k8s.io/utils/strings/slices" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytestdlib/pbhash" ) @@ -55,7 +57,7 @@ func hashify(literal *core.Literal) *core.Literal { return literal } -func HashLiteralMap(ctx context.Context, literalMap *core.LiteralMap) (string, error) { +func HashLiteralMap(ctx context.Context, literalMap *core.LiteralMap, cacheIgnoreInputVars []string) (string, error) { if literalMap == nil || len(literalMap.Literals) == 0 { literalMap = &emptyLiteralMap } @@ -64,7 +66,9 @@ func HashLiteralMap(ctx context.Context, literalMap *core.LiteralMap) (string, e // in case the corresponding hash is set. hashifiedLiteralMap := make(map[string]*core.Literal, len(literalMap.Literals)) for name, literal := range literalMap.Literals { - hashifiedLiteralMap[name] = hashify(literal) + if !slices.Contains(cacheIgnoreInputVars, name) { + hashifiedLiteralMap[name] = hashify(literal) + } } hashifiedInputs := &core.LiteralMap{ Literals: hashifiedLiteralMap, diff --git a/flyteplugins/go/tasks/pluginmachinery/catalog/hashing_test.go b/flyteplugins/go/tasks/pluginmachinery/catalog/hashing_test.go index 43b6754c02..51f2cdfe27 100644 --- a/flyteplugins/go/tasks/pluginmachinery/catalog/hashing_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/catalog/hashing_test.go @@ -617,7 +617,7 @@ func TestHashLiteralMap_LiteralsWithHashSet(t *testing.T) { // Double-check that generating a tag is successful literalMap := &core.LiteralMap{Literals: map[string]*core.Literal{"o0": tt.literal}} - hash, err := HashLiteralMap(context.TODO(), literalMap) + hash, err := HashLiteralMap(context.TODO(), literalMap, nil) assert.NoError(t, err) assert.NotEmpty(t, hash) }) @@ -629,26 +629,43 @@ func TestInputValueSorted(t *testing.T) { literalMap, err := coreutils.MakeLiteralMap(map[string]interface{}{"1": 1, "2": 2}) assert.NoError(t, err) - hash, err := HashLiteralMap(context.TODO(), literalMap) + hash, err := HashLiteralMap(context.TODO(), literalMap, nil) assert.NoError(t, err) assert.Equal(t, "GQid5LjHbakcW68DS3P2jp80QLbiF0olFHF2hTh5bg8", hash) literalMap, err = coreutils.MakeLiteralMap(map[string]interface{}{"2": 2, "1": 1}) assert.NoError(t, err) - hashDupe, err := HashLiteralMap(context.TODO(), literalMap) + hashDupe, err := HashLiteralMap(context.TODO(), literalMap, nil) assert.NoError(t, err) assert.Equal(t, hashDupe, hash) } // Ensure that empty inputs are hashed the same way func TestNoInputValues(t *testing.T) { - hash, err := HashLiteralMap(context.TODO(), nil) + hash, err := HashLiteralMap(context.TODO(), nil, nil) assert.NoError(t, err) assert.Equal(t, "GKw-c0PwFokMUQ6T-TUmEWnZ4_VlQ2Qpgw-vCTT0-OQ", hash) - hashDupe, err := HashLiteralMap(context.TODO(), &core.LiteralMap{Literals: nil}) + hashDupe, err := HashLiteralMap(context.TODO(), &core.LiteralMap{Literals: nil}, nil) assert.NoError(t, err) assert.Equal(t, "GKw-c0PwFokMUQ6T-TUmEWnZ4_VlQ2Qpgw-vCTT0-OQ", hashDupe) assert.Equal(t, hashDupe, hash) } + +// Ensure that empty inputs are hashed the same way +func TestCacheIgnoreInputVars(t *testing.T) { + literalMap, err := coreutils.MakeLiteralMap(map[string]interface{}{"1": 1, "2": 2}) + assert.NoError(t, err) + + hash, err := HashLiteralMap(context.TODO(), literalMap, nil) + assert.NoError(t, err) + assert.Equal(t, "GQid5LjHbakcW68DS3P2jp80QLbiF0olFHF2hTh5bg8", hash) + + literalMap, err = coreutils.MakeLiteralMap(map[string]interface{}{"2": 2, "1": 1, "3": 3}) + assert.NoError(t, err) + + hashDupe, err := HashLiteralMap(context.TODO(), literalMap, []string{"3"}) + assert.NoError(t, err) + assert.Equal(t, hashDupe, hash) +} diff --git a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/datacatalog.go b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/datacatalog.go index c211c19597..a87266170c 100644 --- a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/datacatalog.go +++ b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/datacatalog.go @@ -108,7 +108,7 @@ func (m *CatalogClient) Get(ctx context.Context, key catalog.Key) (catalog.Entry inputs = retInputs } - tag, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars) + tag, err := GenerateArtifactTagName(ctx, inputs, key.CacheIgnoreInputVars) if err != nil { logger.Errorf(ctx, "DataCatalog failed to generate tag for inputs %+v, err: %+v", inputs, err) return catalog.Entry{}, err @@ -233,7 +233,7 @@ func (m *CatalogClient) CreateArtifact(ctx context.Context, key catalog.Key, dat logger.Debugf(ctx, "Created artifact: %v, with %v outputs from execution %+v", cachedArtifact.Id, len(artifactDataList), metadata) // Tag the artifact since it is the cached artifact - tagName, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars) + tagName, err := GenerateArtifactTagName(ctx, inputs, key.CacheIgnoreInputVars) if err != nil { logger.Errorf(ctx, "Failed to generate tag for artifact %+v, err: %+v", cachedArtifact.Id, err) return catalog.Status{}, err @@ -273,7 +273,7 @@ func (m *CatalogClient) UpdateArtifact(ctx context.Context, key catalog.Key, dat artifactDataList = append(artifactDataList, artifactData) } - tagName, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars) + tagName, err := GenerateArtifactTagName(ctx, inputs, key.CacheIgnoreInputVars) if err != nil { logger.Errorf(ctx, "Failed to generate artifact tag name for key %+v, dataset %+v and execution %+v, err: %+v", key, datasetID, metadata, err) return catalog.Status{}, err @@ -378,7 +378,7 @@ func (m *CatalogClient) GetOrExtendReservation(ctx context.Context, key catalog. inputs = retInputs } - tag, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars) + tag, err := GenerateArtifactTagName(ctx, inputs, key.CacheIgnoreInputVars) if err != nil { return nil, err } @@ -418,7 +418,7 @@ func (m *CatalogClient) ReleaseReservation(ctx context.Context, key catalog.Key, inputs = retInputs } - tag, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars) + tag, err := GenerateArtifactTagName(ctx, inputs, key.CacheIgnoreInputVars) if err != nil { return err } diff --git a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer.go b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer.go index 73b6e72089..5c0ac0c30b 100644 --- a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer.go +++ b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer.go @@ -13,7 +13,6 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/catalog" "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/validators" "github.com/flyteorg/flyte/flytestdlib/pbhash" - "golang.org/x/exp/slices" ) const cachedTaskTag = "flyte_cached" @@ -116,20 +115,8 @@ func generateTaskSignatureHash(ctx context.Context, taskInterface core.TypedInte } // Generate a tag by hashing the input values which are not in cacheIgnoreInputVars -func GenerateArtifactTagName(ctx context.Context, inputs *core.LiteralMap, cacheIgnoreInputVars *[]string) (string, error) { - var inputsAfterIgnore *core.LiteralMap - if cacheIgnoreInputVars != nil { - inputsAfterIgnore = &core.LiteralMap{Literals: make(map[string]*core.Literal)} - for name, literal := range inputs.Literals { - if slices.Contains(*cacheIgnoreInputVars, name) { - continue - } - inputsAfterIgnore.Literals[name] = literal - } - } else { - inputsAfterIgnore = inputs - } - hashString, err := catalog.HashLiteralMap(ctx, inputsAfterIgnore) +func GenerateArtifactTagName(ctx context.Context, inputs *core.LiteralMap, cacheIgnoreInputVars []string) (string, error) { + hashString, err := catalog.HashLiteralMap(ctx, inputs, cacheIgnoreInputVars) if err != nil { return "", err } diff --git a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer_test.go b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer_test.go index a392387927..1c6b9e2e1b 100644 --- a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer_test.go +++ b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer_test.go @@ -114,7 +114,7 @@ func TestGenerateArtifactTagNameWithIgnore(t *testing.T) { literalMap, err := coreutils.MakeLiteralMap(map[string]interface{}{"1": 1, "2": 2, "3": 3}) assert.NoError(t, err) cacheIgnoreInputVars := []string{"3"} - tag, err := GenerateArtifactTagName(context.TODO(), literalMap, &cacheIgnoreInputVars) + tag, err := GenerateArtifactTagName(context.TODO(), literalMap, cacheIgnoreInputVars) assert.NoError(t, err) assert.Equal(t, "flyte_cached-GQid5LjHbakcW68DS3P2jp80QLbiF0olFHF2hTh5bg8", tag) }