Skip to content

Commit

Permalink
Added literal offloading for array node map tasks
Browse files Browse the repository at this point in the history
Signed-off-by: pmahindrakar-oss <[email protected]>
  • Loading branch information
pmahindrakar-oss committed Sep 10, 2024
1 parent 89efcc6 commit 3ee2f0f
Show file tree
Hide file tree
Showing 20 changed files with 575 additions and 89 deletions.
22 changes: 22 additions & 0 deletions flyteadmin/pkg/manager/impl/testutils/mock_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,28 @@ func GetExecutionRequest() *admin.ExecutionCreateRequest {
}
}

func GetExecutionRequestWithOffloadedInputs(inputParam string, literalValue *core.Literal) *admin.ExecutionCreateRequest {
execReq := GetExecutionRequest()
execReq.Inputs = &core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": {
Value: &core.Literal_OffloadedMetadata{
OffloadedMetadata: &core.LiteralOffloadedMetadata{
Uri: "s3://bucket/key",
SizeBytes: 100,
InferredType: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_STRING,
},
},
},
},
},
},
}
return execReq
}

func GetSampleWorkflowSpecForTest() *admin.WorkflowSpec {
return &admin.WorkflowSpec{
Template: &core.WorkflowTemplate{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,13 @@ func CheckAndFetchInputsForExecution(
}
executionInputMap[name] = expectedInput.GetDefault()
} else {
inputType := validators.LiteralTypeForLiteral(executionInputMap[name])
var inputType *core.LiteralType
switch executionInputMap[name].GetValue().(type) {
case *core.Literal_OffloadedMetadata:
inputType = executionInputMap[name].GetOffloadedMetadata().GetInferredType()
default:
inputType = validators.LiteralTypeForLiteral(executionInputMap[name])
}
if !validators.AreTypesCastable(inputType, expectedInput.GetVar().GetType()) {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid %s input wrong type. Expected %s, but got %s", name, expectedInput.GetVar().GetType(), inputType)
}
Expand Down
34 changes: 34 additions & 0 deletions flyteadmin/pkg/manager/impl/validation/execution_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,40 @@ func TestGetExecutionInputs(t *testing.T) {
assert.EqualValues(t, expectedMap, actualInputs)
}

func TestGetExecutionWithOffloadedInputs(t *testing.T) {
execLiteral := &core.Literal{
Value: &core.Literal_OffloadedMetadata{
OffloadedMetadata: &core.LiteralOffloadedMetadata{
Uri: "s3://bucket/key",
SizeBytes: 100,
InferredType: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_STRING,
},
},
},
},
}
executionRequest := testutils.GetExecutionRequestWithOffloadedInputs("foo", execLiteral)
lpRequest := testutils.GetLaunchPlanRequest()

actualInputs, err := CheckAndFetchInputsForExecution(
executionRequest.Inputs,
lpRequest.Spec.FixedInputs,
lpRequest.Spec.DefaultInputs,
)
expectedMap := core.LiteralMap{
Literals: map[string]*core.Literal{
"foo": execLiteral,
"bar": coreutils.MustMakeLiteral("bar-value"),
},
}
assert.Nil(t, err)
assert.NotNil(t, actualInputs)
assert.EqualValues(t, expectedMap.GetLiterals()["foo"], actualInputs.Literals["foo"])
assert.EqualValues(t, expectedMap.GetLiterals()["bar"], actualInputs.Literals["bar"])
}

func TestValidateExecInputsWrongType(t *testing.T) {
executionRequest := testutils.GetExecutionRequest()
lpRequest := testutils.GetLaunchPlanRequest()
Expand Down
2 changes: 1 addition & 1 deletion flyteadmin/pkg/manager/impl/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func validateLiteralMap(inputMap *core.LiteralMap, fieldName string) error {
if name == "" {
return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "missing key in %s", fieldName)
}
if fixedInput == nil || fixedInput.GetValue() == nil {
if fixedInput.GetValue() == nil && fixedInput.GetOffloadedMetadata() == nil {
return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "missing valid literal in %s %s", fieldName, name)
}
if isDateTime(fixedInput) {
Expand Down
1 change: 1 addition & 0 deletions flytepropeller/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.22

require (
github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295
github.com/Masterminds/semver v1.5.0
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1
github.com/fatih/color v1.13.0
github.com/flyteorg/flyte/flyteidl v0.0.0-00010101000000-000000000000
Expand Down
2 changes: 2 additions & 0 deletions flytepropeller/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 h1:xJ0dAkuxJXf
github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295/go.mod h1:e0aH495YLkrsIe9fhedd6aSR6fgU/qhKvtroi6y7G/M=
github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625 h1:cQyO5JQ2iuHnEcF3v24kdDMsgh04RjyFPDtuvD6PCE0=
github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625/go.mod h1:6PnrZv6zUDkrNMw0mIoGRmGBR7i9LulhKPmxFq4rUiM=
github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww=
github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/aws/aws-sdk-go v1.44.2 h1:5VBk5r06bgxgRKVaUtm1/4NT/rtrnH2E4cnAYv5zgQc=
Expand Down
4 changes: 4 additions & 0 deletions flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,10 @@ func GetOutputsFile(outputDir DataReference) DataReference {
return outputDir + "/outputs.pb"
}

func GetOutputsLiteralMetadataFile(literalKey string, outputDir DataReference) DataReference {
return outputDir + DataReference(fmt.Sprintf("/%s_offloaded_metadata.pb", literalKey))
}

func GetInputsFile(inputDir DataReference) DataReference {
return inputDir + "/inputs.pb"
}
Expand Down
8 changes: 7 additions & 1 deletion flytepropeller/pkg/compiler/transformers/k8s/inputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ func validateInputs(nodeID common.NodeID, iface *core.TypedInterface, inputs cor
continue
}

inputType := validators.LiteralTypeForLiteral(inputVal)
var inputType *core.LiteralType
switch inputVal.GetValue().(type) {
case *core.Literal_OffloadedMetadata:
inputType = inputVal.GetOffloadedMetadata().GetInferredType()
default:
inputType = validators.LiteralTypeForLiteral(inputVal)
}
if !validators.AreTypesCastable(inputType, v.Type) {
errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, v.Type.String(), inputType.String()))
continue
Expand Down
119 changes: 85 additions & 34 deletions flytepropeller/pkg/controller/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,16 @@
package config

import (
"context"
"fmt"
"time"

"github.com/Masterminds/semver"
"k8s.io/apimachinery/pkg/types"

"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/contextutils"
"github.com/flyteorg/flyte/flytestdlib/logger"
)

//go:generate pflags Config --default-var=defaultConfig
Expand Down Expand Up @@ -120,47 +124,94 @@ var (
EventVersion: 0,
DefaultParallelismBehavior: ParallelismBehaviorUnlimited,
},
LiteralOffloadingConfig: LiteralOffloadingConfig{
Enabled: true, // Default keep this enabled and in case of issues allow users to disable using config map.
SupportedSDKVersions: map[string]string{ // The key is the SDK name (matches the supported SDK in core.RuntimeMetadata_RuntimeType) and the value is the minimum supported version
"FLYTE_SDK": "1.13.5", // Expected release number with flytekit support from this PR https://github.com/flyteorg/flytekit/pull/2685
},
MinSizeInMBForOffloading: 10, // 10 MB is the default size for offloading
MaxSizeInMBForOffloading: 1000, // 1 GB is the default size before failing fast.
},
}
)

// Config that uses the flytestdlib Config module to generate commandline and load config files. This configuration is
// the base configuration to start propeller
// NOTE: when adding new fields, do not mark them as "omitempty" if it's desirable to read the value from env variables.
type Config struct {
KubeConfigPath string `json:"kube-config" pflag:",Path to kubernetes client config file."`
MasterURL string `json:"master"`
Workers int `json:"workers" pflag:",Number of threads to process workflows"`
WorkflowReEval config.Duration `json:"workflow-reeval-duration" pflag:",Frequency of re-evaluating workflows"`
DownstreamEval config.Duration `json:"downstream-eval-duration" pflag:",Frequency of re-evaluating downstream tasks"`
LimitNamespace string `json:"limit-namespace" pflag:",Namespaces to watch for this propeller"`
ProfilerPort config.Port `json:"prof-port" pflag:",Profiler port"`
MetadataPrefix string `json:"metadata-prefix,omitempty" pflag:",MetadataPrefix should be used if all the metadata for Flyte executions should be stored under a specific prefix in CloudStorage. If not specified, the data will be stored in the base container directly."`
DefaultRawOutputPrefix string `json:"rawoutput-prefix" pflag:",a fully qualified storage path of the form s3://flyte/abc/..., where all data sandboxes should be stored."`
Queue CompositeQueueConfig `json:"queue,omitempty" pflag:",Workflow workqueue configuration, affects the way the work is consumed from the queue."`
MetricsPrefix string `json:"metrics-prefix" pflag:",An optional prefix for all published metrics."`
MetricKeys []string `json:"metrics-keys" pflag:",Metrics labels applied to prometheus metrics emitted by the service."`
EnableAdminLauncher bool `json:"enable-admin-launcher" pflag:"Enable remote Workflow launcher to Admin"`
MaxWorkflowRetries int `json:"max-workflow-retries" pflag:"Maximum number of retries per workflow"`
MaxTTLInHours int `json:"max-ttl-hours" pflag:"Maximum number of hours a completed workflow should be retained. Number between 1-23 hours"`
GCInterval config.Duration `json:"gc-interval" pflag:"Run periodic GC every 30 minutes"`
LeaderElection LeaderElectionConfig `json:"leader-election,omitempty" pflag:",Config for leader election."`
PublishK8sEvents bool `json:"publish-k8s-events" pflag:",Enable events publishing to K8s events API."`
MaxDatasetSizeBytes int64 `json:"max-output-size-bytes" pflag:",Deprecated! Use storage.limits.maxDownloadMBs instead"`
EnableGrpcLatencyMetrics bool `json:"enable-grpc-latency-metrics" pflag:",Enable grpc latency metrics. Note Histograms metrics can be expensive on Prometheus servers."`
KubeConfig KubeClientConfig `json:"kube-client-config" pflag:",Configuration to control the Kubernetes client"`
NodeConfig NodeConfig `json:"node-config,omitempty" pflag:",config for a workflow node"`
MaxStreakLength int `json:"max-streak-length" pflag:",Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled."`
EventConfig EventConfig `json:"event-config,omitempty" pflag:",Configures execution event behavior."`
IncludeShardKeyLabel []string `json:"include-shard-key-label" pflag:",Include the specified shard key label in the k8s FlyteWorkflow CRD label selector"`
ExcludeShardKeyLabel []string `json:"exclude-shard-key-label" pflag:",Exclude the specified shard key label from the k8s FlyteWorkflow CRD label selector"`
IncludeProjectLabel []string `json:"include-project-label" pflag:",Include the specified project label in the k8s FlyteWorkflow CRD label selector"`
ExcludeProjectLabel []string `json:"exclude-project-label" pflag:",Exclude the specified project label from the k8s FlyteWorkflow CRD label selector"`
IncludeDomainLabel []string `json:"include-domain-label" pflag:",Include the specified domain label in the k8s FlyteWorkflow CRD label selector"`
ExcludeDomainLabel []string `json:"exclude-domain-label" pflag:",Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector"`
ClusterID string `json:"cluster-id" pflag:",Unique cluster id running this flytepropeller instance with which to annotate execution events"`
CreateFlyteWorkflowCRD bool `json:"create-flyteworkflow-crd" pflag:",Enable creation of the FlyteWorkflow CRD on startup"`
NodeExecutionWorkerCount int `json:"node-execution-worker-count" pflag:",Number of workers to evaluate node executions, currently only used for array nodes"`
ArrayNode ArrayNodeConfig `json:"array-node-config,omitempty" pflag:",Configuration for array nodes"`
KubeConfigPath string `json:"kube-config" pflag:",Path to kubernetes client config file."`
MasterURL string `json:"master"`
Workers int `json:"workers" pflag:",Number of threads to process workflows"`
WorkflowReEval config.Duration `json:"workflow-reeval-duration" pflag:",Frequency of re-evaluating workflows"`
DownstreamEval config.Duration `json:"downstream-eval-duration" pflag:",Frequency of re-evaluating downstream tasks"`
LimitNamespace string `json:"limit-namespace" pflag:",Namespaces to watch for this propeller"`
ProfilerPort config.Port `json:"prof-port" pflag:",Profiler port"`
MetadataPrefix string `json:"metadata-prefix,omitempty" pflag:",MetadataPrefix should be used if all the metadata for Flyte executions should be stored under a specific prefix in CloudStorage. If not specified, the data will be stored in the base container directly."`
DefaultRawOutputPrefix string `json:"rawoutput-prefix" pflag:",a fully qualified storage path of the form s3://flyte/abc/..., where all data sandboxes should be stored."`
Queue CompositeQueueConfig `json:"queue,omitempty" pflag:",Workflow workqueue configuration, affects the way the work is consumed from the queue."`
MetricsPrefix string `json:"metrics-prefix" pflag:",An optional prefix for all published metrics."`
MetricKeys []string `json:"metrics-keys" pflag:",Metrics labels applied to prometheus metrics emitted by the service."`
EnableAdminLauncher bool `json:"enable-admin-launcher" pflag:"Enable remote Workflow launcher to Admin"`
MaxWorkflowRetries int `json:"max-workflow-retries" pflag:"Maximum number of retries per workflow"`
MaxTTLInHours int `json:"max-ttl-hours" pflag:"Maximum number of hours a completed workflow should be retained. Number between 1-23 hours"`
GCInterval config.Duration `json:"gc-interval" pflag:"Run periodic GC every 30 minutes"`
LeaderElection LeaderElectionConfig `json:"leader-election,omitempty" pflag:",Config for leader election."`
PublishK8sEvents bool `json:"publish-k8s-events" pflag:",Enable events publishing to K8s events API."`
MaxDatasetSizeBytes int64 `json:"max-output-size-bytes" pflag:",Deprecated! Use storage.limits.maxDownloadMBs instead"`
EnableGrpcLatencyMetrics bool `json:"enable-grpc-latency-metrics" pflag:",Enable grpc latency metrics. Note Histograms metrics can be expensive on Prometheus servers."`
KubeConfig KubeClientConfig `json:"kube-client-config" pflag:",Configuration to control the Kubernetes client"`
NodeConfig NodeConfig `json:"node-config,omitempty" pflag:",config for a workflow node"`
MaxStreakLength int `json:"max-streak-length" pflag:",Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled."`
EventConfig EventConfig `json:"event-config,omitempty" pflag:",Configures execution event behavior."`
IncludeShardKeyLabel []string `json:"include-shard-key-label" pflag:",Include the specified shard key label in the k8s FlyteWorkflow CRD label selector"`
ExcludeShardKeyLabel []string `json:"exclude-shard-key-label" pflag:",Exclude the specified shard key label from the k8s FlyteWorkflow CRD label selector"`
IncludeProjectLabel []string `json:"include-project-label" pflag:",Include the specified project label in the k8s FlyteWorkflow CRD label selector"`
ExcludeProjectLabel []string `json:"exclude-project-label" pflag:",Exclude the specified project label from the k8s FlyteWorkflow CRD label selector"`
IncludeDomainLabel []string `json:"include-domain-label" pflag:",Include the specified domain label in the k8s FlyteWorkflow CRD label selector"`
ExcludeDomainLabel []string `json:"exclude-domain-label" pflag:",Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector"`
ClusterID string `json:"cluster-id" pflag:",Unique cluster id running this flytepropeller instance with which to annotate execution events"`
CreateFlyteWorkflowCRD bool `json:"create-flyteworkflow-crd" pflag:",Enable creation of the FlyteWorkflow CRD on startup"`
NodeExecutionWorkerCount int `json:"node-execution-worker-count" pflag:",Number of workers to evaluate node executions, currently only used for array nodes"`
ArrayNode ArrayNodeConfig `json:"array-node-config,omitempty" pflag:",Configuration for array nodes"`
LiteralOffloadingConfig LiteralOffloadingConfig `json:"literalOffloadingConfig" pflag:",config used for literal offloading."`
}

type LiteralOffloadingConfig struct {
Enabled bool
// Maps flytekit and union SDK names to minimum supported version that can handle reading offloaded literals.
SupportedSDKVersions map[string]string
// Default, 10Mbs. Determines the size of a literal at which to trigger offloading
MinSizeInMBForOffloading int64
// Fail fast threshold
MaxSizeInMBForOffloading int64
}

// IsSupportedSDKVersion returns true if the provided SDK and version are supported by the literal offloading config.
func (l LiteralOffloadingConfig) IsSupportedSDKVersion(sdk string, versionString string) bool {
if leastSupportedVersion, ok := l.SupportedSDKVersions[sdk]; ok {
c, err := semver.NewConstraint(fmt.Sprintf(">= %s", leastSupportedVersion))
if err != nil {
// This should never happen
logger.Warnf(context.TODO(), "Failed to parse version constraint %s", leastSupportedVersion)
return false
}
version, err := semver.NewVersion(versionString)
if err != nil {
// This should never happen
logger.Warnf(context.TODO(), "Failed to parse version %s", versionString)
return false
}
return c.Check(version)
}
return false
}

// GetSupportedSDKVersion returns the least supported version for the provided SDK.
func (l LiteralOffloadingConfig) GetSupportedSDKVersion(sdk string) string {
if leastSupportedVersion, ok := l.SupportedSDKVersions[sdk]; ok {
return leastSupportedVersion
}
return ""
}

// KubeClientConfig contains the configuration used by flytepropeller to configure its internal Kubernetes Client.
Expand Down
54 changes: 54 additions & 0 deletions flytepropeller/pkg/controller/config/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package config

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestIsSupportedSDKVersion(t *testing.T) {
t.Run("supported version", func(t *testing.T) {
config := LiteralOffloadingConfig{
SupportedSDKVersions: map[string]string{
"flytekit": "0.16.0",
},
}
assert.True(t, config.IsSupportedSDKVersion("flytekit", "0.16.0"))
})

t.Run("unsupported version", func(t *testing.T) {
config := LiteralOffloadingConfig{
SupportedSDKVersions: map[string]string{
"flytekit": "0.16.0",
},
}
assert.False(t, config.IsSupportedSDKVersion("flytekit", "0.15.0"))
})

t.Run("unsupported SDK", func(t *testing.T) {
config := LiteralOffloadingConfig{
SupportedSDKVersions: map[string]string{
"flytekit": "0.16.0",
},
}
assert.False(t, config.IsSupportedSDKVersion("unknown", "0.16.0"))
})

t.Run("invalid version", func(t *testing.T) {
config := LiteralOffloadingConfig{
SupportedSDKVersions: map[string]string{
"flytekit": "0.16.0",
},
}
assert.False(t, config.IsSupportedSDKVersion("flytekit", "invalid"))
})

t.Run("invalid constraint", func(t *testing.T) {
config := LiteralOffloadingConfig{
SupportedSDKVersions: map[string]string{
"flytekit": "invalid",
},
}
assert.False(t, config.IsSupportedSDKVersion("flytekit", "0.16.0"))
})
}
Loading

0 comments on commit 3ee2f0f

Please sign in to comment.