From c446f2ebf11c51ca5dad4db802547a4a891d8d99 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 5 Jan 2024 13:49:07 +0800 Subject: [PATCH] update ParseFlyteURL to not return or parse for the :tag anymore since we're no longer supporting that feature (yet), change the parsing logic to handle versions with slashes, update tests to match, update events handler to read the tracking tag and construct artifact IDs out of it, install tracking string into literal metadata that's just project/domain/name@version Signed-off-by: Yee Hing Tong --- flyteartifacts/pkg/db/storage.go | 6 +-- flyteartifacts/pkg/lib/constants.go | 5 +- flyteartifacts/pkg/lib/url_parse.go | 50 +++++++++---------- flyteartifacts/pkg/lib/url_parse_test.go | 36 +++++-------- .../pkg/server/processor/events_handler.go | 20 ++++++-- flyteartifacts/pkg/server/server.go | 3 ++ flyteartifacts/pkg/server/service.go | 21 ++++++-- 7 files changed, 80 insertions(+), 61 deletions(-) diff --git a/flyteartifacts/pkg/db/storage.go b/flyteartifacts/pkg/db/storage.go index a7387312fe..ca352be656 100644 --- a/flyteartifacts/pkg/db/storage.go +++ b/flyteartifacts/pkg/db/storage.go @@ -99,14 +99,12 @@ func (r *RDSStorage) CreateArtifact(ctx context.Context, serviceModel models.Art } func (r *RDSStorage) handleUriGet(ctx context.Context, uri string) (models.Artifact, error) { - artifactID, tag, err := lib.ParseFlyteURL(uri) + artifactID, err := lib.ParseFlyteURL(uri) if err != nil { logger.Errorf(ctx, "Failed to parse uri [%s]: %+v", uri, err) return models.Artifact{}, err } - if tag != "" { - return models.Artifact{}, fmt.Errorf("tag not implemented yet") - } + logger.Debugf(ctx, "Extracted artifact id [%v] from uri [%s], using id handler", artifactID, uri) return r.handleArtifactIdGet(ctx, artifactID) } diff --git a/flyteartifacts/pkg/lib/constants.go b/flyteartifacts/pkg/lib/constants.go index 845b570525..767d5146db 100644 --- a/flyteartifacts/pkg/lib/constants.go +++ b/flyteartifacts/pkg/lib/constants.go @@ -1,4 +1,7 @@ package lib -// ArtifactKey - This is used to tag Literals as a tracking bit. +// ArtifactKey - This string is used to identify Artifacts when all you have +// is the underlying Literal. Look for this key under the literal's metadata field. This situation can arise +// when a user fetches an artifact, using something like flyte remote or flyte console, and then kicks +// off an execution using that literal. const ArtifactKey = "_ua" diff --git a/flyteartifacts/pkg/lib/url_parse.go b/flyteartifacts/pkg/lib/url_parse.go index 489e416d64..cf4f1023ca 100644 --- a/flyteartifacts/pkg/lib/url_parse.go +++ b/flyteartifacts/pkg/lib/url_parse.go @@ -2,50 +2,48 @@ package lib import ( "errors" - "net/url" - "regexp" - "strings" - + "fmt" "github.com/flyteorg/flyte/flyteartifacts/pkg/models" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "net/url" + "regexp" ) -var flyteURLNameRe = regexp.MustCompile(`(?P[\w/-]+)(?:(:(?P\w+))?)(?:(@(?P\w+))?)`) +var flyteURLNameRe = regexp.MustCompile(`(?P[\w-]+)/(?P[\w-]+)/(?P[\w/-]+)(@(?P[\w/-]+))?`) -func ParseFlyteURL(urlStr string) (core.ArtifactID, string, error) { +func ParseFlyteURL(urlStr string) (core.ArtifactID, error) { if len(urlStr) == 0 { - return core.ArtifactID{}, "", errors.New("URL cannot be empty") + return core.ArtifactID{}, errors.New("URL cannot be empty") } parsed, err := url.Parse(urlStr) if err != nil { - return core.ArtifactID{}, "", err + return core.ArtifactID{}, err } queryValues, err := url.ParseQuery(parsed.RawQuery) if err != nil { - return core.ArtifactID{}, "", err - } - projectDomainName := strings.Split(strings.Trim(parsed.Path, "/"), "/") - if len(projectDomainName) < 3 { - return core.ArtifactID{}, "", errors.New("invalid URL format") + return core.ArtifactID{}, err } - project, domain, name := projectDomainName[0], projectDomainName[1], strings.Join(projectDomainName[2:], "/") - version := "" - tag := "" + //projectDomainName := strings.Split(strings.Trim(parsed.Path, "/"), "/") + //if len(projectDomainName) < 3 { + // return core.ArtifactID{}, errors.New("invalid URL format") + //} + //project, domain, name := projectDomainName[0], projectDomainName[1], strings.Join(projectDomainName[2:], "/") + var project, domain, name, version string queryDict := make(map[string]string) - if match := flyteURLNameRe.FindStringSubmatch(name); match != nil { - name = match[1] - if match[3] != "" { - tag = match[3] + if match := flyteURLNameRe.FindStringSubmatch(parsed.Path); match != nil { + if len(match) < 4 { + return core.ArtifactID{}, fmt.Errorf("insufficient components specified %s", parsed.Path) } - if match[5] != "" { + project = match[1] + domain = match[2] + name = match[3] + if len(match) > 5 { version = match[5] } - - if tag != "" && (version != "" || len(queryValues) > 0) { - return core.ArtifactID{}, "", errors.New("cannot specify tag with version or querydict") - } + } else { + return core.ArtifactID{}, fmt.Errorf("unable to parse %s", parsed.Path) } for key, values := range queryValues { @@ -68,5 +66,5 @@ func ParseFlyteURL(urlStr string) (core.ArtifactID, string, error) { Partitions: p, } - return a, tag, nil + return a, nil } diff --git a/flyteartifacts/pkg/lib/url_parse_test.go b/flyteartifacts/pkg/lib/url_parse_test.go index 2340d07252..d106c1cb8b 100644 --- a/flyteartifacts/pkg/lib/url_parse_test.go +++ b/flyteartifacts/pkg/lib/url_parse_test.go @@ -9,52 +9,42 @@ import ( "github.com/flyteorg/flyte/flyteartifacts/pkg/models" ) -func TestURLParseWithTag(t *testing.T) { - artifactID, tag, err := ParseFlyteURL("flyte://av0.1/project/domain/name:tag") +func TestURLParseWithVersionAndPartitions(t *testing.T) { + artifactID, err := ParseFlyteURL("flyte://av0.1/project/domain/name@version?foo=bar&ham=spam") + expPartitions := map[string]string{"foo": "bar", "ham": "spam"} assert.NoError(t, err) assert.Equal(t, "project", artifactID.ArtifactKey.Project) assert.Equal(t, "domain", artifactID.ArtifactKey.Domain) assert.Equal(t, "name", artifactID.ArtifactKey.Name) - assert.Equal(t, "", artifactID.Version) - assert.Equal(t, "tag", tag) - assert.Nil(t, artifactID.GetPartitions()) + assert.Equal(t, "version", artifactID.Version) + p := artifactID.GetPartitions() + mapP := models.PartitionsFromIdl(context.TODO(), p) + assert.Equal(t, expPartitions, mapP) } -func TestURLParseWithVersionAndPartitions(t *testing.T) { - artifactID, tag, err := ParseFlyteURL("flyte://av0.1/project/domain/name@version?foo=bar&ham=spam") +func TestURLParseWithSlashVersionAndPartitions(t *testing.T) { + artifactID, err := ParseFlyteURL("flyte://av0.1/project/domain/name/more@version/abc/0/o0?foo=bar&ham=spam") expPartitions := map[string]string{"foo": "bar", "ham": "spam"} assert.NoError(t, err) assert.Equal(t, "project", artifactID.ArtifactKey.Project) assert.Equal(t, "domain", artifactID.ArtifactKey.Domain) - assert.Equal(t, "name", artifactID.ArtifactKey.Name) - assert.Equal(t, "version", artifactID.Version) - assert.Equal(t, "", tag) + assert.Equal(t, "name/more", artifactID.ArtifactKey.Name) + assert.Equal(t, "version/abc/0/o0", artifactID.Version) p := artifactID.GetPartitions() mapP := models.PartitionsFromIdl(context.TODO(), p) assert.Equal(t, expPartitions, mapP) } -func TestURLParseFailsWithBothTagAndPartitions(t *testing.T) { - _, _, err := ParseFlyteURL("flyte://av0.1/project/domain/name:tag?foo=bar&ham=spam") - assert.Error(t, err) -} - -func TestURLParseWithBothTagAndVersion(t *testing.T) { - _, _, err := ParseFlyteURL("flyte://av0.1/project/domain/name:tag@version") - assert.Error(t, err) -} - func TestURLParseNameWithSlashes(t *testing.T) { - artifactID, tag, err := ParseFlyteURL("flyte://av0.1/project/domain/name/with/slashes") + artifactID, err := ParseFlyteURL("flyte://av0.1/project/domain/name/with/slashes") assert.NoError(t, err) assert.Equal(t, "project", artifactID.ArtifactKey.Project) assert.Equal(t, "domain", artifactID.ArtifactKey.Domain) assert.Equal(t, "name/with/slashes", artifactID.ArtifactKey.Name) - assert.Equal(t, "", tag) - artifactID, _, err = ParseFlyteURL("flyte://av0.1/project/domain/name/with/slashes?ds=2020-01-01") + artifactID, err = ParseFlyteURL("flyte://av0.1/project/domain/name/with/slashes?ds=2020-01-01") assert.NoError(t, err) assert.Equal(t, "name/with/slashes", artifactID.ArtifactKey.Name) assert.Equal(t, "project", artifactID.ArtifactKey.Project) diff --git a/flyteartifacts/pkg/server/processor/events_handler.go b/flyteartifacts/pkg/server/processor/events_handler.go index 71f7813268..929844194b 100644 --- a/flyteartifacts/pkg/server/processor/events_handler.go +++ b/flyteartifacts/pkg/server/processor/events_handler.go @@ -46,11 +46,25 @@ func (s *ServiceCallHandler) HandleEvent(ctx context.Context, cloudEvent *event2 func (s *ServiceCallHandler) HandleEventExecStart(ctx context.Context, evt *event.CloudEventExecutionStart) error { - if len(evt.ArtifactIds) > 0 { + var inputsUsed []*core.ArtifactID + + inputsUsed = append(inputsUsed, evt.ArtifactIds...) + for _, x := range evt.ArtifactTrackers { + + dummyURI := fmt.Sprintf("flyte://av0.1/%s", x) + idWithVersion, err := lib.ParseFlyteURL(dummyURI) + if err != nil { + logger.Errorf(ctx, "Error parsing input %s for execution start: %v", x, err) + return err + } + inputsUsed = append(inputsUsed, &idWithVersion) + } + + if len(inputsUsed) > 0 { // metric req := &artifact.ExecutionInputsRequest{ ExecutionId: evt.ExecutionId, - Inputs: evt.ArtifactIds, + Inputs: inputsUsed, } _, err := s.service.SetExecutionInputs(ctx, req) if err != nil { @@ -172,8 +186,6 @@ func getPartitionsAndTag(ctx context.Context, partialID core.ArtifactID, variabl } var partitions map[string]string - // todo: consider updating idl to make CreateArtifactRequest just take a full Partitions - // object rather than a mapstrstr @eapolinario @enghabu if partialID.GetPartitions().GetValue() != nil && len(partialID.GetPartitions().GetValue()) > 0 { partitions = make(map[string]string, len(partialID.GetPartitions().GetValue())) for k, lv := range partialID.GetPartitions().GetValue() { diff --git a/flyteartifacts/pkg/server/server.go b/flyteartifacts/pkg/server/server.go index 89e9448205..253b0fb4e2 100644 --- a/flyteartifacts/pkg/server/server.go +++ b/flyteartifacts/pkg/server/server.go @@ -31,6 +31,9 @@ type ArtifactService struct { } func (a *ArtifactService) CreateArtifact(ctx context.Context, req *artifact.CreateArtifactRequest) (*artifact.CreateArtifactResponse, error) { + + // todo: add a request validating section, check for nils, etc. + resp, err := a.Service.CreateArtifact(ctx, req) if err != nil { return resp, err diff --git a/flyteartifacts/pkg/server/service.go b/flyteartifacts/pkg/server/service.go index abc1b81b6a..570902fda7 100644 --- a/flyteartifacts/pkg/server/service.go +++ b/flyteartifacts/pkg/server/service.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "github.com/flyteorg/flyte/flyteartifacts/pkg/lib" "github.com/flyteorg/flyte/flyteartifacts/pkg/models" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/artifact" @@ -14,16 +15,30 @@ import ( type CoreService struct { Storage StorageInterface BlobStore BlobStoreInterface - // SearchHandler SearchHandlerInterface +} + +// This string is a tracker basically that will be installed in the metadata of the literal. See the ArtifactKey constant for more information. +func (c *CoreService) getTrackingString(request artifact.CreateArtifactRequest) string { + ak := request.ArtifactKey + t := fmt.Sprintf("%s/%s/%s@%s", ak.Project, ak.Domain, ak.Name, request.Version) + + return t } func (c *CoreService) CreateArtifact(ctx context.Context, request *artifact.CreateArtifactRequest) (*artifact.CreateArtifactResponse, error) { - // todo: gatepr _ua tracking bit to be installed - if request == nil { + // todo: move one layer higher to server.go + if request == nil || request.GetArtifactKey() == nil { + logger.Errorf(ctx, "Ignoring nil or partially nil request") return nil, nil } + if request.GetSpec().GetValue().Metadata == nil { + request.GetSpec().GetValue().Metadata = make(map[string]string, 1) + } + trackingStr := c.getTrackingString(*request) + request.GetSpec().GetValue().Metadata[lib.ArtifactKey] = trackingStr + artifactObj, err := models.CreateArtifactModelFromRequest(ctx, request.ArtifactKey, request.Spec, request.Version, request.Partitions, request.Tag, request.Source) if err != nil { logger.Errorf(ctx, "Failed to validate Create request: %v", err)