From 58dee5bd5838b1d69aac5a14649cad2d9ef4acf9 Mon Sep 17 00:00:00 2001 From: Chang-Hong Hsu Date: Thu, 29 Oct 2020 20:28:01 -0700 Subject: [PATCH 1/4] add support for hpo on custom training; add tests --- .../k8s/sagemaker/hyperparameter_tuning.go | 178 +++++++++++------- .../sagemaker/hyperparameter_tuning_test.go | 80 +++++--- .../k8s/sagemaker/plugin_test_utils.go | 108 ++++++++--- go/tasks/plugins/k8s/sagemaker/utils.go | 27 ++- go/tasks/plugins/k8s/sagemaker/utils_test.go | 48 +++-- 5 files changed, 309 insertions(+), 132 deletions(-) diff --git a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go index 6df79accd..dce4e5352 100644 --- a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go +++ b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go @@ -57,6 +57,8 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob( return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Required field [TrainingJobResourceConfig] of the HyperparameterTuningJob's underlying TrainingJob does not exist") } + trainingJobType := sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetAlgorithmName() + taskInput, err := taskCtx.InputReader().Get(ctx) if err != nil { return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "unable to fetch task inputs") @@ -64,56 +66,131 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob( // Get inputs from literals inputLiterals := taskInput.GetLiterals() - err = checkIfRequiredInputLiteralsExist(inputLiterals, - []string{TrainPredefinedInputVariable, ValidationPredefinedInputVariable, StaticHyperparametersPredefinedInputVariable}) + + hpoJobConfigLiteral := inputLiterals["hyperparameter_tuning_job_config"] + // hyperparameter_tuning_job_config is marshaled into a struct in flytekit, so will have to unmarshal it back + hpoJobConfig, err := convertHyperparameterTuningJobConfigToSpecType(hpoJobConfigLiteral) if err != nil { - return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Error occurred when checking if all the required inputs exist") + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to convert hyperparameter tuning job config literal to spec type") } + logger.Infof(ctx, "hyperparameter tuning job config = [%v]", hpoJobConfig) - trainPathLiteral := inputLiterals[TrainPredefinedInputVariable] - validatePathLiteral := inputLiterals[ValidationPredefinedInputVariable] - staticHyperparamsLiteral := inputLiterals[StaticHyperparametersPredefinedInputVariable] - hpoJobConfigLiteral := inputLiterals["hyperparameter_tuning_job_config"] - if trainPathLiteral.GetScalar() == nil || trainPathLiteral.GetScalar().GetBlob() == nil { - return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[%v] Input is required and should be of Type [Scalar.Blob]", TrainPredefinedInputVariable) + // Extracting the tunable hyperparameters from the input literals + hpoJobParameterRanges := buildParameterRanges(ctx, inputLiterals) + + for _, catpr := range hpoJobParameterRanges.CategoricalParameterRanges { + logger.Infof(ctx, "CategoricalParameterRange: [%v]: %v", *catpr.Name, catpr.Values) } - if validatePathLiteral.GetScalar() == nil || validatePathLiteral.GetScalar().GetBlob() == nil { - return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[%v] Input is required and should be of Type [Scalar.Blob]", ValidationPredefinedInputVariable) + for _, intpr := range hpoJobParameterRanges.IntegerParameterRanges { + logger.Infof(ctx, "IntegerParameterRange: [%v]: (max:%v, min:%v, scaling:%v)", *intpr.Name, *intpr.MaxValue, *intpr.MinValue, intpr.ScalingType) } - // Convert the hyperparameters to the spec value - staticHyperparams, err := convertStaticHyperparamsLiteralToSpecType(staticHyperparamsLiteral) - if err != nil { - return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "could not convert static hyperparameters to spec type") + for _, conpr := range hpoJobParameterRanges.ContinuousParameterRanges { + logger.Infof(ctx, "ContinuousParameterRange [%v]: (max:%v, min:%v, scaling:%v)", *conpr.Name, *conpr.MaxValue, *conpr.MinValue, conpr.ScalingType) } - // hyperparameter_tuning_job_config is marshaled into a struct in flytekit, so will have to unmarshal it back - hpoJobConfig, err := convertHyperparameterTuningJobConfigToSpecType(hpoJobConfigLiteral) - if err != nil { - return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to convert hyperparameter tuning job config literal to spec type") + inputModeString := strings.Title(strings.ToLower(sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputMode().String())) + + var staticHyperparams []*commonv1.KeyValuePair + var inputChannels []commonv1.Channel + var trainingImageStr string + if trainingJobType != flyteSageMakerIdl.AlgorithmName_CUSTOM { + logger.Infof(ctx, "The hyperparameter tuning job is wrapping around a built-in algorithm training job") + requiredInputs := []string{TrainPredefinedInputVariable, ValidationPredefinedInputVariable, StaticHyperparametersPredefinedInputVariable} + logger.Infof(ctx, "Checking if required inputs exist [%v]", requiredInputs) + // train, validation, and static_hyperparameters are the default required inputs for hpo job that wraps + // around a built-in algorithm training job + err = checkIfRequiredInputLiteralsExist(inputLiterals, requiredInputs) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Error occurred when checking if all the required inputs exist") + } + + trainPathLiteral := inputLiterals[TrainPredefinedInputVariable] + validatePathLiteral := inputLiterals[ValidationPredefinedInputVariable] + staticHyperparamsLiteral := inputLiterals[StaticHyperparametersPredefinedInputVariable] + + if trainPathLiteral.GetScalar() == nil || trainPathLiteral.GetScalar().GetBlob() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[%v] Input is required and should be of Type [Scalar.Blob]", TrainPredefinedInputVariable) + } + if validatePathLiteral.GetScalar() == nil || validatePathLiteral.GetScalar().GetBlob() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[%v] Input is required and should be of Type [Scalar.Blob]", ValidationPredefinedInputVariable) + } + // Convert the hyperparameters to the spec value + staticHyperparams, err := convertStaticHyperparamsLiteralToSpecType(staticHyperparamsLiteral) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "could not convert static hyperparameters to spec type") + } + + // Deleting the conflicting static hyperparameters: if a hyperparameter exist in both the map of static hyperparameter + // and the map of the tunable hyperparameter inside the Hyperparameter Tuning Job Config, we delete the entry + // in the static map and let the one in the map of the tunable hyperparameters take precedence + staticHyperparams = deleteConflictingStaticHyperparameters(ctx, staticHyperparams, hpoJobParameterRanges) + logger.Infof(ctx, "Sagemaker HyperparameterTuningJob Task plugin will proceed with the following static hyperparameters:") + for _, shp := range staticHyperparams { + logger.Infof(ctx, "(%v, %v)", shp.Name, shp.Value) + } + + apiContentType, err := getAPIContentType(sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputContentType()) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Unsupported input file type [%v]", + sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputContentType().String()) + } + + inputChannels = []commonv1.Channel{ + { + ChannelName: ToStringPtr(TrainPredefinedInputVariable), + DataSource: &commonv1.DataSource{ + S3DataSource: &commonv1.S3DataSource{ + S3DataType: "S3Prefix", + S3Uri: ToStringPtr(trainPathLiteral.GetScalar().GetBlob().GetUri()), + }, + }, + ContentType: ToStringPtr(apiContentType), // TODO: can this be derived from the BlobMetadata + InputMode: inputModeString, + }, + { + ChannelName: ToStringPtr(ValidationPredefinedInputVariable), + DataSource: &commonv1.DataSource{ + S3DataSource: &commonv1.S3DataSource{ + S3DataType: "S3Prefix", + S3Uri: ToStringPtr(validatePathLiteral.GetScalar().GetBlob().GetUri()), + }, + }, + ContentType: ToStringPtr(apiContentType), // TODO: can this be derived from the BlobMetadata + InputMode: inputModeString, + }, + } + + trainingImageStr, err = getTrainingJobImage(ctx, taskCtx, sagemakerHPOJob.GetTrainingJob()) + if err != nil { + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to find the training image") + } + } else { + // For hpo job that wraps around a custom training job, there has to be at least one tunable hyperparameter in + // the input list + if len(inputLiterals) < 1 || + (len(hpoJobParameterRanges.ContinuousParameterRanges) < 1 && len(hpoJobParameterRanges.IntegerParameterRanges) < 1 && len(hpoJobParameterRanges.CategoricalParameterRanges) < 1) { + + return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "There has to be at least one input for a hyperparameter tuning job wrapping around a custom-training job") + } + + if taskTemplate.GetContainer().GetImage() == "" { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Invalid image of the container") + } + inputChannels = nil + trainingImageStr = taskTemplate.GetContainer().GetImage() } outputPath := createOutputPath(taskCtx.OutputWriter().GetRawOutputPrefix().String(), HyperparameterOutputPathSubDir) + if hpoJobConfig.GetTuningObjective() == nil { return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Required field [TuningObjective] does not exist") } - // Deleting the conflicting static hyperparameters: if a hyperparameter exist in both the map of static hyperparameter - // and the map of the tunable hyperparameter inside the Hyperparameter Tuning Job Config, we delete the entry - // in the static map and let the one in the map of the tunable hyperparameters take precedence - staticHyperparams = deleteConflictingStaticHyperparameters(ctx, staticHyperparams, hpoJobConfig.GetHyperparameterRanges().GetParameterRangeMap()) - jobName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - trainingImageStr, err := getTrainingJobImage(ctx, taskCtx, sagemakerHPOJob.GetTrainingJob()) - if err != nil { - return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to find the training image") - } + if len(hpoJobParameterRanges.CategoricalParameterRanges) == 0 && len(hpoJobParameterRanges.ContinuousParameterRanges) == 0 { - hpoJobParameterRanges := buildParameterRanges(ctx, inputLiterals) - logger.Infof(ctx, "The Sagemaker HyperparameterTuningJob Task plugin received the following inputs: \n"+ - "static hyperparameters: [%v]\n"+ - "hyperparameter tuning job config: [%v]\n"+ - "parameter ranges: [%v]", staticHyperparams, hpoJobConfig, hpoJobParameterRanges) + } cfg := config.GetSagemakerConfig() @@ -124,13 +201,6 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob( commonv1.MetricDefinition{Name: ToStringPtr(md.Name), Regex: ToStringPtr(md.Regex)}) } - apiContentType, err := getAPIContentType(sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputContentType()) - if err != nil { - return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Unsupported input file type [%v]", - sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputContentType().String()) - } - - inputModeString := strings.Title(strings.ToLower(sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputMode().String())) tuningStrategyString := strings.Title(strings.ToLower(hpoJobConfig.GetTuningStrategy().String())) tuningObjectiveTypeString := strings.Title(strings.ToLower(hpoJobConfig.GetTuningObjective().GetObjectiveType().String())) trainingJobEarlyStoppingTypeString := strings.Title(strings.ToLower(hpoJobConfig.TrainingJobEarlyStoppingType.String())) @@ -157,6 +227,7 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob( TrainingJobEarlyStoppingType: commonv1.TrainingJobEarlyStoppingType(trainingJobEarlyStoppingTypeString), }, TrainingJobDefinition: &commonv1.HyperParameterTrainingJobDefinition{ + // If the underlying training job is a custom training job, this will be nil StaticHyperParameters: staticHyperparams, AlgorithmSpecification: &commonv1.HyperParameterAlgorithmSpecification{ TrainingImage: ToStringPtr(trainingImageStr), @@ -164,30 +235,7 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob( MetricDefinitions: metricDefinitions, AlgorithmName: nil, }, - InputDataConfig: []commonv1.Channel{ - { - ChannelName: ToStringPtr(TrainPredefinedInputVariable), - DataSource: &commonv1.DataSource{ - S3DataSource: &commonv1.S3DataSource{ - S3DataType: "S3Prefix", - S3Uri: ToStringPtr(trainPathLiteral.GetScalar().GetBlob().GetUri()), - }, - }, - ContentType: ToStringPtr(apiContentType), // TODO: can this be derived from the BlobMetadata - InputMode: inputModeString, - }, - { - ChannelName: ToStringPtr(ValidationPredefinedInputVariable), - DataSource: &commonv1.DataSource{ - S3DataSource: &commonv1.S3DataSource{ - S3DataType: "S3Prefix", - S3Uri: ToStringPtr(validatePathLiteral.GetScalar().GetBlob().GetUri()), - }, - }, - ContentType: ToStringPtr(apiContentType), // TODO: can this be derived from the BlobMetadata - InputMode: inputModeString, - }, - }, + InputDataConfig: inputChannels, OutputDataConfig: &commonv1.OutputDataConfig{ S3OutputPath: ToStringPtr(outputPath), }, @@ -200,7 +248,7 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob( RoleArn: ToStringPtr(role), StoppingCondition: &commonv1.StoppingCondition{ MaxRuntimeInSeconds: ToInt64Ptr(86400), - MaxWaitTimeInSeconds: nil, + MaxWaitTimeInSeconds: nil, // We currently don't have a conclusion how to set a value for this }, }, Region: ToStringPtr(cfg.Region), diff --git a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go index 45992321d..11d411115 100644 --- a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go +++ b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go @@ -25,30 +25,60 @@ func Test_awsSagemakerPlugin_BuildResourceForHyperparameterTuningJob(t *testing. panic(err) } defaultCfg := config.GetSagemakerConfig() - awsSageMakerHPOJobHandler := awsSagemakerPlugin{TaskType: hyperparameterTuningJobTaskType} - - tjObj := generateMockTrainingJobCustomObj( - sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, - sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED) - htObj := generateMockHyperparameterTuningJobCustomObj(tjObj, 10, 5) - taskTemplate := generateMockHyperparameterTuningJobTaskTemplate("the job", htObj) - hpoJobResource, err := awsSageMakerHPOJobHandler.BuildResource(ctx, generateMockHyperparameterTuningJobTaskContext(taskTemplate)) - assert.NoError(t, err) - assert.NotNil(t, hpoJobResource) - - hpoJob, ok := hpoJobResource.(*hpojobv1.HyperparameterTuningJob) - assert.True(t, ok) - assert.NotNil(t, hpoJob.Spec.TrainingJobDefinition) - assert.Equal(t, 1, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.IntegerParameterRanges)) - assert.Equal(t, 0, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.ContinuousParameterRanges)) - assert.Equal(t, 0, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.CategoricalParameterRanges)) - assert.Equal(t, "us-east-1", *hpoJob.Spec.Region) - assert.Equal(t, "default_role", *hpoJob.Spec.TrainingJobDefinition.RoleArn) - - err = config.SetSagemakerConfig(defaultCfg) - if err != nil { - panic(err) - } + defer func() { + _ = config.SetSagemakerConfig(defaultCfg) + }() + + t.Run("hpo on built-in algorithm training", func(t *testing.T) { + awsSageMakerHPOJobHandler := awsSagemakerPlugin{TaskType: hyperparameterTuningJobTaskType} + + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_XGBOOST, "0.90", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED) + htObj := generateMockHyperparameterTuningJobCustomObj(tjObj, 10, 5) + taskTemplate := generateMockHyperparameterTuningJobTaskTemplate("the job", htObj) + hpoJobResource, err := awsSageMakerHPOJobHandler.BuildResource(ctx, generateMockHyperparameterTuningJobTaskContext(taskTemplate, trainingJobTaskType)) + assert.NoError(t, err) + assert.NotNil(t, hpoJobResource) + + hpoJob, ok := hpoJobResource.(*hpojobv1.HyperparameterTuningJob) + assert.True(t, ok) + assert.NotNil(t, hpoJob.Spec.TrainingJobDefinition) + assert.Equal(t, 1, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.IntegerParameterRanges)) + assert.Equal(t, 0, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.ContinuousParameterRanges)) + assert.Equal(t, 0, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.CategoricalParameterRanges)) + assert.Equal(t, "us-east-1", *hpoJob.Spec.Region) + assert.Equal(t, "default_role", *hpoJob.Spec.TrainingJobDefinition.RoleArn) + assert.NotNil(t, hpoJob.Spec.TrainingJobDefinition.InputDataConfig) + // Image uri should come from config + assert.Equal(t, defaultCfg.PrebuiltAlgorithms[0].RegionalConfig[0].VersionConfigs[0].Image, *hpoJob.Spec.TrainingJobDefinition.AlgorithmSpecification.TrainingImage) + }) + + t.Run("hpo on custom training", func(t *testing.T) { + awsSageMakerHPOJobHandler := awsSagemakerPlugin{TaskType: hyperparameterTuningJobTaskType} + + tjObj := generateMockTrainingJobCustomObj( + sagemakerIdl.InputMode_FILE, sagemakerIdl.AlgorithmName_CUSTOM, "0.90", []*sagemakerIdl.MetricDefinition{}, + sagemakerIdl.InputContentType_TEXT_CSV, 2, "ml.p3.2xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED) + htObj := generateMockHyperparameterTuningJobCustomObj(tjObj, 10, 5) + taskTemplate := generateMockHyperparameterTuningJobTaskTemplate("the job", htObj) + taskContext := generateMockHyperparameterTuningJobTaskContext(taskTemplate, customTrainingJobTaskType) + hpoJobResource, err := awsSageMakerHPOJobHandler.BuildResource(ctx, taskContext) + assert.NoError(t, err) + assert.NotNil(t, hpoJobResource) + + hpoJob, ok := hpoJobResource.(*hpojobv1.HyperparameterTuningJob) + assert.True(t, ok) + assert.NotNil(t, hpoJob.Spec.TrainingJobDefinition) + assert.Equal(t, 1, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.IntegerParameterRanges)) + assert.Equal(t, 1, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.ContinuousParameterRanges)) + assert.Equal(t, 1, len(hpoJob.Spec.HyperParameterTuningJobConfig.ParameterRanges.CategoricalParameterRanges)) + assert.Equal(t, "us-east-1", *hpoJob.Spec.Region) + assert.Equal(t, "default_role", *hpoJob.Spec.TrainingJobDefinition.RoleArn) + assert.Nil(t, hpoJob.Spec.TrainingJobDefinition.InputDataConfig) + // Image uri should come from taskContext + assert.Equal(t, testImage, *hpoJob.Spec.TrainingJobDefinition.AlgorithmSpecification.TrainingImage) + }) } func Test_awsSagemakerPlugin_getEventInfoForHyperparameterTuningJob(t *testing.T) { @@ -77,7 +107,7 @@ func Test_awsSagemakerPlugin_getEventInfoForHyperparameterTuningJob(t *testing.T sagemakerIdl.InputContentType_TEXT_CSV, 1, "ml.m4.xlarge", 25, sagemakerIdl.DistributedProtocol_UNSPECIFIED) htObj := generateMockHyperparameterTuningJobCustomObj(tjObj, 10, 5) taskTemplate := generateMockHyperparameterTuningJobTaskTemplate("the job", htObj) - taskCtx := generateMockHyperparameterTuningJobTaskContext(taskTemplate) + taskCtx := generateMockHyperparameterTuningJobTaskContext(taskTemplate, trainingJobTaskType) hpoJobResource, err := awsSageMakerHPOJobHandler.BuildResource(ctx, taskCtx) assert.NoError(t, err) assert.NotNil(t, hpoJobResource) diff --git a/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go b/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go index b23047cef..01427a0e6 100644 --- a/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go +++ b/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go @@ -302,11 +302,8 @@ func generateMockBlobLiteral(loc storage.DataReference) *flyteIdlCore.Literal { } } -func generateMockHyperparameterTuningJobTaskContext(taskTemplate *flyteIdlCore.TaskTemplate) pluginsCore.TaskExecutionContext { +func generateMockHyperparameterTuningJobTaskContext(taskTemplate *flyteIdlCore.TaskTemplate, taskType pluginsCore.TaskType) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} - inputReader := &pluginIOMocks.InputReader{} - inputReader.OnGetInputPrefixPath().Return(storage.DataReference("/input/prefix")) - inputReader.OnGetInputPath().Return(storage.DataReference("/input")) trainBlobLoc := storage.DataReference("train-blob-loc") validationBlobLoc := storage.DataReference("validation-blob-loc") @@ -322,31 +319,52 @@ func generateMockHyperparameterTuningJobTaskContext(taskTemplate *flyteIdlCore.T } hpoJobConfigByteArray, _ := proto.Marshal(&hpoJobConfig) - intParamRange := &structpb.Struct{} - err := utils.MarshalStruct(&sagemakerIdl.ParameterRangeOneOf{ - ParameterRangeType: &sagemakerIdl.ParameterRangeOneOf_IntegerParameterRange{ - IntegerParameterRange: &sagemakerIdl.IntegerParameterRange{ - MaxValue: 2, - MinValue: 1, - ScalingType: sagemakerIdl.HyperparameterScalingType_LINEAR, - }, - }, - }, intParamRange) + intParamRange, err := generateMockIntegerParameterRange(100, 10, sagemakerIdl.HyperparameterScalingType_LINEAR) if err != nil { panic(err) } - inputReader.OnGetMatch(mock.Anything).Return( - &flyteIdlCore.LiteralMap{ - Literals: map[string]*flyteIdlCore.Literal{ - "train": generateMockBlobLiteral(trainBlobLoc), - "validation": generateMockBlobLiteral(validationBlobLoc), - "static_hyperparameters": utils.MakeGenericLiteral(shpStructObj), - "hyperparameter_tuning_job_config": utils.MakeBinaryLiteral(hpoJobConfigByteArray), - "a": utils.MakeGenericLiteral(intParamRange), - }, - }, nil) + catPR1, err := generateMockCategoricalParameterRange([]string{"aaa", "bbb"}) + if err != nil { + panic(err) + } + + conPR1, err := generateMockContinuousParameterRange(5.7, 2.0, sagemakerIdl.HyperparameterScalingType_LOGARITHMIC) + if err != nil { + panic(err) + } + + inputReader := &pluginIOMocks.InputReader{} + inputReader.OnGetInputPrefixPath().Return(storage.DataReference("/input/prefix")) + inputReader.OnGetInputPath().Return(storage.DataReference("/input")) + if taskType == trainingJobTaskType { + inputReader.OnGetMatch(mock.Anything).Return( + &flyteIdlCore.LiteralMap{ + Literals: map[string]*flyteIdlCore.Literal{ + "train": generateMockBlobLiteral(trainBlobLoc), + "validation": generateMockBlobLiteral(validationBlobLoc), + "static_hyperparameters": utils.MakeGenericLiteral(shpStructObj), + "hyperparameter_tuning_job_config": utils.MakeBinaryLiteral(hpoJobConfigByteArray), + "a": utils.MakeGenericLiteral(intParamRange), + }, + }, nil) + + } else if taskType == customTrainingJobTaskType { + + inputReader.OnGetMatch(mock.Anything).Return( + &flyteIdlCore.LiteralMap{ + Literals: map[string]*flyteIdlCore.Literal{ + "hyperparameter_tuning_job_config": utils.MakeBinaryLiteral(hpoJobConfigByteArray), + "cat_hp1": utils.MakeGenericLiteral(catPR1), + "val": generateMockBlobLiteral(validationBlobLoc), + "input_1": utils.MustMakeLiteral("123"), + "int_hp1": utils.MakeGenericLiteral(intParamRange), + "con_hp1": utils.MakeGenericLiteral(conPR1), + }, + }, nil) + } + taskCtx.OnInputReader().Return(inputReader) outputReader := &pluginIOMocks.OutputWriter{} @@ -363,6 +381,48 @@ func generateMockHyperparameterTuningJobTaskContext(taskTemplate *flyteIdlCore.T return taskCtx } +func generateMockIntegerParameterRange( + maxValue, minValue int64, scaleType sagemakerIdl.HyperparameterScalingType_Value) (*structpb.Struct, error) { + intParamRange := &structpb.Struct{} + err := utils.MarshalStruct(&sagemakerIdl.ParameterRangeOneOf{ + ParameterRangeType: &sagemakerIdl.ParameterRangeOneOf_IntegerParameterRange{ + IntegerParameterRange: &sagemakerIdl.IntegerParameterRange{ + MaxValue: maxValue, + MinValue: minValue, + ScalingType: scaleType, + }, + }, + }, intParamRange) + return intParamRange, err +} + +func generateMockCategoricalParameterRange(values []string) (*structpb.Struct, error) { + catParamRange := &structpb.Struct{} + err := utils.MarshalStruct(&sagemakerIdl.ParameterRangeOneOf{ + ParameterRangeType: &sagemakerIdl.ParameterRangeOneOf_CategoricalParameterRange{ + CategoricalParameterRange: &sagemakerIdl.CategoricalParameterRange{ + Values: values, + }, + }, + }, catParamRange) + return catParamRange, err +} + +func generateMockContinuousParameterRange( + maxValue, minValue float64, scaleType sagemakerIdl.HyperparameterScalingType_Value) (*structpb.Struct, error) { + conParamRange := &structpb.Struct{} + err := utils.MarshalStruct(&sagemakerIdl.ParameterRangeOneOf{ + ParameterRangeType: &sagemakerIdl.ParameterRangeOneOf_ContinuousParameterRange{ + ContinuousParameterRange: &sagemakerIdl.ContinuousParameterRange{ + MaxValue: maxValue, + MinValue: minValue, + ScalingType: scaleType, + }, + }, + }, conParamRange) + return conParamRange, err +} + func genMockTaskExecutionMetadata() *mocks.TaskExecutionMetadata { tID := &mocks.TaskExecutionID{} tID.OnGetID().Return(flyteIdlCore.TaskExecutionIdentifier{ diff --git a/go/tasks/plugins/k8s/sagemaker/utils.go b/go/tasks/plugins/k8s/sagemaker/utils.go index 6c32cbeaf..1be8feaa6 100644 --- a/go/tasks/plugins/k8s/sagemaker/utils.go +++ b/go/tasks/plugins/k8s/sagemaker/utils.go @@ -275,16 +275,33 @@ func ToFloat64Ptr(f float64) *float64 { func deleteConflictingStaticHyperparameters( ctx context.Context, staticHPs []*commonv1.KeyValuePair, - tunableHPMap map[string]*flyteSagemakerIdl.ParameterRangeOneOf) []*commonv1.KeyValuePair { + tunableHPs *commonv1.ParameterRanges) []*commonv1.KeyValuePair { + //tunableHPMap map[string]*flyteSagemakerIdl.ParameterRangeOneOf resolvedStaticHPs := make([]*commonv1.KeyValuePair, 0, len(staticHPs)) - for _, hp := range staticHPs { - if _, found := tunableHPMap[hp.Name]; !found { - resolvedStaticHPs = append(resolvedStaticHPs, hp) + for _, staticHP := range staticHPs { + conflict := false + for _, tunableHP := range tunableHPs.ContinuousParameterRanges { + if staticHP.Name == *tunableHP.Name { + conflict = true + } + } + for _, tunableHP := range tunableHPs.IntegerParameterRanges { + if staticHP.Name == *tunableHP.Name { + conflict = true + } + } + for _, tunableHP := range tunableHPs.CategoricalParameterRanges { + if staticHP.Name == *tunableHP.Name { + conflict = true + } + } + if !conflict { + resolvedStaticHPs = append(resolvedStaticHPs, staticHP) } else { logger.Infof(ctx, - "Static hyperparameter [%v] is removed because the same hyperparameter can be found in the map of tunable hyperparameters", hp.Name) + "Static hyperparameter [%v] is removed because the same hyperparameter can be found in the map of tunable hyperparameters", staticHP.Name) } } return resolvedStaticHPs diff --git a/go/tasks/plugins/k8s/sagemaker/utils_test.go b/go/tasks/plugins/k8s/sagemaker/utils_test.go index 43f660cfa..d7a7feea4 100644 --- a/go/tasks/plugins/k8s/sagemaker/utils_test.go +++ b/go/tasks/plugins/k8s/sagemaker/utils_test.go @@ -52,6 +52,28 @@ func generateParameterRangeInputs() map[string]*core.Literal { return res } +func generateMockTunableHPs() *commonv1.ParameterRanges { + return &commonv1.ParameterRanges{ + IntegerParameterRanges: []commonv1.IntegerParameterRange{ + { + Name: ToStringPtr("hp1"), + MaxValue: ToStringPtr("10"), + MinValue: ToStringPtr("0"), + ScalingType: commonv1.HyperParameterScalingType(sagemakerSpec.HyperparameterScalingType_AUTO.String()), + }, + }, + ContinuousParameterRanges: []commonv1.ContinuousParameterRange{ + { + Name: ToStringPtr("hp2"), + MaxValue: ToStringPtr("5.0"), + MinValue: ToStringPtr("3.0"), + ScalingType: commonv1.HyperParameterScalingType(sagemakerSpec.HyperparameterScalingType_LINEAR.String()), + }, + }, + CategoricalParameterRanges: []commonv1.CategoricalParameterRange{{Name: ToStringPtr("hp3"), Values: []string{"AAA", "BBB", "CCC"}}}, + } +} + func generateMockTunableHPMap() map[string]*sagemakerSpec.ParameterRangeOneOf { ret := map[string]*sagemakerSpec.ParameterRangeOneOf{ "hp1": {ParameterRangeType: &sagemakerSpec.ParameterRangeOneOf_IntegerParameterRange{ @@ -122,9 +144,9 @@ func generateMockSageMakerConfig() *sagemakerConfig.Config { func Test_deleteConflictingStaticHyperparameters(t *testing.T) { mockCtx := context.TODO() type args struct { - ctx context.Context - staticHPs []*commonv1.KeyValuePair - tunableHPMap map[string]*sagemakerSpec.ParameterRangeOneOf + ctx context.Context + staticHPs []*commonv1.KeyValuePair + tunableHPs *commonv1.ParameterRanges } tests := []struct { name string @@ -132,24 +154,24 @@ func Test_deleteConflictingStaticHyperparameters(t *testing.T) { want []*commonv1.KeyValuePair }{ {name: "Partially conflicting hyperparameter list", args: args{ - ctx: mockCtx, - staticHPs: generatePartiallyConflictingStaticHPs(), - tunableHPMap: generateMockTunableHPMap(), + ctx: mockCtx, + staticHPs: generatePartiallyConflictingStaticHPs(), + tunableHPs: generateMockTunableHPs(), }, want: []*commonv1.KeyValuePair{{Name: "hp4", Value: "0.5"}}}, {name: "Totally conflicting hyperparameter list", args: args{ - ctx: mockCtx, - staticHPs: generateTotallyConflictingStaticHPs(), - tunableHPMap: generateMockTunableHPMap(), + ctx: mockCtx, + staticHPs: generateTotallyConflictingStaticHPs(), + tunableHPs: generateMockTunableHPs(), }, want: []*commonv1.KeyValuePair{}}, {name: "Non-conflicting hyperparameter list", args: args{ - ctx: mockCtx, - staticHPs: generateNonConflictingStaticHPs(), - tunableHPMap: generateMockTunableHPMap(), + ctx: mockCtx, + staticHPs: generateNonConflictingStaticHPs(), + tunableHPs: generateMockTunableHPs(), }, want: []*commonv1.KeyValuePair{{Name: "hp5", Value: "100"}, {Name: "hp4", Value: "0.5"}, {Name: "hp7", Value: "ddd,eee"}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := deleteConflictingStaticHyperparameters(tt.args.ctx, tt.args.staticHPs, tt.args.tunableHPMap); !reflect.DeepEqual(got, tt.want) { + if got := deleteConflictingStaticHyperparameters(tt.args.ctx, tt.args.staticHPs, tt.args.tunableHPs); !reflect.DeepEqual(got, tt.want) { t.Errorf("deleteConflictingStaticHyperparameters() = %v, want %v", got, tt.want) } }) From b8dd1bc5914dac69559a62ef96fa47df551c3bba Mon Sep 17 00:00:00 2001 From: Chang-Hong Hsu Date: Thu, 29 Oct 2020 20:36:12 -0700 Subject: [PATCH 2/4] lint --- go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go index dce4e5352..b819d9dbe 100644 --- a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go +++ b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go @@ -188,10 +188,6 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob( jobName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - if len(hpoJobParameterRanges.CategoricalParameterRanges) == 0 && len(hpoJobParameterRanges.ContinuousParameterRanges) == 0 { - - } - cfg := config.GetSagemakerConfig() var metricDefinitions []commonv1.MetricDefinition From 4fafedf200ae4340ea0ccf0bf25aae849956b601 Mon Sep 17 00:00:00 2001 From: Chang-Hong Hsu Date: Sun, 1 Nov 2020 21:41:07 -0800 Subject: [PATCH 3/4] propagate timeout to HPO Job k8s resource --- go.mod | 1 + go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 6aebceacb..6b3424b1b 100644 --- a/go.mod +++ b/go.mod @@ -44,6 +44,7 @@ require ( replace ( github.com/GoogleCloudPlatform/spark-on-k8s-operator => github.com/lyft/spark-on-k8s-operator v0.1.3 github.com/googleapis/gnostic => github.com/googleapis/gnostic v0.3.1 + github.com/lyft/flyteidl => /Users/changhonghsu/src/go/src/github.com/lyft/flyteidl k8s.io/api => github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0 k8s.io/apimachinery => github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f k8s.io/client-go => k8s.io/client-go v0.0.0-20191016111102-bec269661e48 diff --git a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go index b819d9dbe..6e41053e5 100644 --- a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go +++ b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go @@ -173,6 +173,10 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob( return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "There has to be at least one input for a hyperparameter tuning job wrapping around a custom-training job") } + if taskTemplate.GetContainer() == nil { + return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "The task template points to a nil container") + } + if taskTemplate.GetContainer().GetImage() == "" { return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Invalid image of the container") } @@ -206,6 +210,8 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob( role = cfg.RoleArn } + maxTrainingJobRuntimeInSeconds := sagemakerHPOJob.GetTrainingJobMetadata().GetTimeout().GetSeconds() + hpoJob := &hpojobv1.HyperparameterTuningJob{ Spec: hpojobv1.HyperparameterTuningJobSpec{ HyperParameterTuningJobName: &jobName, @@ -243,7 +249,7 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob( }, RoleArn: ToStringPtr(role), StoppingCondition: &commonv1.StoppingCondition{ - MaxRuntimeInSeconds: ToInt64Ptr(86400), + MaxRuntimeInSeconds: ToInt64Ptr(maxTrainingJobRuntimeInSeconds), MaxWaitTimeInSeconds: nil, // We currently don't have a conclusion how to set a value for this }, }, From 65af097040a1290c1bf8ed16aa0fb4a4ea32feb4 Mon Sep 17 00:00:00 2001 From: Chang-Hong Hsu Date: Fri, 18 Dec 2020 11:12:55 -0800 Subject: [PATCH 4/4] add output handling logic --- .../k8s/sagemaker/hyperparameter_tuning.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go index 6e41053e5..606157fa8 100644 --- a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go +++ b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go @@ -295,6 +295,22 @@ func (m awsSagemakerPlugin) getTaskPhaseForHyperparameterTuningJob( case sagemaker.HyperParameterTuningJobStatusCompleted: // Now that it is a success we will set the outputs as expected by the task + // 11/01/2020: how do I tell if it is a custom training job or not in this function? Do I need to know? + + logger.Infof(ctx, "Looking for the output.pb under %s", pluginContext.OutputWriter().GetOutputPrefixPath()) + outputReader := ioutils.NewRemoteFileOutputReader(ctx, pluginContext.DataStore(), pluginContext.OutputWriter(), pluginContext.MaxDatasetSizeBytes()) + + retrieveBestTrainingJobOutput + createModelOutputPath(hpoJob, pluginContext.OutputWriter().GetRawOutputPrefix().String(), + *hpoJob.Status.BestTrainingJob.TrainingJobName) + + // Instantiate a output reader with the literal map, and write the output to the remote location referred to by the OutputWriter + if err := pluginContext.OutputWriter().Put(ctx, outputReader); err != nil { + return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Failed to write output to the remote location") + } + logger.Debugf(ctx, "Successfully produced and returned outputs") + return pluginsCore.PhaseInfoSuccess(info), nil + // TODO: // Check task template -> custom training job -> if custom: assume output.pb exist, and fail if it doesn't. If it exists, then // -> if not custom: check model.tar.gz