Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[WIP] Supporting hyperparameter tuining on custom training job task #137

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ require (
replace (
github.com/GoogleCloudPlatform/spark-on-k8s-operator => github.com/lyft/spark-on-k8s-operator v0.1.4-0.20201027003055-c76b67e3b6d0
github.com/googleapis/gnostic => github.com/googleapis/gnostic v0.3.1
github.com/lyft/flyteidl => /Users/changhonghsu/src/go/src/github.com/lyft/flyteidl
k8s.io/api => github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0
k8s.io/apimachinery => github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f
k8s.io/client-go => k8s.io/client-go v0.0.0-20191016111102-bec269661e48
Expand Down
202 changes: 134 additions & 68 deletions go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,64 +57,141 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob(
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Required field [TrainingJobResourceConfig] of the HyperparameterTuningJob's underlying TrainingJob does not exist")
}

trainingJobType := sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetAlgorithmName()

taskInput, err := taskCtx.InputReader().Get(ctx)
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "unable to fetch task inputs")
}

// Get inputs from literals
inputLiterals := taskInput.GetLiterals()
err = checkIfRequiredInputLiteralsExist(inputLiterals,
[]string{TrainPredefinedInputVariable, ValidationPredefinedInputVariable, StaticHyperparametersPredefinedInputVariable})

hpoJobConfigLiteral := inputLiterals["hyperparameter_tuning_job_config"]
// hyperparameter_tuning_job_config is marshaled into a struct in flytekit, so will have to unmarshal it back
hpoJobConfig, err := convertHyperparameterTuningJobConfigToSpecType(hpoJobConfigLiteral)
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Error occurred when checking if all the required inputs exist")
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to convert hyperparameter tuning job config literal to spec type")
}
logger.Infof(ctx, "hyperparameter tuning job config = [%v]", hpoJobConfig)

trainPathLiteral := inputLiterals[TrainPredefinedInputVariable]
validatePathLiteral := inputLiterals[ValidationPredefinedInputVariable]
staticHyperparamsLiteral := inputLiterals[StaticHyperparametersPredefinedInputVariable]
hpoJobConfigLiteral := inputLiterals["hyperparameter_tuning_job_config"]
if trainPathLiteral.GetScalar() == nil || trainPathLiteral.GetScalar().GetBlob() == nil {
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[%v] Input is required and should be of Type [Scalar.Blob]", TrainPredefinedInputVariable)
// Extracting the tunable hyperparameters from the input literals
hpoJobParameterRanges := buildParameterRanges(ctx, inputLiterals)

for _, catpr := range hpoJobParameterRanges.CategoricalParameterRanges {
logger.Infof(ctx, "CategoricalParameterRange: [%v]: %v", *catpr.Name, catpr.Values)
}
if validatePathLiteral.GetScalar() == nil || validatePathLiteral.GetScalar().GetBlob() == nil {
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[%v] Input is required and should be of Type [Scalar.Blob]", ValidationPredefinedInputVariable)
for _, intpr := range hpoJobParameterRanges.IntegerParameterRanges {
logger.Infof(ctx, "IntegerParameterRange: [%v]: (max:%v, min:%v, scaling:%v)", *intpr.Name, *intpr.MaxValue, *intpr.MinValue, intpr.ScalingType)
}
// Convert the hyperparameters to the spec value
staticHyperparams, err := convertStaticHyperparamsLiteralToSpecType(staticHyperparamsLiteral)
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "could not convert static hyperparameters to spec type")
for _, conpr := range hpoJobParameterRanges.ContinuousParameterRanges {
logger.Infof(ctx, "ContinuousParameterRange [%v]: (max:%v, min:%v, scaling:%v)", *conpr.Name, *conpr.MaxValue, *conpr.MinValue, conpr.ScalingType)
}

// hyperparameter_tuning_job_config is marshaled into a struct in flytekit, so will have to unmarshal it back
hpoJobConfig, err := convertHyperparameterTuningJobConfigToSpecType(hpoJobConfigLiteral)
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to convert hyperparameter tuning job config literal to spec type")
inputModeString := strings.Title(strings.ToLower(sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputMode().String()))

var staticHyperparams []*commonv1.KeyValuePair
var inputChannels []commonv1.Channel
var trainingImageStr string
if trainingJobType != flyteSageMakerIdl.AlgorithmName_CUSTOM {
logger.Infof(ctx, "The hyperparameter tuning job is wrapping around a built-in algorithm training job")
requiredInputs := []string{TrainPredefinedInputVariable, ValidationPredefinedInputVariable, StaticHyperparametersPredefinedInputVariable}
logger.Infof(ctx, "Checking if required inputs exist [%v]", requiredInputs)
// train, validation, and static_hyperparameters are the default required inputs for hpo job that wraps
// around a built-in algorithm training job
err = checkIfRequiredInputLiteralsExist(inputLiterals, requiredInputs)
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Error occurred when checking if all the required inputs exist")
}

trainPathLiteral := inputLiterals[TrainPredefinedInputVariable]
validatePathLiteral := inputLiterals[ValidationPredefinedInputVariable]
staticHyperparamsLiteral := inputLiterals[StaticHyperparametersPredefinedInputVariable]

if trainPathLiteral.GetScalar() == nil || trainPathLiteral.GetScalar().GetBlob() == nil {
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[%v] Input is required and should be of Type [Scalar.Blob]", TrainPredefinedInputVariable)
}
if validatePathLiteral.GetScalar() == nil || validatePathLiteral.GetScalar().GetBlob() == nil {
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "[%v] Input is required and should be of Type [Scalar.Blob]", ValidationPredefinedInputVariable)
}
// Convert the hyperparameters to the spec value
staticHyperparams, err := convertStaticHyperparamsLiteralToSpecType(staticHyperparamsLiteral)
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "could not convert static hyperparameters to spec type")
}

