diff --git a/flyteartifacts/pkg/db/storage.go b/flyteartifacts/pkg/db/storage.go index a7387312fe..ca352be656 100644 --- a/flyteartifacts/pkg/db/storage.go +++ b/flyteartifacts/pkg/db/storage.go @@ -99,14 +99,12 @@ func (r *RDSStorage) CreateArtifact(ctx context.Context, serviceModel models.Art } func (r *RDSStorage) handleUriGet(ctx context.Context, uri string) (models.Artifact, error) { - artifactID, tag, err := lib.ParseFlyteURL(uri) + artifactID, err := lib.ParseFlyteURL(uri) if err != nil { logger.Errorf(ctx, "Failed to parse uri [%s]: %+v", uri, err) return models.Artifact{}, err } - if tag != "" { - return models.Artifact{}, fmt.Errorf("tag not implemented yet") - } + logger.Debugf(ctx, "Extracted artifact id [%v] from uri [%s], using id handler", artifactID, uri) return r.handleArtifactIdGet(ctx, artifactID) } diff --git a/flyteartifacts/pkg/lib/constants.go b/flyteartifacts/pkg/lib/constants.go index 845b570525..767d5146db 100644 --- a/flyteartifacts/pkg/lib/constants.go +++ b/flyteartifacts/pkg/lib/constants.go @@ -1,4 +1,7 @@ package lib -// ArtifactKey - This is used to tag Literals as a tracking bit. +// ArtifactKey - This string is used to identify Artifacts when all you have +// is the underlying Literal. Look for this key under the literal's metadata field. This situation can arise +// when a user fetches an artifact, using something like flyte remote or flyte console, and then kicks +// off an execution using that literal. const ArtifactKey = "_ua" diff --git a/flyteartifacts/pkg/lib/url_parse.go b/flyteartifacts/pkg/lib/url_parse.go index 489e416d64..cf4f1023ca 100644 --- a/flyteartifacts/pkg/lib/url_parse.go +++ b/flyteartifacts/pkg/lib/url_parse.go @@ -2,50 +2,48 @@ package lib import ( "errors" - "net/url" - "regexp" - "strings" - + "fmt" "github.com/flyteorg/flyte/flyteartifacts/pkg/models" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "net/url" + "regexp" ) -var flyteURLNameRe = regexp.MustCompile(`(?P[\w/-]+)(?:(:(?P\w+))?)(?:(@(?P\w+))?)`) +var flyteURLNameRe = regexp.MustCompile(`(?P[\w-]+)/(?P[\w-]+)/(?P[\w/-]+)(@(?P[\w/-]+))?`) -func ParseFlyteURL(urlStr string) (core.ArtifactID, string, error) { +func ParseFlyteURL(urlStr string) (core.ArtifactID, error) { if len(urlStr) == 0 { - return core.ArtifactID{}, "", errors.New("URL cannot be empty") + return core.ArtifactID{}, errors.New("URL cannot be empty") } parsed, err := url.Parse(urlStr) if err != nil { - return core.ArtifactID{}, "", err + return core.ArtifactID{}, err } queryValues, err := url.ParseQuery(parsed.RawQuery) if err != nil { - return core.ArtifactID{}, "", err - } - projectDomainName := strings.Split(strings.Trim(parsed.Path, "/"), "/") - if len(projectDomainName) < 3 { - return core.ArtifactID{}, "", errors.New("invalid URL format") + return core.ArtifactID{}, err } - project, domain, name := projectDomainName[0], projectDomainName[1], strings.Join(projectDomainName[2:], "/") - version := "" - tag := "" + //projectDomainName := strings.Split(strings.Trim(parsed.Path, "/"), "/") + //if len(projectDomainName) < 3 { + // return core.ArtifactID{}, errors.New("invalid URL format") + //} + //project, domain, name := projectDomainName[0], projectDomainName[1], strings.Join(projectDomainName[2:], "/") + var project, domain, name, version string queryDict := make(map[string]string) - if match := flyteURLNameRe.FindStringSubmatch(name); match != nil { - name = match[1] - if match[3] != "" { - tag = match[3] + if match := flyteURLNameRe.FindStringSubmatch(parsed.Path); match != nil { + if len(match) < 4 { + return core.ArtifactID{}, fmt.Errorf("insufficient components specified %s", parsed.Path) } - if match[5] != "" { + project = match[1] + domain = match[2] + name = match[3] + if len(match) > 5 { version = match[5] } - - if tag != "" && (version != "" || len(queryValues) > 0) { - return core.ArtifactID{}, "", errors.New("cannot specify tag with version or querydict") - } + } else { + return core.ArtifactID{}, fmt.Errorf("unable to parse %s", parsed.Path) } for key, values := range queryValues { @@ -68,5 +66,5 @@ func ParseFlyteURL(urlStr string) (core.ArtifactID, string, error) { Partitions: p, } - return a, tag, nil + return a, nil } diff --git a/flyteartifacts/pkg/lib/url_parse_test.go b/flyteartifacts/pkg/lib/url_parse_test.go index 2340d07252..d106c1cb8b 100644 --- a/flyteartifacts/pkg/lib/url_parse_test.go +++ b/flyteartifacts/pkg/lib/url_parse_test.go @@ -9,52 +9,42 @@ import ( "github.com/flyteorg/flyte/flyteartifacts/pkg/models" ) -func TestURLParseWithTag(t *testing.T) { - artifactID, tag, err := ParseFlyteURL("flyte://av0.1/project/domain/name:tag") +func TestURLParseWithVersionAndPartitions(t *testing.T) { + artifactID, err := ParseFlyteURL("flyte://av0.1/project/domain/name@version?foo=bar&ham=spam") + expPartitions := map[string]string{"foo": "bar", "ham": "spam"} assert.NoError(t, err) assert.Equal(t, "project", artifactID.ArtifactKey.Project) assert.Equal(t, "domain", artifactID.ArtifactKey.Domain) assert.Equal(t, "name", artifactID.ArtifactKey.Name) - assert.Equal(t, "", artifactID.Version) - assert.Equal(t, "tag", tag) - assert.Nil(t, artifactID.GetPartitions()) + assert.Equal(t, "version", artifactID.Version) + p := artifactID.GetPartitions() + mapP := models.PartitionsFromIdl(context.TODO(), p) + assert.Equal(t, expPartitions, mapP) } -func TestURLParseWithVersionAndPartitions(t *testing.T) { - artifactID, tag, err := ParseFlyteURL("flyte://av0.1/project/domain/name@version?foo=bar&ham=spam") +func TestURLParseWithSlashVersionAndPartitions(t *testing.T) { + artifactID, err := ParseFlyteURL("flyte://av0.1/project/domain/name/more@version/abc/0/o0?foo=bar&ham=spam") expPartitions := map[string]string{"foo": "bar", "ham": "spam"} assert.NoError(t, err) assert.Equal(t, "project", artifactID.ArtifactKey.Project) assert.Equal(t, "domain", artifactID.ArtifactKey.Domain) - assert.Equal(t, "name", artifactID.ArtifactKey.Name) - assert.Equal(t, "version", artifactID.Version) - assert.Equal(t, "", tag) + assert.Equal(t, "name/more", artifactID.ArtifactKey.Name) + assert.Equal(t, "version/abc/0/o0", artifactID.Version) p := artifactID.GetPartitions() mapP := models.PartitionsFromIdl(context.TODO(), p) assert.Equal(t, expPartitions, mapP) } -func TestURLParseFailsWithBothTagAndPartitions(t *testing.T) { - _, _, err := ParseFlyteURL("flyte://av0.1/project/domain/name:tag?foo=bar&ham=spam") - assert.Error(t, err) -} - -func TestURLParseWithBothTagAndVersion(t *testing.T) { - _, _, err := ParseFlyteURL("flyte://av0.1/project/domain/name:tag@version") - assert.Error(t, err) -} - func TestURLParseNameWithSlashes(t *testing.T) { - artifactID, tag, err := ParseFlyteURL("flyte://av0.1/project/domain/name/with/slashes") + artifactID, err := ParseFlyteURL("flyte://av0.1/project/domain/name/with/slashes") assert.NoError(t, err) assert.Equal(t, "project", artifactID.ArtifactKey.Project) assert.Equal(t, "domain", artifactID.ArtifactKey.Domain) assert.Equal(t, "name/with/slashes", artifactID.ArtifactKey.Name) - assert.Equal(t, "", tag) - artifactID, _, err = ParseFlyteURL("flyte://av0.1/project/domain/name/with/slashes?ds=2020-01-01") + artifactID, err = ParseFlyteURL("flyte://av0.1/project/domain/name/with/slashes?ds=2020-01-01") assert.NoError(t, err) assert.Equal(t, "name/with/slashes", artifactID.ArtifactKey.Name) assert.Equal(t, "project", artifactID.ArtifactKey.Project) diff --git a/flyteartifacts/pkg/server/processor/events_handler.go b/flyteartifacts/pkg/server/processor/events_handler.go index 71f7813268..929844194b 100644 --- a/flyteartifacts/pkg/server/processor/events_handler.go +++ b/flyteartifacts/pkg/server/processor/events_handler.go @@ -46,11 +46,25 @@ func (s *ServiceCallHandler) HandleEvent(ctx context.Context, cloudEvent *event2 func (s *ServiceCallHandler) HandleEventExecStart(ctx context.Context, evt *event.CloudEventExecutionStart) error { - if len(evt.ArtifactIds) > 0 { + var inputsUsed []*core.ArtifactID + + inputsUsed = append(inputsUsed, evt.ArtifactIds...) + for _, x := range evt.ArtifactTrackers { + + dummyURI := fmt.Sprintf("flyte://av0.1/%s", x) + idWithVersion, err := lib.ParseFlyteURL(dummyURI) + if err != nil { + logger.Errorf(ctx, "Error parsing input %s for execution start: %v", x, err) + return err + } + inputsUsed = append(inputsUsed, &idWithVersion) + } + + if len(inputsUsed) > 0 { // metric req := &artifact.ExecutionInputsRequest{ ExecutionId: evt.ExecutionId, - Inputs: evt.ArtifactIds, + Inputs: inputsUsed, } _, err := s.service.SetExecutionInputs(ctx, req) if err != nil { @@ -172,8 +186,6 @@ func getPartitionsAndTag(ctx context.Context, partialID core.ArtifactID, variabl } var partitions map[string]string - // todo: consider updating idl to make CreateArtifactRequest just take a full Partitions - // object rather than a mapstrstr @eapolinario @enghabu if partialID.GetPartitions().GetValue() != nil && len(partialID.GetPartitions().GetValue()) > 0 { partitions = make(map[string]string, len(partialID.GetPartitions().GetValue())) for k, lv := range partialID.GetPartitions().GetValue() { diff --git a/flyteartifacts/pkg/server/server.go b/flyteartifacts/pkg/server/server.go index 89e9448205..253b0fb4e2 100644 --- a/flyteartifacts/pkg/server/server.go +++ b/flyteartifacts/pkg/server/server.go @@ -31,6 +31,9 @@ type ArtifactService struct { } func (a *ArtifactService) CreateArtifact(ctx context.Context, req *artifact.CreateArtifactRequest) (*artifact.CreateArtifactResponse, error) { + + // todo: add a request validating section, check for nils, etc. + resp, err := a.Service.CreateArtifact(ctx, req) if err != nil { return resp, err diff --git a/flyteartifacts/pkg/server/service.go b/flyteartifacts/pkg/server/service.go index abc1b81b6a..570902fda7 100644 --- a/flyteartifacts/pkg/server/service.go +++ b/flyteartifacts/pkg/server/service.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "github.com/flyteorg/flyte/flyteartifacts/pkg/lib" "github.com/flyteorg/flyte/flyteartifacts/pkg/models" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/artifact" @@ -14,16 +15,30 @@ import ( type CoreService struct { Storage StorageInterface BlobStore BlobStoreInterface - // SearchHandler SearchHandlerInterface +} + +// This string is a tracker basically that will be installed in the metadata of the literal. See the ArtifactKey constant for more information. +func (c *CoreService) getTrackingString(request artifact.CreateArtifactRequest) string { + ak := request.ArtifactKey + t := fmt.Sprintf("%s/%s/%s@%s", ak.Project, ak.Domain, ak.Name, request.Version) + + return t } func (c *CoreService) CreateArtifact(ctx context.Context, request *artifact.CreateArtifactRequest) (*artifact.CreateArtifactResponse, error) { - // todo: gatepr _ua tracking bit to be installed - if request == nil { + // todo: move one layer higher to server.go + if request == nil || request.GetArtifactKey() == nil { + logger.Errorf(ctx, "Ignoring nil or partially nil request") return nil, nil } + if request.GetSpec().GetValue().Metadata == nil { + request.GetSpec().GetValue().Metadata = make(map[string]string, 1) + } + trackingStr := c.getTrackingString(*request) + request.GetSpec().GetValue().Metadata[lib.ArtifactKey] = trackingStr + artifactObj, err := models.CreateArtifactModelFromRequest(ctx, request.ArtifactKey, request.Spec, request.Version, request.Partitions, request.Tag, request.Source) if err != nil { logger.Errorf(ctx, "Failed to validate Create request: %v", err)