Skip to content

Commit

Permalink
moved cache ignore input vars check to HashLiteralMap function
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Rammer <[email protected]>
  • Loading branch information
hamersaw committed Dec 18, 2023
1 parent 30d534c commit cfa1e00
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func hashInputs(ctx context.Context, key Key) (string, error) {
}
inputs = retInputs
}
return HashLiteralMap(ctx, inputs)
return HashLiteralMap(ctx, inputs, key.CacheIgnoreInputVars)
}

func (c AsyncClientImpl) Download(ctx context.Context, requests ...DownloadRequest) (outputFuture DownloadFuture, err error) {
Expand Down
8 changes: 6 additions & 2 deletions flyteplugins/go/tasks/pluginmachinery/catalog/hashing.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"encoding/base64"

"k8s.io/utils/strings/slices"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytestdlib/pbhash"
)
Expand Down Expand Up @@ -55,7 +57,7 @@ func hashify(literal *core.Literal) *core.Literal {
return literal
}

func HashLiteralMap(ctx context.Context, literalMap *core.LiteralMap) (string, error) {
func HashLiteralMap(ctx context.Context, literalMap *core.LiteralMap, cacheIgnoreInputVars []string) (string, error) {
if literalMap == nil || len(literalMap.Literals) == 0 {
literalMap = &emptyLiteralMap
}
Expand All @@ -64,7 +66,9 @@ func HashLiteralMap(ctx context.Context, literalMap *core.LiteralMap) (string, e
// in case the corresponding hash is set.
hashifiedLiteralMap := make(map[string]*core.Literal, len(literalMap.Literals))
for name, literal := range literalMap.Literals {
hashifiedLiteralMap[name] = hashify(literal)
if !slices.Contains(cacheIgnoreInputVars, name) {
hashifiedLiteralMap[name] = hashify(literal)
}
}
hashifiedInputs := &core.LiteralMap{
Literals: hashifiedLiteralMap,
Expand Down
27 changes: 22 additions & 5 deletions flyteplugins/go/tasks/pluginmachinery/catalog/hashing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ func TestHashLiteralMap_LiteralsWithHashSet(t *testing.T) {

// Double-check that generating a tag is successful
literalMap := &core.LiteralMap{Literals: map[string]*core.Literal{"o0": tt.literal}}
hash, err := HashLiteralMap(context.TODO(), literalMap)
hash, err := HashLiteralMap(context.TODO(), literalMap, nil)
assert.NoError(t, err)
assert.NotEmpty(t, hash)
})
Expand All @@ -629,26 +629,43 @@ func TestInputValueSorted(t *testing.T) {
literalMap, err := coreutils.MakeLiteralMap(map[string]interface{}{"1": 1, "2": 2})
assert.NoError(t, err)

hash, err := HashLiteralMap(context.TODO(), literalMap)
hash, err := HashLiteralMap(context.TODO(), literalMap, nil)
assert.NoError(t, err)
assert.Equal(t, "GQid5LjHbakcW68DS3P2jp80QLbiF0olFHF2hTh5bg8", hash)

literalMap, err = coreutils.MakeLiteralMap(map[string]interface{}{"2": 2, "1": 1})
assert.NoError(t, err)

hashDupe, err := HashLiteralMap(context.TODO(), literalMap)
hashDupe, err := HashLiteralMap(context.TODO(), literalMap, nil)
assert.NoError(t, err)
assert.Equal(t, hashDupe, hash)
}

// Ensure that empty inputs are hashed the same way
func TestNoInputValues(t *testing.T) {
hash, err := HashLiteralMap(context.TODO(), nil)
hash, err := HashLiteralMap(context.TODO(), nil, nil)
assert.NoError(t, err)
assert.Equal(t, "GKw-c0PwFokMUQ6T-TUmEWnZ4_VlQ2Qpgw-vCTT0-OQ", hash)

hashDupe, err := HashLiteralMap(context.TODO(), &core.LiteralMap{Literals: nil})
hashDupe, err := HashLiteralMap(context.TODO(), &core.LiteralMap{Literals: nil}, nil)
assert.NoError(t, err)
assert.Equal(t, "GKw-c0PwFokMUQ6T-TUmEWnZ4_VlQ2Qpgw-vCTT0-OQ", hashDupe)
assert.Equal(t, hashDupe, hash)
}

// Ensure that empty inputs are hashed the same way
func TestCacheIgnoreInputVars(t *testing.T) {
literalMap, err := coreutils.MakeLiteralMap(map[string]interface{}{"1": 1, "2": 2})
assert.NoError(t, err)

hash, err := HashLiteralMap(context.TODO(), literalMap, nil)
assert.NoError(t, err)
assert.Equal(t, "GQid5LjHbakcW68DS3P2jp80QLbiF0olFHF2hTh5bg8", hash)

literalMap, err = coreutils.MakeLiteralMap(map[string]interface{}{"2": 2, "1": 1, "3": 3})
assert.NoError(t, err)

hashDupe, err := HashLiteralMap(context.TODO(), literalMap, []string{"3"})
assert.NoError(t, err)
assert.Equal(t, hashDupe, hash)
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (m *CatalogClient) Get(ctx context.Context, key catalog.Key) (catalog.Entry
inputs = retInputs
}

tag, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars)
tag, err := GenerateArtifactTagName(ctx, inputs, key.CacheIgnoreInputVars)
if err != nil {
logger.Errorf(ctx, "DataCatalog failed to generate tag for inputs %+v, err: %+v", inputs, err)
return catalog.Entry{}, err
Expand Down Expand Up @@ -233,7 +233,7 @@ func (m *CatalogClient) CreateArtifact(ctx context.Context, key catalog.Key, dat
logger.Debugf(ctx, "Created artifact: %v, with %v outputs from execution %+v", cachedArtifact.Id, len(artifactDataList), metadata)

// Tag the artifact since it is the cached artifact
tagName, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars)
tagName, err := GenerateArtifactTagName(ctx, inputs, key.CacheIgnoreInputVars)
if err != nil {
logger.Errorf(ctx, "Failed to generate tag for artifact %+v, err: %+v", cachedArtifact.Id, err)
return catalog.Status{}, err
Expand Down Expand Up @@ -273,7 +273,7 @@ func (m *CatalogClient) UpdateArtifact(ctx context.Context, key catalog.Key, dat
artifactDataList = append(artifactDataList, artifactData)
}

tagName, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars)
tagName, err := GenerateArtifactTagName(ctx, inputs, key.CacheIgnoreInputVars)
if err != nil {
logger.Errorf(ctx, "Failed to generate artifact tag name for key %+v, dataset %+v and execution %+v, err: %+v", key, datasetID, metadata, err)
return catalog.Status{}, err
Expand Down Expand Up @@ -378,7 +378,7 @@ func (m *CatalogClient) GetOrExtendReservation(ctx context.Context, key catalog.
inputs = retInputs
}

tag, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars)
tag, err := GenerateArtifactTagName(ctx, inputs, key.CacheIgnoreInputVars)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -418,7 +418,7 @@ func (m *CatalogClient) ReleaseReservation(ctx context.Context, key catalog.Key,
inputs = retInputs
}

tag, err := GenerateArtifactTagName(ctx, inputs, &key.CacheIgnoreInputVars)
tag, err := GenerateArtifactTagName(ctx, inputs, key.CacheIgnoreInputVars)
if err != nil {
return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/catalog"
"github.com/flyteorg/flyte/flytepropeller/pkg/compiler/validators"
"github.com/flyteorg/flyte/flytestdlib/pbhash"
"golang.org/x/exp/slices"
)

const cachedTaskTag = "flyte_cached"
Expand Down Expand Up @@ -116,20 +115,8 @@ func generateTaskSignatureHash(ctx context.Context, taskInterface core.TypedInte
}

// Generate a tag by hashing the input values which are not in cacheIgnoreInputVars
func GenerateArtifactTagName(ctx context.Context, inputs *core.LiteralMap, cacheIgnoreInputVars *[]string) (string, error) {
var inputsAfterIgnore *core.LiteralMap
if cacheIgnoreInputVars != nil {
inputsAfterIgnore = &core.LiteralMap{Literals: make(map[string]*core.Literal)}
for name, literal := range inputs.Literals {
if slices.Contains(*cacheIgnoreInputVars, name) {
continue
}
inputsAfterIgnore.Literals[name] = literal
}
} else {
inputsAfterIgnore = inputs
}
hashString, err := catalog.HashLiteralMap(ctx, inputsAfterIgnore)
func GenerateArtifactTagName(ctx context.Context, inputs *core.LiteralMap, cacheIgnoreInputVars []string) (string, error) {
hashString, err := catalog.HashLiteralMap(ctx, inputs, cacheIgnoreInputVars)
if err != nil {
return "", err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func TestGenerateArtifactTagNameWithIgnore(t *testing.T) {
literalMap, err := coreutils.MakeLiteralMap(map[string]interface{}{"1": 1, "2": 2, "3": 3})
assert.NoError(t, err)
cacheIgnoreInputVars := []string{"3"}
tag, err := GenerateArtifactTagName(context.TODO(), literalMap, &cacheIgnoreInputVars)
tag, err := GenerateArtifactTagName(context.TODO(), literalMap, cacheIgnoreInputVars)
assert.NoError(t, err)
assert.Equal(t, "flyte_cached-GQid5LjHbakcW68DS3P2jp80QLbiF0olFHF2hTh5bg8", tag)
}
Expand Down

0 comments on commit cfa1e00

Please sign in to comment.