Skip to content

Commit

Permalink
update ParseFlyteURL to not return or parse for the :tag anymore sinc…
Browse files Browse the repository at this point in the history
…e we're no longer supporting that feature (yet), change the parsing logic to handle versions with slashes, update tests to match, update events handler to read the tracking tag and construct artifact IDs out of it, install tracking string into literal metadata that's just project/domain/name@version

Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor committed Jan 5, 2024
1 parent d650b58 commit c446f2e
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 61 deletions.
6 changes: 2 additions & 4 deletions flyteartifacts/pkg/db/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,12 @@ func (r *RDSStorage) CreateArtifact(ctx context.Context, serviceModel models.Art
}

func (r *RDSStorage) handleUriGet(ctx context.Context, uri string) (models.Artifact, error) {
artifactID, tag, err := lib.ParseFlyteURL(uri)
artifactID, err := lib.ParseFlyteURL(uri)
if err != nil {
logger.Errorf(ctx, "Failed to parse uri [%s]: %+v", uri, err)
return models.Artifact{}, err
}
if tag != "" {
return models.Artifact{}, fmt.Errorf("tag not implemented yet")
}

logger.Debugf(ctx, "Extracted artifact id [%v] from uri [%s], using id handler", artifactID, uri)
return r.handleArtifactIdGet(ctx, artifactID)
}
Expand Down
5 changes: 4 additions & 1 deletion flyteartifacts/pkg/lib/constants.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
package lib

// ArtifactKey - This is used to tag Literals as a tracking bit.
// ArtifactKey - This string is used to identify Artifacts when all you have
// is the underlying Literal. Look for this key under the literal's metadata field. This situation can arise
// when a user fetches an artifact, using something like flyte remote or flyte console, and then kicks
// off an execution using that literal.
const ArtifactKey = "_ua"
50 changes: 24 additions & 26 deletions flyteartifacts/pkg/lib/url_parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,48 @@ package lib

import (
"errors"
"net/url"
"regexp"
"strings"

"fmt"
"github.com/flyteorg/flyte/flyteartifacts/pkg/models"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"net/url"
"regexp"
)

var flyteURLNameRe = regexp.MustCompile(`(?P<name>[\w/-]+)(?:(:(?P<tag>\w+))?)(?:(@(?P<version>\w+))?)`)
var flyteURLNameRe = regexp.MustCompile(`(?P<project>[\w-]+)/(?P<domain>[\w-]+)/(?P<name>[\w/-]+)(@(?P<version>[\w/-]+))?`)