// Deleting the conflicting static hyperparameters: if a hyperparameter exist in both the map of static hyperparameter
// and the map of the tunable hyperparameter inside the Hyperparameter Tuning Job Config, we delete the entry
// in the static map and let the one in the map of the tunable hyperparameters take precedence
staticHyperparams = deleteConflictingStaticHyperparameters(ctx, staticHyperparams, hpoJobParameterRanges)
logger.Infof(ctx, "Sagemaker HyperparameterTuningJob Task plugin will proceed with the following static hyperparameters:")
for _, shp := range staticHyperparams {
logger.Infof(ctx, "(%v, %v)", shp.Name, shp.Value)
}

apiContentType, err := getAPIContentType(sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputContentType())
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Unsupported input file type [%v]",
sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputContentType().String())
}

inputChannels = []commonv1.Channel{
{
ChannelName: ToStringPtr(TrainPredefinedInputVariable),
DataSource: &commonv1.DataSource{
S3DataSource: &commonv1.S3DataSource{
S3DataType: "S3Prefix",
S3Uri: ToStringPtr(trainPathLiteral.GetScalar().GetBlob().GetUri()),
},
},
ContentType: ToStringPtr(apiContentType), // TODO: can this be derived from the BlobMetadata
InputMode: inputModeString,
},
{
ChannelName: ToStringPtr(ValidationPredefinedInputVariable),
DataSource: &commonv1.DataSource{
S3DataSource: &commonv1.S3DataSource{
S3DataType: "S3Prefix",
S3Uri: ToStringPtr(validatePathLiteral.GetScalar().GetBlob().GetUri()),
},
},
ContentType: ToStringPtr(apiContentType), // TODO: can this be derived from the BlobMetadata
InputMode: inputModeString,
},
}

trainingImageStr, err = getTrainingJobImage(ctx, taskCtx, sagemakerHPOJob.GetTrainingJob())
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to find the training image")
}
} else {
// For hpo job that wraps around a custom training job, there has to be at least one tunable hyperparameter in
// the input list
if len(inputLiterals) < 1 ||
(len(hpoJobParameterRanges.ContinuousParameterRanges) < 1 && len(hpoJobParameterRanges.IntegerParameterRanges) < 1 && len(hpoJobParameterRanges.CategoricalParameterRanges) < 1) {

return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "There has to be at least one input for a hyperparameter tuning job wrapping around a custom-training job")
}

if taskTemplate.GetContainer() == nil {
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "The task template points to a nil container")
}

if taskTemplate.GetContainer().GetImage() == "" {
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Invalid image of the container")
}
inputChannels = nil
trainingImageStr = taskTemplate.GetContainer().GetImage()
}

outputPath := createOutputPath(taskCtx.OutputWriter().GetRawOutputPrefix().String(), HyperparameterOutputPathSubDir)

if hpoJobConfig.GetTuningObjective() == nil {
return nil, pluginErrors.Errorf(pluginErrors.BadTaskSpecification, "Required field [TuningObjective] does not exist")
}

// Deleting the conflicting static hyperparameters: if a hyperparameter exist in both the map of static hyperparameter
// and the map of the tunable hyperparameter inside the Hyperparameter Tuning Job Config, we delete the entry
// in the static map and let the one in the map of the tunable hyperparameters take precedence
staticHyperparams = deleteConflictingStaticHyperparameters(ctx, staticHyperparams, hpoJobConfig.GetHyperparameterRanges().GetParameterRangeMap())

jobName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()

trainingImageStr, err := getTrainingJobImage(ctx, taskCtx, sagemakerHPOJob.GetTrainingJob())
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to find the training image")
}

hpoJobParameterRanges := buildParameterRanges(ctx, inputLiterals)
logger.Infof(ctx, "The Sagemaker HyperparameterTuningJob Task plugin received the following inputs: \n"+
"static hyperparameters: [%v]\n"+
"hyperparameter tuning job config: [%v]\n"+
"parameter ranges: [%v]", staticHyperparams, hpoJobConfig, hpoJobParameterRanges)

cfg := config.GetSagemakerConfig()

var metricDefinitions []commonv1.MetricDefinition
Expand All @@ -124,13 +201,6 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob(
commonv1.MetricDefinition{Name: ToStringPtr(md.Name), Regex: ToStringPtr(md.Regex)})
}

apiContentType, err := getAPIContentType(sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputContentType())
if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Unsupported input file type [%v]",
sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputContentType().String())
}

inputModeString := strings.Title(strings.ToLower(sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification().GetInputMode().String()))
tuningStrategyString := strings.Title(strings.ToLower(hpoJobConfig.GetTuningStrategy().String()))
tuningObjectiveTypeString := strings.Title(strings.ToLower(hpoJobConfig.GetTuningObjective().GetObjectiveType().String()))
trainingJobEarlyStoppingTypeString := strings.Title(strings.ToLower(hpoJobConfig.TrainingJobEarlyStoppingType.String()))
Expand All @@ -140,6 +210,8 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob(
role = cfg.RoleArn
}

maxTrainingJobRuntimeInSeconds := sagemakerHPOJob.GetTrainingJobMetadata().GetTimeout().GetSeconds()

hpoJob := &hpojobv1.HyperparameterTuningJob{
Spec: hpojobv1.HyperparameterTuningJobSpec{
HyperParameterTuningJobName: &jobName,
Expand All @@ -157,37 +229,15 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob(
TrainingJobEarlyStoppingType: commonv1.TrainingJobEarlyStoppingType(trainingJobEarlyStoppingTypeString),
},
TrainingJobDefinition: &commonv1.HyperParameterTrainingJobDefinition{
// If the underlying training job is a custom training job, this will be nil
StaticHyperParameters: staticHyperparams,
AlgorithmSpecification: &commonv1.HyperParameterAlgorithmSpecification{
TrainingImage: ToStringPtr(trainingImageStr),
TrainingInputMode: commonv1.TrainingInputMode(inputModeString),
MetricDefinitions: metricDefinitions,
AlgorithmName: nil,
},
InputDataConfig: []commonv1.Channel{
{
ChannelName: ToStringPtr(TrainPredefinedInputVariable),
DataSource: &commonv1.DataSource{
S3DataSource: &commonv1.S3DataSource{
S3DataType: "S3Prefix",
S3Uri: ToStringPtr(trainPathLiteral.GetScalar().GetBlob().GetUri()),
},
},
ContentType: ToStringPtr(apiContentType), // TODO: can this be derived from the BlobMetadata
InputMode: inputModeString,
},
{
ChannelName: ToStringPtr(ValidationPredefinedInputVariable),
DataSource: &commonv1.DataSource{
S3DataSource: &commonv1.S3DataSource{
S3DataType: "S3Prefix",
S3Uri: ToStringPtr(validatePathLiteral.GetScalar().GetBlob().GetUri()),
},
},
ContentType: ToStringPtr(apiContentType), // TODO: can this be derived from the BlobMetadata
InputMode: inputModeString,
},
},
InputDataConfig: inputChannels,
OutputDataConfig: &commonv1.OutputDataConfig{
S3OutputPath: ToStringPtr(outputPath),
},
Expand All @@ -199,8 +249,8 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob(
},
RoleArn: ToStringPtr(role),
StoppingCondition: &commonv1.StoppingCondition{
MaxRuntimeInSeconds: ToInt64Ptr(86400),
MaxWaitTimeInSeconds: nil,
MaxRuntimeInSeconds: ToInt64Ptr(maxTrainingJobRuntimeInSeconds),
MaxWaitTimeInSeconds: nil, // We currently don't have a conclusion how to set a value for this
},
},
Region: ToStringPtr(cfg.Region),
Expand Down Expand Up @@ -245,6 +295,22 @@ func (m awsSagemakerPlugin) getTaskPhaseForHyperparameterTuningJob(
case sagemaker.HyperParameterTuningJobStatusCompleted:
// Now that it is a success we will set the outputs as expected by the task

// 11/01/2020: how do I tell if it is a custom training job or not in this function? Do I need to know?

logger.Infof(ctx, "Looking for the output.pb under %s", pluginContext.OutputWriter().GetOutputPrefixPath())
outputReader := ioutils.NewRemoteFileOutputReader(ctx, pluginContext.DataStore(), pluginContext.OutputWriter(), pluginContext.MaxDatasetSizeBytes())

retrieveBestTrainingJobOutput
createModelOutputPath(hpoJob, pluginContext.OutputWriter().GetRawOutputPrefix().String(),
*hpoJob.Status.BestTrainingJob.TrainingJobName)

// Instantiate a output reader with the literal map, and write the output to the remote location referred to by the OutputWriter
if err := pluginContext.OutputWriter().Put(ctx, outputReader); err != nil {
return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Failed to write output to the remote location")
}
logger.Debugf(ctx, "Successfully produced and returned outputs")
return pluginsCore.PhaseInfoSuccess(info), nil

// TODO:
// Check task template -> custom training job -> if custom: assume output.pb exist, and fail if it doesn't. If it exists, then
// -> if not custom: check model.tar.gz
Expand Down
Loading