func ParseFlyteURL(urlStr string) (core.ArtifactID, string, error) {
func ParseFlyteURL(urlStr string) (core.ArtifactID, error) {
if len(urlStr) == 0 {
return core.ArtifactID{}, "", errors.New("URL cannot be empty")
return core.ArtifactID{}, errors.New("URL cannot be empty")
}

parsed, err := url.Parse(urlStr)
if err != nil {
return core.ArtifactID{}, "", err
return core.ArtifactID{}, err
}
queryValues, err := url.ParseQuery(parsed.RawQuery)
if err != nil {
return core.ArtifactID{}, "", err
}
projectDomainName := strings.Split(strings.Trim(parsed.Path, "/"), "/")
if len(projectDomainName) < 3 {
return core.ArtifactID{}, "", errors.New("invalid URL format")
return core.ArtifactID{}, err
}
project, domain, name := projectDomainName[0], projectDomainName[1], strings.Join(projectDomainName[2:], "/")
version := ""
tag := ""
//projectDomainName := strings.Split(strings.Trim(parsed.Path, "/"), "/")
//if len(projectDomainName) < 3 {
// return core.ArtifactID{}, errors.New("invalid URL format")
//}
//project, domain, name := projectDomainName[0], projectDomainName[1], strings.Join(projectDomainName[2:], "/")
var project, domain, name, version string
queryDict := make(map[string]string)

if match := flyteURLNameRe.FindStringSubmatch(name); match != nil {
name = match[1]
if match[3] != "" {
tag = match[3]
if match := flyteURLNameRe.FindStringSubmatch(parsed.Path); match != nil {
if len(match) < 4 {
return core.ArtifactID{}, fmt.Errorf("insufficient components specified %s", parsed.Path)
}
if match[5] != "" {
project = match[1]
domain = match[2]
name = match[3]
if len(match) > 5 {
version = match[5]
}

if tag != "" && (version != "" || len(queryValues) > 0) {
return core.ArtifactID{}, "", errors.New("cannot specify tag with version or querydict")
}
} else {
return core.ArtifactID{}, fmt.Errorf("unable to parse %s", parsed.Path)
}

for key, values := range queryValues {
Expand All @@ -68,5 +66,5 @@ func ParseFlyteURL(urlStr string) (core.ArtifactID, string, error) {
Partitions: p,
}

return a, tag, nil
return a, nil
}
36 changes: 13 additions & 23 deletions flyteartifacts/pkg/lib/url_parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,42 @@ import (
"github.com/flyteorg/flyte/flyteartifacts/pkg/models"
)

func TestURLParseWithTag(t *testing.T) {
artifactID, tag, err := ParseFlyteURL("flyte://av0.1/project/domain/name:tag")
func TestURLParseWithVersionAndPartitions(t *testing.T) {
artifactID, err := ParseFlyteURL("flyte://av0.1/project/domain/name@version?foo=bar&ham=spam")
expPartitions := map[string]string{"foo": "bar", "ham": "spam"}
assert.NoError(t, err)

assert.Equal(t, "project", artifactID.ArtifactKey.Project)
assert.Equal(t, "domain", artifactID.ArtifactKey.Domain)
assert.Equal(t, "name", artifactID.ArtifactKey.Name)
assert.Equal(t, "", artifactID.Version)
assert.Equal(t, "tag", tag)
assert.Nil(t, artifactID.GetPartitions())
assert.Equal(t, "version", artifactID.Version)
p := artifactID.GetPartitions()
mapP := models.PartitionsFromIdl(context.TODO(), p)
assert.Equal(t, expPartitions, mapP)
}

func TestURLParseWithVersionAndPartitions(t *testing.T) {
artifactID, tag, err := ParseFlyteURL("flyte://av0.1/project/domain/name@version?foo=bar&ham=spam")
func TestURLParseWithSlashVersionAndPartitions(t *testing.T) {
artifactID, err := ParseFlyteURL("flyte://av0.1/project/domain/name/more@version/abc/0/o0?foo=bar&ham=spam")
expPartitions := map[string]string{"foo": "bar", "ham": "spam"}
assert.NoError(t, err)

assert.Equal(t, "project", artifactID.ArtifactKey.Project)
assert.Equal(t, "domain", artifactID.ArtifactKey.Domain)
assert.Equal(t, "name", artifactID.ArtifactKey.Name)
assert.Equal(t, "version", artifactID.Version)
assert.Equal(t, "", tag)
assert.Equal(t, "name/more", artifactID.ArtifactKey.Name)
assert.Equal(t, "version/abc/0/o0", artifactID.Version)
p := artifactID.GetPartitions()
mapP := models.PartitionsFromIdl(context.TODO(), p)
assert.Equal(t, expPartitions, mapP)
}

func TestURLParseFailsWithBothTagAndPartitions(t *testing.T) {
_, _, err := ParseFlyteURL("flyte://av0.1/project/domain/name:tag?foo=bar&ham=spam")
assert.Error(t, err)
}

func TestURLParseWithBothTagAndVersion(t *testing.T) {
_, _, err := ParseFlyteURL("flyte://av0.1/project/domain/name:tag@version")
assert.Error(t, err)
}

func TestURLParseNameWithSlashes(t *testing.T) {
artifactID, tag, err := ParseFlyteURL("flyte://av0.1/project/domain/name/with/slashes")
artifactID, err := ParseFlyteURL("flyte://av0.1/project/domain/name/with/slashes")
assert.NoError(t, err)
assert.Equal(t, "project", artifactID.ArtifactKey.Project)
assert.Equal(t, "domain", artifactID.ArtifactKey.Domain)
assert.Equal(t, "name/with/slashes", artifactID.ArtifactKey.Name)
assert.Equal(t, "", tag)

artifactID, _, err = ParseFlyteURL("flyte://av0.1/project/domain/name/with/slashes?ds=2020-01-01")
artifactID, err = ParseFlyteURL("flyte://av0.1/project/domain/name/with/slashes?ds=2020-01-01")
assert.NoError(t, err)
assert.Equal(t, "name/with/slashes", artifactID.ArtifactKey.Name)
assert.Equal(t, "project", artifactID.ArtifactKey.Project)
Expand Down
20 changes: 16 additions & 4 deletions flyteartifacts/pkg/server/processor/events_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,25 @@ func (s *ServiceCallHandler) HandleEvent(ctx context.Context, cloudEvent *event2

func (s *ServiceCallHandler) HandleEventExecStart(ctx context.Context, evt *event.CloudEventExecutionStart) error {

if len(evt.ArtifactIds) > 0 {
var inputsUsed []*core.ArtifactID

inputsUsed = append(inputsUsed, evt.ArtifactIds...)
for _, x := range evt.ArtifactTrackers {

dummyURI := fmt.Sprintf("flyte://av0.1/%s", x)
idWithVersion, err := lib.ParseFlyteURL(dummyURI)
if err != nil {
logger.Errorf(ctx, "Error parsing input %s for execution start: %v", x, err)
return err
}
inputsUsed = append(inputsUsed, &idWithVersion)
}

if len(inputsUsed) > 0 {
// metric
req := &artifact.ExecutionInputsRequest{
ExecutionId: evt.ExecutionId,
Inputs: evt.ArtifactIds,
Inputs: inputsUsed,
}
_, err := s.service.SetExecutionInputs(ctx, req)
if err != nil {
Expand Down Expand Up @@ -172,8 +186,6 @@ func getPartitionsAndTag(ctx context.Context, partialID core.ArtifactID, variabl
}

var partitions map[string]string
// todo: consider updating idl to make CreateArtifactRequest just take a full Partitions
// object rather than a mapstrstr @eapolinario @enghabu
if partialID.GetPartitions().GetValue() != nil && len(partialID.GetPartitions().GetValue()) > 0 {
partitions = make(map[string]string, len(partialID.GetPartitions().GetValue()))
for k, lv := range partialID.GetPartitions().GetValue() {
Expand Down
3 changes: 3 additions & 0 deletions flyteartifacts/pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type ArtifactService struct {
}

func (a *ArtifactService) CreateArtifact(ctx context.Context, req *artifact.CreateArtifactRequest) (*artifact.CreateArtifactResponse, error) {

// todo: add a request validating section, check for nils, etc.

resp, err := a.Service.CreateArtifact(ctx, req)
if err != nil {
return resp, err
Expand Down
21 changes: 18 additions & 3 deletions flyteartifacts/pkg/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"context"
"fmt"
"github.com/flyteorg/flyte/flyteartifacts/pkg/lib"

"github.com/flyteorg/flyte/flyteartifacts/pkg/models"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/artifact"
Expand All @@ -14,16 +15,30 @@ import (
type CoreService struct {
Storage StorageInterface
BlobStore BlobStoreInterface
// SearchHandler SearchHandlerInterface
}

// This string is a tracker basically that will be installed in the metadata of the literal. See the ArtifactKey constant for more information.
func (c *CoreService) getTrackingString(request artifact.CreateArtifactRequest) string {
ak := request.ArtifactKey
t := fmt.Sprintf("%s/%s/%s@%s", ak.Project, ak.Domain, ak.Name, request.Version)

return t
}

func (c *CoreService) CreateArtifact(ctx context.Context, request *artifact.CreateArtifactRequest) (*artifact.CreateArtifactResponse, error) {

// todo: gatepr _ua tracking bit to be installed
if request == nil {
// todo: move one layer higher to server.go
if request == nil || request.GetArtifactKey() == nil {
logger.Errorf(ctx, "Ignoring nil or partially nil request")
return nil, nil
}

if request.GetSpec().GetValue().Metadata == nil {
request.GetSpec().GetValue().Metadata = make(map[string]string, 1)
}
trackingStr := c.getTrackingString(*request)
request.GetSpec().GetValue().Metadata[lib.ArtifactKey] = trackingStr

artifactObj, err := models.CreateArtifactModelFromRequest(ctx, request.ArtifactKey, request.Spec, request.Version, request.Partitions, request.Tag, request.Source)
if err != nil {
logger.Errorf(ctx, "Failed to validate Create request: %v", err)
Expand Down

0 comments on commit c446f2e

Please sign in to comment.