diff --git a/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.cc b/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.cc index 303cb8003e..187cc04817 100644 --- a/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.cc +++ b/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.cc @@ -83,6 +83,7 @@ const ::google::protobuf::uint32 TableStruct_flyteidl_2fplugins_2fkubeflow_2ften PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingTask, ps_replicas_), PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingTask, chief_replicas_), PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingTask, run_policy_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingTask, evaluator_replicas_), ~0u, // no _has_bits_ PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec, _internal_metadata_), ~0u, // no _extensions_ @@ -95,7 +96,7 @@ const ::google::protobuf::uint32 TableStruct_flyteidl_2fplugins_2fkubeflow_2ften }; static const ::google::protobuf::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { { 0, -1, sizeof(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingTask)}, - { 9, -1, sizeof(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec)}, + { 10, -1, sizeof(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec)}, }; static ::google::protobuf::Message const * const file_default_instances[] = { @@ -113,7 +114,7 @@ const char descriptor_table_protodef_flyteidl_2fplugins_2fkubeflow_2ftensorflow_ "\n*flyteidl/plugins/kubeflow/tensorflow.p" "roto\022\031flyteidl.plugins.kubeflow\032\031flyteid" "l/core/tasks.proto\032&flyteidl/plugins/kub" - "eflow/common.proto\"\362\002\n!DistributedTensor" + "eflow/common.proto\"\323\003\n!DistributedTensor" "flowTrainingTask\022\\\n\017worker_replicas\030\001 \001(" "\0132C.flyteidl.plugins.kubeflow.Distribute" "dTensorflowTrainingReplicaSpec\022X\n\013ps_rep" @@ -122,19 +123,21 @@ const char descriptor_table_protodef_flyteidl_2fplugins_2fkubeflow_2ftensorflow_ "\022[\n\016chief_replicas\030\003 \001(\0132C.flyteidl.plug" "ins.kubeflow.DistributedTensorflowTraini" "ngReplicaSpec\0228\n\nrun_policy\030\004 \001(\0132$.flyt" - "eidl.plugins.kubeflow.RunPolicy\"\272\001\n(Dist" - "ributedTensorflowTrainingReplicaSpec\022\020\n\010" - "replicas\030\001 \001(\005\022\r\n\005image\030\002 \001(\t\022+\n\tresourc" - "es\030\003 \001(\0132\030.flyteidl.core.Resources\022@\n\016re" - "start_policy\030\004 \001(\0162(.flyteidl.plugins.ku" - "beflow.RestartPolicyB\?Z=github.com/flyte" - "org/flyte/flyteidl/gen/pb-go/flyteidl/pl" - "uginsb\006proto3" + "eidl.plugins.kubeflow.RunPolicy\022_\n\022evalu" + "ator_replicas\030\005 \001(\0132C.flyteidl.plugins.k" + "ubeflow.DistributedTensorflowTrainingRep" + "licaSpec\"\272\001\n(DistributedTensorflowTraini" + "ngReplicaSpec\022\020\n\010replicas\030\001 \001(\005\022\r\n\005image" + "\030\002 \001(\t\022+\n\tresources\030\003 \001(\0132\030.flyteidl.cor" + "e.Resources\022@\n\016restart_policy\030\004 \001(\0162(.fl" + "yteidl.plugins.kubeflow.RestartPolicyB\?Z" + "=github.com/flyteorg/flyte/flyteidl/gen/" + "pb-go/flyteidl/pluginsb\006proto3" ; ::google::protobuf::internal::DescriptorTable descriptor_table_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto = { false, InitDefaults_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto, descriptor_table_protodef_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto, - "flyteidl/plugins/kubeflow/tensorflow.proto", &assign_descriptors_table_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto, 773, + "flyteidl/plugins/kubeflow/tensorflow.proto", &assign_descriptors_table_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto, 870, }; void AddDescriptors_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto() { @@ -163,6 +166,8 @@ void DistributedTensorflowTrainingTask::InitAsDefaultInstance() { ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec::internal_default_instance()); ::flyteidl::plugins::kubeflow::_DistributedTensorflowTrainingTask_default_instance_._instance.get_mutable()->run_policy_ = const_cast< ::flyteidl::plugins::kubeflow::RunPolicy*>( ::flyteidl::plugins::kubeflow::RunPolicy::internal_default_instance()); + ::flyteidl::plugins::kubeflow::_DistributedTensorflowTrainingTask_default_instance_._instance.get_mutable()->evaluator_replicas_ = const_cast< ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec*>( + ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec::internal_default_instance()); } class DistributedTensorflowTrainingTask::HasBitSetters { public: @@ -170,6 +175,7 @@ class DistributedTensorflowTrainingTask::HasBitSetters { static const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& ps_replicas(const DistributedTensorflowTrainingTask* msg); static const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& chief_replicas(const DistributedTensorflowTrainingTask* msg); static const ::flyteidl::plugins::kubeflow::RunPolicy& run_policy(const DistributedTensorflowTrainingTask* msg); + static const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& evaluator_replicas(const DistributedTensorflowTrainingTask* msg); }; const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& @@ -188,6 +194,10 @@ const ::flyteidl::plugins::kubeflow::RunPolicy& DistributedTensorflowTrainingTask::HasBitSetters::run_policy(const DistributedTensorflowTrainingTask* msg) { return *msg->run_policy_; } +const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& +DistributedTensorflowTrainingTask::HasBitSetters::evaluator_replicas(const DistributedTensorflowTrainingTask* msg) { + return *msg->evaluator_replicas_; +} void DistributedTensorflowTrainingTask::clear_run_policy() { if (GetArenaNoVirtual() == nullptr && run_policy_ != nullptr) { delete run_policy_; @@ -199,6 +209,7 @@ const int DistributedTensorflowTrainingTask::kWorkerReplicasFieldNumber; const int DistributedTensorflowTrainingTask::kPsReplicasFieldNumber; const int DistributedTensorflowTrainingTask::kChiefReplicasFieldNumber; const int DistributedTensorflowTrainingTask::kRunPolicyFieldNumber; +const int DistributedTensorflowTrainingTask::kEvaluatorReplicasFieldNumber; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 DistributedTensorflowTrainingTask::DistributedTensorflowTrainingTask() @@ -230,6 +241,11 @@ DistributedTensorflowTrainingTask::DistributedTensorflowTrainingTask(const Distr } else { run_policy_ = nullptr; } + if (from.has_evaluator_replicas()) { + evaluator_replicas_ = new ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec(*from.evaluator_replicas_); + } else { + evaluator_replicas_ = nullptr; + } // @@protoc_insertion_point(copy_constructor:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask) } @@ -237,8 +253,8 @@ void DistributedTensorflowTrainingTask::SharedCtor() { ::google::protobuf::internal::InitSCC( &scc_info_DistributedTensorflowTrainingTask_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto.base); ::memset(&worker_replicas_, 0, static_cast( - reinterpret_cast(&run_policy_) - - reinterpret_cast(&worker_replicas_)) + sizeof(run_policy_)); + reinterpret_cast(&evaluator_replicas_) - + reinterpret_cast(&worker_replicas_)) + sizeof(evaluator_replicas_)); } DistributedTensorflowTrainingTask::~DistributedTensorflowTrainingTask() { @@ -251,6 +267,7 @@ void DistributedTensorflowTrainingTask::SharedDtor() { if (this != internal_default_instance()) delete ps_replicas_; if (this != internal_default_instance()) delete chief_replicas_; if (this != internal_default_instance()) delete run_policy_; + if (this != internal_default_instance()) delete evaluator_replicas_; } void DistributedTensorflowTrainingTask::SetCachedSize(int size) const { @@ -284,6 +301,10 @@ void DistributedTensorflowTrainingTask::Clear() { delete run_policy_; } run_policy_ = nullptr; + if (GetArenaNoVirtual() == nullptr && evaluator_replicas_ != nullptr) { + delete evaluator_replicas_; + } + evaluator_replicas_ = nullptr; _internal_metadata_.Clear(); } @@ -352,6 +373,19 @@ const char* DistributedTensorflowTrainingTask::_InternalParse(const char* begin, {parser_till_end, object}, ptr - size, ptr)); break; } + // .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + case 5: { + if (static_cast<::google::protobuf::uint8>(tag) != 42) goto handle_unusual; + ptr = ::google::protobuf::io::ReadSize(ptr, &size); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + parser_till_end = ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec::_InternalParse; + object = msg->mutable_evaluator_replicas(); + if (size > end - ptr) goto len_delim_till_end; + ptr += size; + GOOGLE_PROTOBUF_PARSER_ASSERT(ctx->ParseExactRange( + {parser_till_end, object}, ptr - size, ptr)); + break; + } default: { handle_unusual: if ((tag & 7) == 4 || tag == 0) { @@ -426,6 +460,17 @@ bool DistributedTensorflowTrainingTask::MergePartialFromCodedStream( break; } + // .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + case 5: { + if (static_cast< ::google::protobuf::uint8>(tag) == (42 & 0xFF)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadMessage( + input, mutable_evaluator_replicas())); + } else { + goto handle_unusual; + } + break; + } + default: { handle_unusual: if (tag == 0) { @@ -477,6 +522,12 @@ void DistributedTensorflowTrainingTask::SerializeWithCachedSizes( 4, HasBitSetters::run_policy(this), output); } + // .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + if (this->has_evaluator_replicas()) { + ::google::protobuf::internal::WireFormatLite::WriteMessageMaybeToArray( + 5, HasBitSetters::evaluator_replicas(this), output); + } + if (_internal_metadata_.have_unknown_fields()) { ::google::protobuf::internal::WireFormat::SerializeUnknownFields( _internal_metadata_.unknown_fields(), output); @@ -518,6 +569,13 @@ ::google::protobuf::uint8* DistributedTensorflowTrainingTask::InternalSerializeW 4, HasBitSetters::run_policy(this), target); } + // .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + if (this->has_evaluator_replicas()) { + target = ::google::protobuf::internal::WireFormatLite:: + InternalWriteMessageToArray( + 5, HasBitSetters::evaluator_replicas(this), target); + } + if (_internal_metadata_.have_unknown_fields()) { target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( _internal_metadata_.unknown_fields(), target); @@ -567,6 +625,13 @@ size_t DistributedTensorflowTrainingTask::ByteSizeLong() const { *run_policy_); } + // .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + if (this->has_evaluator_replicas()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::MessageSize( + *evaluator_replicas_); + } + int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); SetCachedSize(cached_size); return total_size; @@ -606,6 +671,9 @@ void DistributedTensorflowTrainingTask::MergeFrom(const DistributedTensorflowTra if (from.has_run_policy()) { mutable_run_policy()->::flyteidl::plugins::kubeflow::RunPolicy::MergeFrom(from.run_policy()); } + if (from.has_evaluator_replicas()) { + mutable_evaluator_replicas()->::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec::MergeFrom(from.evaluator_replicas()); + } } void DistributedTensorflowTrainingTask::CopyFrom(const ::google::protobuf::Message& from) { @@ -637,6 +705,7 @@ void DistributedTensorflowTrainingTask::InternalSwap(DistributedTensorflowTraini swap(ps_replicas_, other->ps_replicas_); swap(chief_replicas_, other->chief_replicas_); swap(run_policy_, other->run_policy_); + swap(evaluator_replicas_, other->evaluator_replicas_); } ::google::protobuf::Metadata DistributedTensorflowTrainingTask::GetMetadata() const { diff --git a/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.h b/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.h index 9839ca4817..4a100b0233 100644 --- a/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.h +++ b/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.h @@ -205,6 +205,15 @@ class DistributedTensorflowTrainingTask final : ::flyteidl::plugins::kubeflow::RunPolicy* mutable_run_policy(); void set_allocated_run_policy(::flyteidl::plugins::kubeflow::RunPolicy* run_policy); + // .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + bool has_evaluator_replicas() const; + void clear_evaluator_replicas(); + static const int kEvaluatorReplicasFieldNumber = 5; + const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& evaluator_replicas() const; + ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* release_evaluator_replicas(); + ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* mutable_evaluator_replicas(); + void set_allocated_evaluator_replicas(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* evaluator_replicas); + // @@protoc_insertion_point(class_scope:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask) private: class HasBitSetters; @@ -214,6 +223,7 @@ class DistributedTensorflowTrainingTask final : ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* ps_replicas_; ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* chief_replicas_; ::flyteidl::plugins::kubeflow::RunPolicy* run_policy_; + ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* evaluator_replicas_; mutable ::google::protobuf::internal::CachedSize _cached_size_; friend struct ::TableStruct_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto; }; @@ -570,6 +580,57 @@ inline void DistributedTensorflowTrainingTask::set_allocated_run_policy(::flytei // @@protoc_insertion_point(field_set_allocated:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask.run_policy) } +// .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; +inline bool DistributedTensorflowTrainingTask::has_evaluator_replicas() const { + return this != internal_default_instance() && evaluator_replicas_ != nullptr; +} +inline void DistributedTensorflowTrainingTask::clear_evaluator_replicas() { + if (GetArenaNoVirtual() == nullptr && evaluator_replicas_ != nullptr) { + delete evaluator_replicas_; + } + evaluator_replicas_ = nullptr; +} +inline const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& DistributedTensorflowTrainingTask::evaluator_replicas() const { + const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* p = evaluator_replicas_; + // @@protoc_insertion_point(field_get:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask.evaluator_replicas) + return p != nullptr ? *p : *reinterpret_cast( + &::flyteidl::plugins::kubeflow::_DistributedTensorflowTrainingReplicaSpec_default_instance_); +} +inline ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* DistributedTensorflowTrainingTask::release_evaluator_replicas() { + // @@protoc_insertion_point(field_release:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask.evaluator_replicas) + + ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* temp = evaluator_replicas_; + evaluator_replicas_ = nullptr; + return temp; +} +inline ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* DistributedTensorflowTrainingTask::mutable_evaluator_replicas() { + + if (evaluator_replicas_ == nullptr) { + auto* p = CreateMaybeMessage<::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec>(GetArenaNoVirtual()); + evaluator_replicas_ = p; + } + // @@protoc_insertion_point(field_mutable:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask.evaluator_replicas) + return evaluator_replicas_; +} +inline void DistributedTensorflowTrainingTask::set_allocated_evaluator_replicas(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* evaluator_replicas) { + ::google::protobuf::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete evaluator_replicas_; + } + if (evaluator_replicas) { + ::google::protobuf::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + evaluator_replicas = ::google::protobuf::internal::GetOwnedMessage( + message_arena, evaluator_replicas, submessage_arena); + } + + } else { + + } + evaluator_replicas_ = evaluator_replicas; + // @@protoc_insertion_point(field_set_allocated:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask.evaluator_replicas) +} + // ------------------------------------------------------------------- // DistributedTensorflowTrainingReplicaSpec diff --git a/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.cc b/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.cc index 7d01fad6e2..23d343be1a 100644 --- a/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.cc +++ b/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.cc @@ -55,6 +55,7 @@ const ::google::protobuf::uint32 TableStruct_flyteidl_2fplugins_2ftensorflow_2ep PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedTensorflowTrainingTask, workers_), PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedTensorflowTrainingTask, ps_replicas_), PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedTensorflowTrainingTask, chief_replicas_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedTensorflowTrainingTask, evaluator_replicas_), }; static const ::google::protobuf::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { { 0, -1, sizeof(::flyteidl::plugins::DistributedTensorflowTrainingTask)}, @@ -72,16 +73,17 @@ ::google::protobuf::internal::AssignDescriptorsTable assign_descriptors_table_fl const char descriptor_table_protodef_flyteidl_2fplugins_2ftensorflow_2eproto[] = "\n!flyteidl/plugins/tensorflow.proto\022\020fly" - "teidl.plugins\"a\n!DistributedTensorflowTr" + "teidl.plugins\"}\n!DistributedTensorflowTr" "ainingTask\022\017\n\007workers\030\001 \001(\005\022\023\n\013ps_replic" - "as\030\002 \001(\005\022\026\n\016chief_replicas\030\003 \001(\005B\?Z=gith" - "ub.com/flyteorg/flyte/flyteidl/gen/pb-go" - "/flyteidl/pluginsb\006proto3" + "as\030\002 \001(\005\022\026\n\016chief_replicas\030\003 \001(\005\022\032\n\022eval" + "uator_replicas\030\004 \001(\005B\?Z=github.com/flyte" + "org/flyte/flyteidl/gen/pb-go/flyteidl/pl" + "uginsb\006proto3" ; ::google::protobuf::internal::DescriptorTable descriptor_table_flyteidl_2fplugins_2ftensorflow_2eproto = { false, InitDefaults_flyteidl_2fplugins_2ftensorflow_2eproto, descriptor_table_protodef_flyteidl_2fplugins_2ftensorflow_2eproto, - "flyteidl/plugins/tensorflow.proto", &assign_descriptors_table_flyteidl_2fplugins_2ftensorflow_2eproto, 225, + "flyteidl/plugins/tensorflow.proto", &assign_descriptors_table_flyteidl_2fplugins_2ftensorflow_2eproto, 253, }; void AddDescriptors_flyteidl_2fplugins_2ftensorflow_2eproto() { @@ -108,6 +110,7 @@ class DistributedTensorflowTrainingTask::HasBitSetters { const int DistributedTensorflowTrainingTask::kWorkersFieldNumber; const int DistributedTensorflowTrainingTask::kPsReplicasFieldNumber; const int DistributedTensorflowTrainingTask::kChiefReplicasFieldNumber; +const int DistributedTensorflowTrainingTask::kEvaluatorReplicasFieldNumber; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 DistributedTensorflowTrainingTask::DistributedTensorflowTrainingTask() @@ -120,15 +123,15 @@ DistributedTensorflowTrainingTask::DistributedTensorflowTrainingTask(const Distr _internal_metadata_(nullptr) { _internal_metadata_.MergeFrom(from._internal_metadata_); ::memcpy(&workers_, &from.workers_, - static_cast(reinterpret_cast(&chief_replicas_) - - reinterpret_cast(&workers_)) + sizeof(chief_replicas_)); + static_cast(reinterpret_cast(&evaluator_replicas_) - + reinterpret_cast(&workers_)) + sizeof(evaluator_replicas_)); // @@protoc_insertion_point(copy_constructor:flyteidl.plugins.DistributedTensorflowTrainingTask) } void DistributedTensorflowTrainingTask::SharedCtor() { ::memset(&workers_, 0, static_cast( - reinterpret_cast(&chief_replicas_) - - reinterpret_cast(&workers_)) + sizeof(chief_replicas_)); + reinterpret_cast(&evaluator_replicas_) - + reinterpret_cast(&workers_)) + sizeof(evaluator_replicas_)); } DistributedTensorflowTrainingTask::~DistributedTensorflowTrainingTask() { @@ -155,8 +158,8 @@ void DistributedTensorflowTrainingTask::Clear() { (void) cached_has_bits; ::memset(&workers_, 0, static_cast( - reinterpret_cast(&chief_replicas_) - - reinterpret_cast(&workers_)) + sizeof(chief_replicas_)); + reinterpret_cast(&evaluator_replicas_) - + reinterpret_cast(&workers_)) + sizeof(evaluator_replicas_)); _internal_metadata_.Clear(); } @@ -194,6 +197,13 @@ const char* DistributedTensorflowTrainingTask::_InternalParse(const char* begin, GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); break; } + // int32 evaluator_replicas = 4; + case 4: { + if (static_cast<::google::protobuf::uint8>(tag) != 32) goto handle_unusual; + msg->set_evaluator_replicas(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } default: { handle_unusual: if ((tag & 7) == 4 || tag == 0) { @@ -260,6 +270,19 @@ bool DistributedTensorflowTrainingTask::MergePartialFromCodedStream( break; } + // int32 evaluator_replicas = 4; + case 4: { + if (static_cast< ::google::protobuf::uint8>(tag) == (32 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &evaluator_replicas_))); + } else { + goto handle_unusual; + } + break; + } + default: { handle_unusual: if (tag == 0) { @@ -302,6 +325,11 @@ void DistributedTensorflowTrainingTask::SerializeWithCachedSizes( ::google::protobuf::internal::WireFormatLite::WriteInt32(3, this->chief_replicas(), output); } + // int32 evaluator_replicas = 4; + if (this->evaluator_replicas() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(4, this->evaluator_replicas(), output); + } + if (_internal_metadata_.have_unknown_fields()) { ::google::protobuf::internal::WireFormat::SerializeUnknownFields( _internal_metadata_.unknown_fields(), output); @@ -330,6 +358,11 @@ ::google::protobuf::uint8* DistributedTensorflowTrainingTask::InternalSerializeW target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(3, this->chief_replicas(), target); } + // int32 evaluator_replicas = 4; + if (this->evaluator_replicas() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(4, this->evaluator_replicas(), target); + } + if (_internal_metadata_.have_unknown_fields()) { target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( _internal_metadata_.unknown_fields(), target); @@ -372,6 +405,13 @@ size_t DistributedTensorflowTrainingTask::ByteSizeLong() const { this->chief_replicas()); } + // int32 evaluator_replicas = 4; + if (this->evaluator_replicas() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->evaluator_replicas()); + } + int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); SetCachedSize(cached_size); return total_size; @@ -408,6 +448,9 @@ void DistributedTensorflowTrainingTask::MergeFrom(const DistributedTensorflowTra if (from.chief_replicas() != 0) { set_chief_replicas(from.chief_replicas()); } + if (from.evaluator_replicas() != 0) { + set_evaluator_replicas(from.evaluator_replicas()); + } } void DistributedTensorflowTrainingTask::CopyFrom(const ::google::protobuf::Message& from) { @@ -438,6 +481,7 @@ void DistributedTensorflowTrainingTask::InternalSwap(DistributedTensorflowTraini swap(workers_, other->workers_); swap(ps_replicas_, other->ps_replicas_); swap(chief_replicas_, other->chief_replicas_); + swap(evaluator_replicas_, other->evaluator_replicas_); } ::google::protobuf::Metadata DistributedTensorflowTrainingTask::GetMetadata() const { diff --git a/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.h b/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.h index 613ed31d80..4150592a60 100644 --- a/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.h +++ b/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.h @@ -178,6 +178,12 @@ class DistributedTensorflowTrainingTask final : ::google::protobuf::int32 chief_replicas() const; void set_chief_replicas(::google::protobuf::int32 value); + // int32 evaluator_replicas = 4; + void clear_evaluator_replicas(); + static const int kEvaluatorReplicasFieldNumber = 4; + ::google::protobuf::int32 evaluator_replicas() const; + void set_evaluator_replicas(::google::protobuf::int32 value); + // @@protoc_insertion_point(class_scope:flyteidl.plugins.DistributedTensorflowTrainingTask) private: class HasBitSetters; @@ -186,6 +192,7 @@ class DistributedTensorflowTrainingTask final : ::google::protobuf::int32 workers_; ::google::protobuf::int32 ps_replicas_; ::google::protobuf::int32 chief_replicas_; + ::google::protobuf::int32 evaluator_replicas_; mutable ::google::protobuf::internal::CachedSize _cached_size_; friend struct ::TableStruct_flyteidl_2fplugins_2ftensorflow_2eproto; }; @@ -242,6 +249,20 @@ inline void DistributedTensorflowTrainingTask::set_chief_replicas(::google::prot // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedTensorflowTrainingTask.chief_replicas) } +// int32 evaluator_replicas = 4; +inline void DistributedTensorflowTrainingTask::clear_evaluator_replicas() { + evaluator_replicas_ = 0; +} +inline ::google::protobuf::int32 DistributedTensorflowTrainingTask::evaluator_replicas() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedTensorflowTrainingTask.evaluator_replicas) + return evaluator_replicas_; +} +inline void DistributedTensorflowTrainingTask::set_evaluator_replicas(::google::protobuf::int32 value) { + + evaluator_replicas_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedTensorflowTrainingTask.evaluator_replicas) +} + #ifdef __GNUC__ #pragma GCC diagnostic pop #endif // __GNUC__ diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.go index ccb4eff2ac..04243dec6e 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.go @@ -32,10 +32,12 @@ type DistributedTensorflowTrainingTask struct { // RunPolicy encapsulates various runtime policies of the distributed training // job, for example how to clean up resources and how long the job can stay // active. - RunPolicy *RunPolicy `protobuf:"bytes,4,opt,name=run_policy,json=runPolicy,proto3" json:"run_policy,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + RunPolicy *RunPolicy `protobuf:"bytes,4,opt,name=run_policy,json=runPolicy,proto3" json:"run_policy,omitempty"` + // Evaluator replicas spec + EvaluatorReplicas *DistributedTensorflowTrainingReplicaSpec `protobuf:"bytes,5,opt,name=evaluator_replicas,json=evaluatorReplicas,proto3" json:"evaluator_replicas,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } func (m *DistributedTensorflowTrainingTask) Reset() { *m = DistributedTensorflowTrainingTask{} } @@ -91,6 +93,13 @@ func (m *DistributedTensorflowTrainingTask) GetRunPolicy() *RunPolicy { return nil } +func (m *DistributedTensorflowTrainingTask) GetEvaluatorReplicas() *DistributedTensorflowTrainingReplicaSpec { + if m != nil { + return m.EvaluatorReplicas + } + return nil +} + type DistributedTensorflowTrainingReplicaSpec struct { // Number of replicas Replicas int32 `protobuf:"varint,1,opt,name=replicas,proto3" json:"replicas,omitempty"` @@ -170,28 +179,29 @@ func init() { } var fileDescriptor_93de2bd764ddf01a = []byte{ - // 358 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xac, 0x92, 0x41, 0x4b, 0xc3, 0x30, - 0x14, 0xc7, 0x99, 0x73, 0xe2, 0x32, 0x56, 0xa1, 0x78, 0x98, 0x3b, 0xe9, 0x10, 0x19, 0x82, 0x0d, - 0x4c, 0xf0, 0x26, 0x82, 0xf3, 0xae, 0xc4, 0x9d, 0xbc, 0x8c, 0x36, 0x7b, 0xeb, 0x62, 0xdb, 0x24, - 0xbc, 0xa4, 0x8c, 0x7d, 0x23, 0xbf, 0x98, 0xdf, 0x43, 0xd6, 0xac, 0xed, 0x14, 0x36, 0x3c, 0xec, - 0xf6, 0x5e, 0xf3, 0xcf, 0xff, 0xf7, 0x5e, 0xfa, 0x27, 0xb7, 0xf3, 0x74, 0x65, 0x41, 0xcc, 0x52, - 0xaa, 0xd3, 0x3c, 0x16, 0xd2, 0xd0, 0x24, 0x8f, 0x60, 0x9e, 0xaa, 0x25, 0xb5, 0x20, 0x8d, 0xc2, - 0x75, 0x19, 0x68, 0x54, 0x56, 0xf9, 0x17, 0xa5, 0x36, 0xd8, 0x68, 0x83, 0x52, 0xdb, 0xaf, 0x8e, - 0x28, 0x57, 0x08, 0xd4, 0x86, 0x26, 0x31, 0xee, 0x56, 0xff, 0x66, 0x37, 0x81, 0xab, 0x2c, 0x53, - 0xd2, 0xe9, 0x06, 0x5f, 0x4d, 0x72, 0xf5, 0x22, 0x8c, 0x45, 0x11, 0xe5, 0x16, 0x66, 0x93, 0x8a, - 0x3e, 0xc1, 0x50, 0x48, 0x21, 0xe3, 0x49, 0x68, 0x12, 0x3f, 0x25, 0x67, 0x4b, 0x85, 0x09, 0xe0, - 0x14, 0x41, 0xa7, 0x82, 0x87, 0xa6, 0xd7, 0xb8, 0x6c, 0x0c, 0x3b, 0xa3, 0x71, 0xb0, 0x73, 0xba, - 0x60, 0xaf, 0x2d, 0x73, 0x3e, 0xef, 0x1a, 0x38, 0xf3, 0x9c, 0xf7, 0xe6, 0x93, 0xf1, 0x67, 0xa4, - 0xa3, 0x4d, 0x4d, 0x3a, 0x3a, 0x1c, 0x89, 0x68, 0x53, 0x51, 0x3e, 0x89, 0xc7, 0x17, 0x02, 0xe6, - 0x35, 0xa8, 0x79, 0x38, 0x50, 0xb7, 0xb0, 0xae, 0x58, 0x63, 0x42, 0x30, 0x97, 0x53, 0xad, 0x52, - 0xc1, 0x57, 0xbd, 0xe3, 0x82, 0x73, 0xbd, 0x87, 0xc3, 0x72, 0xf9, 0x56, 0x68, 0x59, 0x1b, 0xcb, - 0x72, 0xf0, 0xdd, 0x20, 0xc3, 0xff, 0x0e, 0xe0, 0xf7, 0xc9, 0xe9, 0xaf, 0x5f, 0xd5, 0x62, 0x55, - 0xef, 0x9f, 0x93, 0x96, 0xc8, 0xc2, 0x18, 0x8a, 0x97, 0x6d, 0x33, 0xd7, 0xf8, 0x0f, 0xa4, 0x8d, - 0x60, 0x54, 0x8e, 0x1c, 0xca, 0xa7, 0xe8, 0xd5, 0x23, 0xae, 0x03, 0x16, 0xb0, 0xf2, 0x9c, 0xd5, - 0x52, 0xff, 0x95, 0x78, 0x08, 0xc6, 0x86, 0x68, 0xb7, 0xf7, 0xf3, 0x46, 0xc3, 0x7d, 0xfb, 0xb9, - 0x0b, 0x9b, 0x1d, 0xbb, 0xb8, 0xdd, 0x3e, 0x3f, 0x7d, 0x3c, 0xc6, 0xc2, 0x2e, 0xf2, 0x28, 0xe0, - 0x2a, 0xa3, 0x85, 0x89, 0xc2, 0xd8, 0x15, 0xb4, 0x8a, 0x75, 0x0c, 0x92, 0xea, 0xe8, 0x2e, 0x56, - 0xf4, 0x6f, 0xd2, 0xa3, 0x93, 0x22, 0xda, 0xf7, 0x3f, 0x01, 0x00, 0x00, 0xff, 0xff, 0x15, 0x1b, - 0x83, 0x24, 0x66, 0x03, 0x00, 0x00, + // 382 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xac, 0x93, 0xc1, 0x6a, 0xe3, 0x30, + 0x10, 0x86, 0xc9, 0x6e, 0xb2, 0x6c, 0x14, 0xe2, 0x65, 0xc5, 0x1e, 0xb2, 0x39, 0xed, 0x86, 0x65, + 0x09, 0x85, 0x5a, 0x90, 0x42, 0x6f, 0xa5, 0xd0, 0xf4, 0xde, 0xa2, 0xe6, 0xd4, 0x4b, 0x90, 0x15, + 0xc5, 0x51, 0x2d, 0x4b, 0x62, 0x24, 0x35, 0xe4, 0x5d, 0xfa, 0x7a, 0x7d, 0x8f, 0x12, 0x3b, 0xb6, + 0xd3, 0x42, 0x42, 0x0f, 0xb9, 0xcd, 0x58, 0xff, 0xfc, 0xdf, 0x78, 0xa4, 0x41, 0x67, 0x4b, 0xb5, + 0xf1, 0x42, 0x2e, 0x14, 0xb1, 0x2a, 0xa4, 0x52, 0x3b, 0x92, 0x85, 0x44, 0x2c, 0x95, 0x59, 0x13, + 0x2f, 0xb4, 0x33, 0xb0, 0x0d, 0x63, 0x0b, 0xc6, 0x1b, 0xfc, 0xbb, 0xd2, 0xc6, 0x3b, 0x6d, 0x5c, + 0x69, 0x87, 0xf5, 0x11, 0xe1, 0x06, 0x04, 0xf1, 0xcc, 0x65, 0xae, 0xac, 0x1a, 0xfe, 0x3f, 0x4c, + 0xe0, 0x26, 0xcf, 0x8d, 0x2e, 0x75, 0xa3, 0x97, 0x36, 0xfa, 0x7b, 0x2b, 0x9d, 0x07, 0x99, 0x04, + 0x2f, 0x16, 0xb3, 0x9a, 0x3e, 0x03, 0x26, 0xb5, 0xd4, 0xe9, 0x8c, 0xb9, 0x0c, 0x2b, 0xf4, 0x63, + 0x6d, 0x20, 0x13, 0x30, 0x07, 0x61, 0x95, 0xe4, 0xcc, 0x0d, 0x5a, 0x7f, 0x5a, 0xe3, 0xde, 0x64, + 0x1a, 0x1f, 0xec, 0x2e, 0x3e, 0x6a, 0x4b, 0x4b, 0x9f, 0x07, 0x2b, 0x38, 0x8d, 0x4a, 0xef, 0xdd, + 0x27, 0x87, 0x17, 0xa8, 0x67, 0x5d, 0x43, 0xfa, 0x72, 0x3a, 0x12, 0xb2, 0xae, 0xa6, 0x3c, 0xa1, + 0x88, 0xaf, 0xa4, 0x58, 0x36, 0xa0, 0xaf, 0xa7, 0x03, 0xf5, 0x0b, 0xeb, 0x9a, 0x35, 0x45, 0x08, + 0x82, 0x9e, 0x5b, 0xa3, 0x24, 0xdf, 0x0c, 0xda, 0x05, 0xe7, 0xdf, 0x11, 0x0e, 0x0d, 0xfa, 0xbe, + 0xd0, 0xd2, 0x2e, 0x54, 0x21, 0x06, 0x84, 0xc5, 0x33, 0x53, 0x81, 0x79, 0xb3, 0x77, 0x0f, 0x9d, + 0xd3, 0x35, 0xfd, 0xb3, 0xb6, 0xaf, 0x1a, 0x1f, 0xbd, 0xb6, 0xd0, 0xf8, 0xb3, 0xf5, 0x78, 0x88, + 0xbe, 0xbf, 0x7b, 0x1e, 0x1d, 0x5a, 0xe7, 0xf8, 0x17, 0xea, 0xc8, 0x9c, 0xa5, 0xa2, 0xb8, 0xcd, + 0x2e, 0x2d, 0x13, 0x7c, 0x89, 0xba, 0x20, 0x9c, 0x09, 0xc0, 0x45, 0x35, 0xfe, 0x41, 0xf3, 0x27, + 0xdb, 0x47, 0x1d, 0xd3, 0xea, 0x9c, 0x36, 0x52, 0x7c, 0x87, 0x22, 0x10, 0xce, 0x33, 0xf0, 0xfb, + 0x33, 0x8d, 0x26, 0xe3, 0x63, 0x33, 0x2d, 0x0b, 0x76, 0x73, 0xed, 0xc3, 0x7e, 0x7a, 0x73, 0xfd, + 0x78, 0x95, 0x4a, 0xbf, 0x0a, 0x49, 0xcc, 0x4d, 0x4e, 0x0a, 0x13, 0x03, 0x69, 0x19, 0x90, 0x7a, + 0x95, 0x52, 0xa1, 0x89, 0x4d, 0xce, 0x53, 0x43, 0x3e, 0x6e, 0x57, 0xf2, 0xad, 0x58, 0xa7, 0x8b, + 0xb7, 0x00, 0x00, 0x00, 0xff, 0xff, 0x74, 0xcf, 0x60, 0xf4, 0xda, 0x03, 0x00, 0x00, } diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.validate.go b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.validate.go index 098b4dc7cf..397b3a813b 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.validate.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.validate.go @@ -84,6 +84,16 @@ func (m *DistributedTensorflowTrainingTask) Validate() error { } } + if v, ok := interface{}(m.GetEvaluatorReplicas()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return DistributedTensorflowTrainingTaskValidationError{ + field: "EvaluatorReplicas", + reason: "embedded message failed validation", + cause: err, + } + } + } + return nil } diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.go index d9f1006792..a07ff3feeb 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.go @@ -22,11 +22,15 @@ const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package // Custom proto for plugin that enables distributed training using https://github.com/kubeflow/tf-operator type DistributedTensorflowTrainingTask struct { - // number of worker, ps, chief replicas spawned in the cluster for this job + // number of worker replicas spawned in the cluster for this job Workers int32 `protobuf:"varint,1,opt,name=workers,proto3" json:"workers,omitempty"` // PS -> Parameter server - PsReplicas int32 `protobuf:"varint,2,opt,name=ps_replicas,json=psReplicas,proto3" json:"ps_replicas,omitempty"` - ChiefReplicas int32 `protobuf:"varint,3,opt,name=chief_replicas,json=chiefReplicas,proto3" json:"chief_replicas,omitempty"` + // number of ps replicas spawned in the cluster for this job + PsReplicas int32 `protobuf:"varint,2,opt,name=ps_replicas,json=psReplicas,proto3" json:"ps_replicas,omitempty"` + // number of chief replicas spawned in the cluster for this job + ChiefReplicas int32 `protobuf:"varint,3,opt,name=chief_replicas,json=chiefReplicas,proto3" json:"chief_replicas,omitempty"` + // number of evaluator replicas spawned in the cluster for this job + EvaluatorReplicas int32 `protobuf:"varint,4,opt,name=evaluator_replicas,json=evaluatorReplicas,proto3" json:"evaluator_replicas,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -78,6 +82,13 @@ func (m *DistributedTensorflowTrainingTask) GetChiefReplicas() int32 { return 0 } +func (m *DistributedTensorflowTrainingTask) GetEvaluatorReplicas() int32 { + if m != nil { + return m.EvaluatorReplicas + } + return 0 +} + func init() { proto.RegisterType((*DistributedTensorflowTrainingTask)(nil), "flyteidl.plugins.DistributedTensorflowTrainingTask") } @@ -85,18 +96,19 @@ func init() { func init() { proto.RegisterFile("flyteidl/plugins/tensorflow.proto", fileDescriptor_8da02783614e1bcc) } var fileDescriptor_8da02783614e1bcc = []byte{ - // 203 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x5c, 0x8f, 0xc1, 0x4a, 0xc4, 0x30, - 0x10, 0x86, 0xa9, 0xa2, 0x42, 0x44, 0x91, 0x9c, 0x7a, 0xd3, 0x0a, 0x82, 0x17, 0x9b, 0x83, 0x67, - 0x11, 0xc4, 0x27, 0x28, 0x3d, 0x79, 0x91, 0xa6, 0x4d, 0xd3, 0xa1, 0xd9, 0x4c, 0x98, 0x49, 0x29, - 0xfb, 0x00, 0xfb, 0xde, 0x0b, 0xd9, 0xb6, 0x0b, 0x7b, 0x9b, 0xf9, 0xe7, 0x1b, 0xf8, 0x3f, 0x51, - 0xf4, 0x6e, 0x1f, 0x0d, 0x74, 0x4e, 0x05, 0x37, 0x59, 0xf0, 0xac, 0xa2, 0xf1, 0x8c, 0xd4, 0x3b, - 0x9c, 0xcb, 0x40, 0x18, 0x51, 0x3e, 0xad, 0x48, 0xb9, 0x20, 0xaf, 0x87, 0x4c, 0x14, 0xbf, 0xc0, - 0x91, 0x40, 0x4f, 0xd1, 0x74, 0xf5, 0xf6, 0x51, 0x53, 0x03, 0x1e, 0xbc, 0xad, 0x1b, 0x1e, 0x65, - 0x2e, 0xee, 0x66, 0xa4, 0xd1, 0x10, 0xe7, 0xd9, 0x4b, 0xf6, 0x7e, 0x53, 0xad, 0xab, 0x7c, 0x16, - 0xf7, 0x81, 0xff, 0xc9, 0x04, 0x07, 0x6d, 0xc3, 0xf9, 0x55, 0xba, 0x8a, 0xc0, 0xd5, 0x92, 0xc8, - 0x37, 0xf1, 0xd8, 0x0e, 0x60, 0xfa, 0x33, 0x73, 0x9d, 0x98, 0x87, 0x94, 0xae, 0xd8, 0xcf, 0xf7, - 0xdf, 0x97, 0x85, 0x38, 0x4c, 0xba, 0x6c, 0x71, 0xa7, 0x52, 0x4d, 0x24, 0x7b, 0x1a, 0xd4, 0x26, - 0x66, 0x8d, 0x57, 0x41, 0x7f, 0x58, 0x54, 0x97, 0xae, 0xfa, 0x36, 0x19, 0x7e, 0x1e, 0x03, 0x00, - 0x00, 0xff, 0xff, 0x55, 0x40, 0x42, 0x06, 0x06, 0x01, 0x00, 0x00, + // 220 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x5c, 0xd0, 0xbf, 0x4a, 0x04, 0x31, + 0x10, 0x06, 0x70, 0xd6, 0xbf, 0x10, 0x51, 0x34, 0xd5, 0x76, 0x7a, 0x82, 0x60, 0x73, 0x9b, 0xc2, + 0x5a, 0x04, 0xf1, 0x09, 0x8e, 0xad, 0x6c, 0x24, 0xbb, 0x37, 0x9b, 0x1b, 0x2e, 0x66, 0xc2, 0x4c, + 0xe2, 0xe1, 0x7b, 0xf9, 0x80, 0x42, 0xbc, 0xcd, 0x81, 0x5d, 0xf2, 0x7d, 0xbf, 0x14, 0xf9, 0xd4, + 0x62, 0xf2, 0xdf, 0x09, 0x70, 0xed, 0x4d, 0xf4, 0xd9, 0x61, 0x10, 0x93, 0x20, 0x08, 0xf1, 0xe4, + 0x69, 0xd7, 0x45, 0xa6, 0x44, 0xfa, 0x7a, 0x26, 0xdd, 0x9e, 0xdc, 0xff, 0x34, 0x6a, 0xf1, 0x86, + 0x92, 0x18, 0x87, 0x9c, 0x60, 0xdd, 0xd7, 0x17, 0x3d, 0x5b, 0x0c, 0x18, 0x5c, 0x6f, 0x65, 0xab, + 0x5b, 0x75, 0xbe, 0x23, 0xde, 0x02, 0x4b, 0xdb, 0xdc, 0x35, 0x8f, 0xa7, 0xab, 0xf9, 0xaa, 0x6f, + 0xd5, 0x45, 0x94, 0x0f, 0x86, 0xe8, 0x71, 0xb4, 0xd2, 0x1e, 0x95, 0x56, 0x45, 0x59, 0xed, 0x13, + 0xfd, 0xa0, 0xae, 0xc6, 0x0d, 0xc2, 0x74, 0x30, 0xc7, 0xc5, 0x5c, 0x96, 0xb4, 0xb2, 0xa5, 0xd2, + 0xf0, 0x65, 0x7d, 0xb6, 0x89, 0xf8, 0x40, 0x4f, 0x0a, 0xbd, 0xa9, 0xcd, 0xcc, 0x5f, 0x5f, 0xde, + 0x9f, 0x1d, 0xa6, 0x4d, 0x1e, 0xba, 0x91, 0x3e, 0x4d, 0xf9, 0x15, 0xb1, 0xfb, 0x3b, 0x98, 0xba, + 0x83, 0x83, 0x60, 0xe2, 0xb0, 0x74, 0x64, 0xfe, 0x4f, 0x33, 0x9c, 0x95, 0x41, 0x9e, 0x7e, 0x03, + 0x00, 0x00, 0xff, 0xff, 0xa6, 0x22, 0x34, 0xcf, 0x35, 0x01, 0x00, 0x00, } diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.validate.go b/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.validate.go index ed7a8eeb80..00db969ce6 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.validate.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.validate.go @@ -50,6 +50,8 @@ func (m *DistributedTensorflowTrainingTask) Validate() error { // no validation rules for ChiefReplicas + // no validation rules for EvaluatorReplicas + return nil } diff --git a/flyteidl/gen/pb-java/flyteidl/plugins/Tensorflow.java b/flyteidl/gen/pb-java/flyteidl/plugins/Tensorflow.java index 14ab5db6bd..2353f8b78b 100644 --- a/flyteidl/gen/pb-java/flyteidl/plugins/Tensorflow.java +++ b/flyteidl/gen/pb-java/flyteidl/plugins/Tensorflow.java @@ -20,7 +20,7 @@ public interface DistributedTensorflowTrainingTaskOrBuilder extends /** *
-     * number of worker, ps, chief replicas spawned in the cluster for this job
+     * number of worker replicas spawned in the cluster for this job
      * 
* * int32 workers = 1; @@ -30,6 +30,7 @@ public interface DistributedTensorflowTrainingTaskOrBuilder extends /** *
      * PS -> Parameter server
+     * number of ps replicas spawned in the cluster for this job
      * 
* * int32 ps_replicas = 2; @@ -37,9 +38,22 @@ public interface DistributedTensorflowTrainingTaskOrBuilder extends int getPsReplicas(); /** + *
+     * number of chief replicas spawned in the cluster for this job
+     * 
+ * * int32 chief_replicas = 3; */ int getChiefReplicas(); + + /** + *
+     * number of evaluator replicas spawned in the cluster for this job
+     * 
+ * + * int32 evaluator_replicas = 4; + */ + int getEvaluatorReplicas(); } /** *
@@ -99,6 +113,11 @@ private DistributedTensorflowTrainingTask(
               chiefReplicas_ = input.readInt32();
               break;
             }
+            case 32: {
+
+              evaluatorReplicas_ = input.readInt32();
+              break;
+            }
             default: {
               if (!parseUnknownField(
                   input, unknownFields, extensionRegistry, tag)) {
@@ -135,7 +154,7 @@ private DistributedTensorflowTrainingTask(
     private int workers_;
     /**
      * 
-     * number of worker, ps, chief replicas spawned in the cluster for this job
+     * number of worker replicas spawned in the cluster for this job
      * 
* * int32 workers = 1; @@ -149,6 +168,7 @@ public int getWorkers() { /** *
      * PS -> Parameter server
+     * number of ps replicas spawned in the cluster for this job
      * 
* * int32 ps_replicas = 2; @@ -160,12 +180,29 @@ public int getPsReplicas() { public static final int CHIEF_REPLICAS_FIELD_NUMBER = 3; private int chiefReplicas_; /** + *
+     * number of chief replicas spawned in the cluster for this job
+     * 
+ * * int32 chief_replicas = 3; */ public int getChiefReplicas() { return chiefReplicas_; } + public static final int EVALUATOR_REPLICAS_FIELD_NUMBER = 4; + private int evaluatorReplicas_; + /** + *
+     * number of evaluator replicas spawned in the cluster for this job
+     * 
+ * + * int32 evaluator_replicas = 4; + */ + public int getEvaluatorReplicas() { + return evaluatorReplicas_; + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -189,6 +226,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (chiefReplicas_ != 0) { output.writeInt32(3, chiefReplicas_); } + if (evaluatorReplicas_ != 0) { + output.writeInt32(4, evaluatorReplicas_); + } unknownFields.writeTo(output); } @@ -210,6 +250,10 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeInt32Size(3, chiefReplicas_); } + if (evaluatorReplicas_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(4, evaluatorReplicas_); + } size += unknownFields.getSerializedSize(); memoizedSize = size; return size; @@ -231,6 +275,8 @@ public boolean equals(final java.lang.Object obj) { != other.getPsReplicas()) return false; if (getChiefReplicas() != other.getChiefReplicas()) return false; + if (getEvaluatorReplicas() + != other.getEvaluatorReplicas()) return false; if (!unknownFields.equals(other.unknownFields)) return false; return true; } @@ -248,6 +294,8 @@ public int hashCode() { hash = (53 * hash) + getPsReplicas(); hash = (37 * hash) + CHIEF_REPLICAS_FIELD_NUMBER; hash = (53 * hash) + getChiefReplicas(); + hash = (37 * hash) + EVALUATOR_REPLICAS_FIELD_NUMBER; + hash = (53 * hash) + getEvaluatorReplicas(); hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; @@ -391,6 +439,8 @@ public Builder clear() { chiefReplicas_ = 0; + evaluatorReplicas_ = 0; + return this; } @@ -420,6 +470,7 @@ public flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask buildPartia result.workers_ = workers_; result.psReplicas_ = psReplicas_; result.chiefReplicas_ = chiefReplicas_; + result.evaluatorReplicas_ = evaluatorReplicas_; onBuilt(); return result; } @@ -477,6 +528,9 @@ public Builder mergeFrom(flyteidl.plugins.Tensorflow.DistributedTensorflowTraini if (other.getChiefReplicas() != 0) { setChiefReplicas(other.getChiefReplicas()); } + if (other.getEvaluatorReplicas() != 0) { + setEvaluatorReplicas(other.getEvaluatorReplicas()); + } this.mergeUnknownFields(other.unknownFields); onChanged(); return this; @@ -509,7 +563,7 @@ public Builder mergeFrom( private int workers_ ; /** *
-       * number of worker, ps, chief replicas spawned in the cluster for this job
+       * number of worker replicas spawned in the cluster for this job
        * 
* * int32 workers = 1; @@ -519,7 +573,7 @@ public int getWorkers() { } /** *
-       * number of worker, ps, chief replicas spawned in the cluster for this job
+       * number of worker replicas spawned in the cluster for this job
        * 
* * int32 workers = 1; @@ -532,7 +586,7 @@ public Builder setWorkers(int value) { } /** *
-       * number of worker, ps, chief replicas spawned in the cluster for this job
+       * number of worker replicas spawned in the cluster for this job
        * 
* * int32 workers = 1; @@ -548,6 +602,7 @@ public Builder clearWorkers() { /** *
        * PS -> Parameter server
+       * number of ps replicas spawned in the cluster for this job
        * 
* * int32 ps_replicas = 2; @@ -558,6 +613,7 @@ public int getPsReplicas() { /** *
        * PS -> Parameter server
+       * number of ps replicas spawned in the cluster for this job
        * 
* * int32 ps_replicas = 2; @@ -571,6 +627,7 @@ public Builder setPsReplicas(int value) { /** *
        * PS -> Parameter server
+       * number of ps replicas spawned in the cluster for this job
        * 
* * int32 ps_replicas = 2; @@ -584,12 +641,20 @@ public Builder clearPsReplicas() { private int chiefReplicas_ ; /** + *
+       * number of chief replicas spawned in the cluster for this job
+       * 
+ * * int32 chief_replicas = 3; */ public int getChiefReplicas() { return chiefReplicas_; } /** + *
+       * number of chief replicas spawned in the cluster for this job
+       * 
+ * * int32 chief_replicas = 3; */ public Builder setChiefReplicas(int value) { @@ -599,6 +664,10 @@ public Builder setChiefReplicas(int value) { return this; } /** + *
+       * number of chief replicas spawned in the cluster for this job
+       * 
+ * * int32 chief_replicas = 3; */ public Builder clearChiefReplicas() { @@ -607,6 +676,44 @@ public Builder clearChiefReplicas() { onChanged(); return this; } + + private int evaluatorReplicas_ ; + /** + *
+       * number of evaluator replicas spawned in the cluster for this job
+       * 
+ * + * int32 evaluator_replicas = 4; + */ + public int getEvaluatorReplicas() { + return evaluatorReplicas_; + } + /** + *
+       * number of evaluator replicas spawned in the cluster for this job
+       * 
+ * + * int32 evaluator_replicas = 4; + */ + public Builder setEvaluatorReplicas(int value) { + + evaluatorReplicas_ = value; + onChanged(); + return this; + } + /** + *
+       * number of evaluator replicas spawned in the cluster for this job
+       * 
+ * + * int32 evaluator_replicas = 4; + */ + public Builder clearEvaluatorReplicas() { + + evaluatorReplicas_ = 0; + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -675,11 +782,12 @@ public flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask getDefaultI static { java.lang.String[] descriptorData = { "\n!flyteidl/plugins/tensorflow.proto\022\020fly" + - "teidl.plugins\"a\n!DistributedTensorflowTr" + + "teidl.plugins\"}\n!DistributedTensorflowTr" + "ainingTask\022\017\n\007workers\030\001 \001(\005\022\023\n\013ps_replic" + - "as\030\002 \001(\005\022\026\n\016chief_replicas\030\003 \001(\005B?Z=gith" + - "ub.com/flyteorg/flyte/flyteidl/gen/pb-go" + - "/flyteidl/pluginsb\006proto3" + "as\030\002 \001(\005\022\026\n\016chief_replicas\030\003 \001(\005\022\032\n\022eval" + + "uator_replicas\030\004 \001(\005B?Z=github.com/flyte" + + "org/flyte/flyteidl/gen/pb-go/flyteidl/pl" + + "uginsb\006proto3" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { @@ -698,7 +806,7 @@ public com.google.protobuf.ExtensionRegistry assignDescriptors( internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_descriptor, - new java.lang.String[] { "Workers", "PsReplicas", "ChiefReplicas", }); + new java.lang.String[] { "Workers", "PsReplicas", "ChiefReplicas", "EvaluatorReplicas", }); } // @@protoc_insertion_point(outer_class_scope) diff --git a/flyteidl/gen/pb-java/flyteidl/plugins/kubeflow/Tensorflow.java b/flyteidl/gen/pb-java/flyteidl/plugins/kubeflow/Tensorflow.java index ec763705c8..324f3e5d01 100644 --- a/flyteidl/gen/pb-java/flyteidl/plugins/kubeflow/Tensorflow.java +++ b/flyteidl/gen/pb-java/flyteidl/plugins/kubeflow/Tensorflow.java @@ -123,6 +123,31 @@ public interface DistributedTensorflowTrainingTaskOrBuilder extends * .flyteidl.plugins.kubeflow.RunPolicy run_policy = 4; */ flyteidl.plugins.kubeflow.Common.RunPolicyOrBuilder getRunPolicyOrBuilder(); + + /** + *
+     * Evaluator replicas spec
+     * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + boolean hasEvaluatorReplicas(); + /** + *
+     * Evaluator replicas spec
+     * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec getEvaluatorReplicas(); + /** + *
+     * Evaluator replicas spec
+     * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpecOrBuilder getEvaluatorReplicasOrBuilder(); } /** *
@@ -219,6 +244,19 @@ private DistributedTensorflowTrainingTask(
 
               break;
             }
+            case 42: {
+              flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.Builder subBuilder = null;
+              if (evaluatorReplicas_ != null) {
+                subBuilder = evaluatorReplicas_.toBuilder();
+              }
+              evaluatorReplicas_ = input.readMessage(flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.parser(), extensionRegistry);
+              if (subBuilder != null) {
+                subBuilder.mergeFrom(evaluatorReplicas_);
+                evaluatorReplicas_ = subBuilder.buildPartial();
+              }
+
+              break;
+            }
             default: {
               if (!parseUnknownField(
                   input, unknownFields, extensionRegistry, tag)) {
@@ -389,6 +427,39 @@ public flyteidl.plugins.kubeflow.Common.RunPolicyOrBuilder getRunPolicyOrBuilder
       return getRunPolicy();
     }
 
+    public static final int EVALUATOR_REPLICAS_FIELD_NUMBER = 5;
+    private flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec evaluatorReplicas_;
+    /**
+     * 
+     * Evaluator replicas spec
+     * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public boolean hasEvaluatorReplicas() { + return evaluatorReplicas_ != null; + } + /** + *
+     * Evaluator replicas spec
+     * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec getEvaluatorReplicas() { + return evaluatorReplicas_ == null ? flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.getDefaultInstance() : evaluatorReplicas_; + } + /** + *
+     * Evaluator replicas spec
+     * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpecOrBuilder getEvaluatorReplicasOrBuilder() { + return getEvaluatorReplicas(); + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -415,6 +486,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (runPolicy_ != null) { output.writeMessage(4, getRunPolicy()); } + if (evaluatorReplicas_ != null) { + output.writeMessage(5, getEvaluatorReplicas()); + } unknownFields.writeTo(output); } @@ -440,6 +514,10 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeMessageSize(4, getRunPolicy()); } + if (evaluatorReplicas_ != null) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(5, getEvaluatorReplicas()); + } size += unknownFields.getSerializedSize(); memoizedSize = size; return size; @@ -475,6 +553,11 @@ public boolean equals(final java.lang.Object obj) { if (!getRunPolicy() .equals(other.getRunPolicy())) return false; } + if (hasEvaluatorReplicas() != other.hasEvaluatorReplicas()) return false; + if (hasEvaluatorReplicas()) { + if (!getEvaluatorReplicas() + .equals(other.getEvaluatorReplicas())) return false; + } if (!unknownFields.equals(other.unknownFields)) return false; return true; } @@ -502,6 +585,10 @@ public int hashCode() { hash = (37 * hash) + RUN_POLICY_FIELD_NUMBER; hash = (53 * hash) + getRunPolicy().hashCode(); } + if (hasEvaluatorReplicas()) { + hash = (37 * hash) + EVALUATOR_REPLICAS_FIELD_NUMBER; + hash = (53 * hash) + getEvaluatorReplicas().hashCode(); + } hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; @@ -663,6 +750,12 @@ public Builder clear() { runPolicy_ = null; runPolicyBuilder_ = null; } + if (evaluatorReplicasBuilder_ == null) { + evaluatorReplicas_ = null; + } else { + evaluatorReplicas_ = null; + evaluatorReplicasBuilder_ = null; + } return this; } @@ -709,6 +802,11 @@ public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingTask bu } else { result.runPolicy_ = runPolicyBuilder_.build(); } + if (evaluatorReplicasBuilder_ == null) { + result.evaluatorReplicas_ = evaluatorReplicas_; + } else { + result.evaluatorReplicas_ = evaluatorReplicasBuilder_.build(); + } onBuilt(); return result; } @@ -769,6 +867,9 @@ public Builder mergeFrom(flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorf if (other.hasRunPolicy()) { mergeRunPolicy(other.getRunPolicy()); } + if (other.hasEvaluatorReplicas()) { + mergeEvaluatorReplicas(other.getEvaluatorReplicas()); + } this.mergeUnknownFields(other.unknownFields); onChanged(); return this; @@ -1427,6 +1528,159 @@ public flyteidl.plugins.kubeflow.Common.RunPolicyOrBuilder getRunPolicyOrBuilder } return runPolicyBuilder_; } + + private flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec evaluatorReplicas_; + private com.google.protobuf.SingleFieldBuilderV3< + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec, flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.Builder, flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpecOrBuilder> evaluatorReplicasBuilder_; + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public boolean hasEvaluatorReplicas() { + return evaluatorReplicasBuilder_ != null || evaluatorReplicas_ != null; + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec getEvaluatorReplicas() { + if (evaluatorReplicasBuilder_ == null) { + return evaluatorReplicas_ == null ? flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.getDefaultInstance() : evaluatorReplicas_; + } else { + return evaluatorReplicasBuilder_.getMessage(); + } + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public Builder setEvaluatorReplicas(flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec value) { + if (evaluatorReplicasBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + evaluatorReplicas_ = value; + onChanged(); + } else { + evaluatorReplicasBuilder_.setMessage(value); + } + + return this; + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public Builder setEvaluatorReplicas( + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.Builder builderForValue) { + if (evaluatorReplicasBuilder_ == null) { + evaluatorReplicas_ = builderForValue.build(); + onChanged(); + } else { + evaluatorReplicasBuilder_.setMessage(builderForValue.build()); + } + + return this; + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public Builder mergeEvaluatorReplicas(flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec value) { + if (evaluatorReplicasBuilder_ == null) { + if (evaluatorReplicas_ != null) { + evaluatorReplicas_ = + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.newBuilder(evaluatorReplicas_).mergeFrom(value).buildPartial(); + } else { + evaluatorReplicas_ = value; + } + onChanged(); + } else { + evaluatorReplicasBuilder_.mergeFrom(value); + } + + return this; + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public Builder clearEvaluatorReplicas() { + if (evaluatorReplicasBuilder_ == null) { + evaluatorReplicas_ = null; + onChanged(); + } else { + evaluatorReplicas_ = null; + evaluatorReplicasBuilder_ = null; + } + + return this; + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.Builder getEvaluatorReplicasBuilder() { + + onChanged(); + return getEvaluatorReplicasFieldBuilder().getBuilder(); + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpecOrBuilder getEvaluatorReplicasOrBuilder() { + if (evaluatorReplicasBuilder_ != null) { + return evaluatorReplicasBuilder_.getMessageOrBuilder(); + } else { + return evaluatorReplicas_ == null ? + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.getDefaultInstance() : evaluatorReplicas_; + } + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + private com.google.protobuf.SingleFieldBuilderV3< + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec, flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.Builder, flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpecOrBuilder> + getEvaluatorReplicasFieldBuilder() { + if (evaluatorReplicasBuilder_ == null) { + evaluatorReplicasBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec, flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.Builder, flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpecOrBuilder>( + getEvaluatorReplicas(), + getParentForChildren(), + isClean()); + evaluatorReplicas_ = null; + } + return evaluatorReplicasBuilder_; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -2553,7 +2807,7 @@ public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplica "\n*flyteidl/plugins/kubeflow/tensorflow.p" + "roto\022\031flyteidl.plugins.kubeflow\032\031flyteid" + "l/core/tasks.proto\032&flyteidl/plugins/kub" + - "eflow/common.proto\"\362\002\n!DistributedTensor" + + "eflow/common.proto\"\323\003\n!DistributedTensor" + "flowTrainingTask\022\\\n\017worker_replicas\030\001 \001(" + "\0132C.flyteidl.plugins.kubeflow.Distribute" + "dTensorflowTrainingReplicaSpec\022X\n\013ps_rep" + @@ -2562,14 +2816,16 @@ public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplica "\022[\n\016chief_replicas\030\003 \001(\0132C.flyteidl.plug" + "ins.kubeflow.DistributedTensorflowTraini" + "ngReplicaSpec\0228\n\nrun_policy\030\004 \001(\0132$.flyt" + - "eidl.plugins.kubeflow.RunPolicy\"\272\001\n(Dist" + - "ributedTensorflowTrainingReplicaSpec\022\020\n\010" + - "replicas\030\001 \001(\005\022\r\n\005image\030\002 \001(\t\022+\n\tresourc" + - "es\030\003 \001(\0132\030.flyteidl.core.Resources\022@\n\016re" + - "start_policy\030\004 \001(\0162(.flyteidl.plugins.ku" + - "beflow.RestartPolicyB?Z=github.com/flyte" + - "org/flyte/flyteidl/gen/pb-go/flyteidl/pl" + - "uginsb\006proto3" + "eidl.plugins.kubeflow.RunPolicy\022_\n\022evalu" + + "ator_replicas\030\005 \001(\0132C.flyteidl.plugins.k" + + "ubeflow.DistributedTensorflowTrainingRep" + + "licaSpec\"\272\001\n(DistributedTensorflowTraini" + + "ngReplicaSpec\022\020\n\010replicas\030\001 \001(\005\022\r\n\005image" + + "\030\002 \001(\t\022+\n\tresources\030\003 \001(\0132\030.flyteidl.cor" + + "e.Resources\022@\n\016restart_policy\030\004 \001(\0162(.fl" + + "yteidl.plugins.kubeflow.RestartPolicyB?Z" + + "=github.com/flyteorg/flyte/flyteidl/gen/" + + "pb-go/flyteidl/pluginsb\006proto3" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { @@ -2590,7 +2846,7 @@ public com.google.protobuf.ExtensionRegistry assignDescriptors( internal_static_flyteidl_plugins_kubeflow_DistributedTensorflowTrainingTask_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_flyteidl_plugins_kubeflow_DistributedTensorflowTrainingTask_descriptor, - new java.lang.String[] { "WorkerReplicas", "PsReplicas", "ChiefReplicas", "RunPolicy", }); + new java.lang.String[] { "WorkerReplicas", "PsReplicas", "ChiefReplicas", "RunPolicy", "EvaluatorReplicas", }); internal_static_flyteidl_plugins_kubeflow_DistributedTensorflowTrainingReplicaSpec_descriptor = getDescriptor().getMessageTypes().get(1); internal_static_flyteidl_plugins_kubeflow_DistributedTensorflowTrainingReplicaSpec_fieldAccessorTable = new diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.py b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.py index 15f0d96558..f0c086f9e7 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.py @@ -15,7 +15,7 @@ from flyteidl.plugins.kubeflow import common_pb2 as flyteidl_dot_plugins_dot_kubeflow_dot_common__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*flyteidl/plugins/kubeflow/tensorflow.proto\x12\x19\x66lyteidl.plugins.kubeflow\x1a\x19\x66lyteidl/core/tasks.proto\x1a&flyteidl/plugins/kubeflow/common.proto\"\xa8\x03\n!DistributedTensorflowTrainingTask\x12l\n\x0fworker_replicas\x18\x01 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\x0eworkerReplicas\x12\x64\n\x0bps_replicas\x18\x02 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\npsReplicas\x12j\n\x0e\x63hief_replicas\x18\x03 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\rchiefReplicas\x12\x43\n\nrun_policy\x18\x04 \x01(\x0b\x32$.flyteidl.plugins.kubeflow.RunPolicyR\trunPolicy\"\xe5\x01\n(DistributedTensorflowTrainingReplicaSpec\x12\x1a\n\x08replicas\x18\x01 \x01(\x05R\x08replicas\x12\x14\n\x05image\x18\x02 \x01(\tR\x05image\x12\x36\n\tresources\x18\x03 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x12O\n\x0erestart_policy\x18\x04 \x01(\x0e\x32(.flyteidl.plugins.kubeflow.RestartPolicyR\rrestartPolicyB\xf5\x01\n\x1d\x63om.flyteidl.plugins.kubeflowB\x0fTensorflowProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PK\xaa\x02\x19\x46lyteidl.Plugins.Kubeflow\xca\x02\x19\x46lyteidl\\Plugins\\Kubeflow\xe2\x02%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\xea\x02\x1b\x46lyteidl::Plugins::Kubeflowb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*flyteidl/plugins/kubeflow/tensorflow.proto\x12\x19\x66lyteidl.plugins.kubeflow\x1a\x19\x66lyteidl/core/tasks.proto\x1a&flyteidl/plugins/kubeflow/common.proto\"\x9c\x04\n!DistributedTensorflowTrainingTask\x12l\n\x0fworker_replicas\x18\x01 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\x0eworkerReplicas\x12\x64\n\x0bps_replicas\x18\x02 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\npsReplicas\x12j\n\x0e\x63hief_replicas\x18\x03 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\rchiefReplicas\x12\x43\n\nrun_policy\x18\x04 \x01(\x0b\x32$.flyteidl.plugins.kubeflow.RunPolicyR\trunPolicy\x12r\n\x12\x65valuator_replicas\x18\x05 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\x11\x65valuatorReplicas\"\xe5\x01\n(DistributedTensorflowTrainingReplicaSpec\x12\x1a\n\x08replicas\x18\x01 \x01(\x05R\x08replicas\x12\x14\n\x05image\x18\x02 \x01(\tR\x05image\x12\x36\n\tresources\x18\x03 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x12O\n\x0erestart_policy\x18\x04 \x01(\x0e\x32(.flyteidl.plugins.kubeflow.RestartPolicyR\rrestartPolicyB\xf5\x01\n\x1d\x63om.flyteidl.plugins.kubeflowB\x0fTensorflowProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PK\xaa\x02\x19\x46lyteidl.Plugins.Kubeflow\xca\x02\x19\x46lyteidl\\Plugins\\Kubeflow\xe2\x02%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\xea\x02\x1b\x46lyteidl::Plugins::Kubeflowb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -25,7 +25,7 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\n\035com.flyteidl.plugins.kubeflowB\017TensorflowProtoP\001Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\242\002\003FPK\252\002\031Flyteidl.Plugins.Kubeflow\312\002\031Flyteidl\\Plugins\\Kubeflow\342\002%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\352\002\033Flyteidl::Plugins::Kubeflow' _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_start=141 - _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_end=565 - _globals['_DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC']._serialized_start=568 - _globals['_DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC']._serialized_end=797 + _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_end=681 + _globals['_DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC']._serialized_start=684 + _globals['_DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC']._serialized_end=913 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.pyi index e08a1ff983..4a999f70e8 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.pyi @@ -7,16 +7,18 @@ from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Opti DESCRIPTOR: _descriptor.FileDescriptor class DistributedTensorflowTrainingTask(_message.Message): - __slots__ = ["worker_replicas", "ps_replicas", "chief_replicas", "run_policy"] + __slots__ = ["worker_replicas", "ps_replicas", "chief_replicas", "run_policy", "evaluator_replicas"] WORKER_REPLICAS_FIELD_NUMBER: _ClassVar[int] PS_REPLICAS_FIELD_NUMBER: _ClassVar[int] CHIEF_REPLICAS_FIELD_NUMBER: _ClassVar[int] RUN_POLICY_FIELD_NUMBER: _ClassVar[int] + EVALUATOR_REPLICAS_FIELD_NUMBER: _ClassVar[int] worker_replicas: DistributedTensorflowTrainingReplicaSpec ps_replicas: DistributedTensorflowTrainingReplicaSpec chief_replicas: DistributedTensorflowTrainingReplicaSpec run_policy: _common_pb2.RunPolicy - def __init__(self, worker_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., ps_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., chief_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., run_policy: _Optional[_Union[_common_pb2.RunPolicy, _Mapping]] = ...) -> None: ... + evaluator_replicas: DistributedTensorflowTrainingReplicaSpec + def __init__(self, worker_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., ps_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., chief_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., run_policy: _Optional[_Union[_common_pb2.RunPolicy, _Mapping]] = ..., evaluator_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ...) -> None: ... class DistributedTensorflowTrainingReplicaSpec(_message.Message): __slots__ = ["replicas", "image", "resources", "restart_policy"] diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.py b/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.py index 187526d6d5..ceed4231bb 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.py @@ -13,7 +13,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n!flyteidl/plugins/tensorflow.proto\x12\x10\x66lyteidl.plugins\"\x85\x01\n!DistributedTensorflowTrainingTask\x12\x18\n\x07workers\x18\x01 \x01(\x05R\x07workers\x12\x1f\n\x0bps_replicas\x18\x02 \x01(\x05R\npsReplicas\x12%\n\x0e\x63hief_replicas\x18\x03 \x01(\x05R\rchiefReplicasB\xc7\x01\n\x14\x63om.flyteidl.pluginsB\x0fTensorflowProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n!flyteidl/plugins/tensorflow.proto\x12\x10\x66lyteidl.plugins\"\xb4\x01\n!DistributedTensorflowTrainingTask\x12\x18\n\x07workers\x18\x01 \x01(\x05R\x07workers\x12\x1f\n\x0bps_replicas\x18\x02 \x01(\x05R\npsReplicas\x12%\n\x0e\x63hief_replicas\x18\x03 \x01(\x05R\rchiefReplicas\x12-\n\x12\x65valuator_replicas\x18\x04 \x01(\x05R\x11\x65valuatorReplicasB\xc7\x01\n\x14\x63om.flyteidl.pluginsB\x0fTensorflowProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -23,5 +23,5 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\n\024com.flyteidl.pluginsB\017TensorflowProtoP\001Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\242\002\003FPX\252\002\020Flyteidl.Plugins\312\002\020Flyteidl\\Plugins\342\002\034Flyteidl\\Plugins\\GPBMetadata\352\002\021Flyteidl::Plugins' _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_start=56 - _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_end=189 + _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_end=236 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.pyi index d3dc028af3..81e2bc30b9 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.pyi @@ -5,11 +5,13 @@ from typing import ClassVar as _ClassVar, Optional as _Optional DESCRIPTOR: _descriptor.FileDescriptor class DistributedTensorflowTrainingTask(_message.Message): - __slots__ = ["workers", "ps_replicas", "chief_replicas"] + __slots__ = ["workers", "ps_replicas", "chief_replicas", "evaluator_replicas"] WORKERS_FIELD_NUMBER: _ClassVar[int] PS_REPLICAS_FIELD_NUMBER: _ClassVar[int] CHIEF_REPLICAS_FIELD_NUMBER: _ClassVar[int] + EVALUATOR_REPLICAS_FIELD_NUMBER: _ClassVar[int] workers: int ps_replicas: int chief_replicas: int - def __init__(self, workers: _Optional[int] = ..., ps_replicas: _Optional[int] = ..., chief_replicas: _Optional[int] = ...) -> None: ... + evaluator_replicas: int + def __init__(self, workers: _Optional[int] = ..., ps_replicas: _Optional[int] = ..., chief_replicas: _Optional[int] = ..., evaluator_replicas: _Optional[int] = ...) -> None: ... diff --git a/flyteidl/gen/pb_rust/flyteidl.plugins.kubeflow.rs b/flyteidl/gen/pb_rust/flyteidl.plugins.kubeflow.rs index 59c1f681a0..96d46653da 100644 --- a/flyteidl/gen/pb_rust/flyteidl.plugins.kubeflow.rs +++ b/flyteidl/gen/pb_rust/flyteidl.plugins.kubeflow.rs @@ -182,6 +182,9 @@ pub struct DistributedTensorflowTrainingTask { /// active. #[prost(message, optional, tag="4")] pub run_policy: ::core::option::Option, + /// Evaluator replicas spec + #[prost(message, optional, tag="5")] + pub evaluator_replicas: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/flyteidl/gen/pb_rust/flyteidl.plugins.rs b/flyteidl/gen/pb_rust/flyteidl.plugins.rs index 5c7873b5d2..11e4ad05af 100644 --- a/flyteidl/gen/pb_rust/flyteidl.plugins.rs +++ b/flyteidl/gen/pb_rust/flyteidl.plugins.rs @@ -290,14 +290,19 @@ pub struct SparkJob { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DistributedTensorflowTrainingTask { - /// number of worker, ps, chief replicas spawned in the cluster for this job + /// number of worker replicas spawned in the cluster for this job #[prost(int32, tag="1")] pub workers: i32, /// PS -> Parameter server + /// number of ps replicas spawned in the cluster for this job #[prost(int32, tag="2")] pub ps_replicas: i32, + /// number of chief replicas spawned in the cluster for this job #[prost(int32, tag="3")] pub chief_replicas: i32, + /// number of evaluator replicas spawned in the cluster for this job + #[prost(int32, tag="4")] + pub evaluator_replicas: i32, } /// Represents an Execution that was launched and could be waited on. #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto index 4cf3153548..789666b989 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto @@ -22,6 +22,9 @@ message DistributedTensorflowTrainingTask { // job, for example how to clean up resources and how long the job can stay // active. RunPolicy run_policy = 4; + + // Evaluator replicas spec + DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; } message DistributedTensorflowTrainingReplicaSpec { diff --git a/flyteidl/protos/flyteidl/plugins/tensorflow.proto b/flyteidl/protos/flyteidl/plugins/tensorflow.proto index e768ad65ad..e494a6cc32 100644 --- a/flyteidl/protos/flyteidl/plugins/tensorflow.proto +++ b/flyteidl/protos/flyteidl/plugins/tensorflow.proto @@ -6,9 +6,13 @@ option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugi // Custom proto for plugin that enables distributed training using https://github.com/kubeflow/tf-operator message DistributedTensorflowTrainingTask { - // number of worker, ps, chief replicas spawned in the cluster for this job + // number of worker replicas spawned in the cluster for this job int32 workers = 1; // PS -> Parameter server + // number of ps replicas spawned in the cluster for this job int32 ps_replicas = 2; + // number of chief replicas spawned in the cluster for this job int32 chief_replicas = 3; + // number of evaluator replicas spawned in the cluster for this job + int32 evaluator_replicas = 4; } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 4e38d791c4..f6a9787dbb 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -93,7 +93,7 @@ func GetMPIPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Tim // GetLogs will return the logs for kubeflow job func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v1.ObjectMeta, hasMaster bool, - workersCount int32, psReplicasCount int32, chiefReplicasCount int32) ([]*core.TaskLog, error) { + workersCount int32, psReplicasCount int32, chiefReplicasCount int32, evaluatorReplicasCount int32) ([]*core.TaskLog, error) { name := objectMeta.Name namespace := objectMeta.Namespace @@ -181,6 +181,18 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v } taskLogs = append(taskLogs, chiefReplicaLog.TaskLogs...) } + // get evaluator log, and the max number of evaluator is 1 + if evaluatorReplicasCount != 0 { + evaluatorReplicasCount, err := logPlugin.GetTaskLogs(tasklog.Input{ + PodName: name + fmt.Sprintf("-evaluatorReplica-%d", 0), + Namespace: namespace, + TaskExecutionIdentifier: &taskExecID, + }) + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, evaluatorReplicasCount.TaskLogs...) + } return taskLogs, nil } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go index 4f5d70dc5c..6c083b6898 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -167,7 +167,7 @@ func TestGetLogs(t *testing.T) { Name: "test", Namespace: "mpi-namespace", } - jobLogs, err := GetLogs(taskCtx, MPITaskType, mpiJobObjectMeta, false, workers, launcher, 0) + jobLogs, err := GetLogs(taskCtx, MPITaskType, mpiJobObjectMeta, false, workers, launcher, 0, 0) assert.NoError(t, err) assert.Equal(t, 1, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", "mpi-namespace", "test"), jobLogs[0].Uri) @@ -176,7 +176,7 @@ func TestGetLogs(t *testing.T) { Name: "test", Namespace: "pytorch-namespace", } - jobLogs, err = GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, true, workers, launcher, 0) + jobLogs, err = GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, true, workers, launcher, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[0].Uri) @@ -186,7 +186,7 @@ func TestGetLogs(t *testing.T) { Name: "test", Namespace: "tensorflow-namespace", } - jobLogs, err = GetLogs(taskCtx, TensorflowTaskType, tensorflowJobObjectMeta, false, workers, launcher, 1) + jobLogs, err = GetLogs(taskCtx, TensorflowTaskType, tensorflowJobObjectMeta, false, workers, launcher, 1, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[0].Uri) @@ -209,7 +209,7 @@ func TestGetLogsTemplateUri(t *testing.T) { Time: time.Date(2022, time.January, 1, 12, 0, 0, 0, time.UTC), }, } - jobLogs, err := GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, true, 1, 0, 0) + jobLogs, err := GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, true, 1, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("https://console.cloud.google.com/logs/query;query=resource.labels.pod_name=%s-master-0×tamp>%s", "test", "2022-01-01T12:00:00Z"), jobLogs[0].Uri) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 492dd32235..25e45ad727 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -210,7 +210,7 @@ func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext numLauncherReplicas = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas taskLogs, err := common.GetLogs(pluginContext, common.MPITaskType, app.ObjectMeta, false, - *numWorkers, *numLauncherReplicas, 0) + *numWorkers, *numLauncherReplicas, 0, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 96d8efdcb8..b7c66b08f6 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -410,7 +410,7 @@ func TestGetLogs(t *testing.T) { mpiResourceHandler := mpiOperatorResourceHandler{} mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, mpiOp.JobRunning) taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots))) - jobLogs, err := common.GetLogs(taskCtx, common.MPITaskType, mpiJob.ObjectMeta, false, workers, launcher, 0) + jobLogs, err := common.GetLogs(taskCtx, common.MPITaskType, mpiJob.ObjectMeta, false, workers, launcher, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", jobNamespace, jobName), jobLogs[0].Uri) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 07f3df0ef5..2461c7bc18 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -231,7 +231,7 @@ func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginCont workersCount := app.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas - taskLogs, err := common.GetLogs(pluginContext, common.PytorchTaskType, app.ObjectMeta, hasMaster, *workersCount, 0, 0) + taskLogs, err := common.GetLogs(pluginContext, common.PytorchTaskType, app.ObjectMeta, hasMaster, *workersCount, 0, 0, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index a1e37f5124..9407294971 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -455,7 +455,7 @@ func TestGetLogs(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers))) - jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0) + jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri) @@ -475,7 +475,7 @@ func TestGetLogsElastic(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers))) - jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0) + jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 578bd2a0d6..1c4c965819 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -80,6 +80,11 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task PodSpec: podSpec.DeepCopy(), RestartPolicy: commonOp.RestartPolicyNever, }, + kubeflowv1.TFJobReplicaTypeEval: { + ReplicaNum: int32(0), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + }, } runPolicy := commonOp.RunPolicy{} @@ -94,6 +99,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].ReplicaNum = tensorflowTaskExtraArgs.GetChiefReplicas() replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].ReplicaNum = tensorflowTaskExtraArgs.GetWorkers() replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].ReplicaNum = tensorflowTaskExtraArgs.GetPsReplicas() + replicaSpecMap[kubeflowv1.TFJobReplicaTypeEval].ReplicaNum = tensorflowTaskExtraArgs.GetEvaluatorReplicas() } else if taskTemplate.TaskTypeVersion == 1 { kfTensorflowTaskExtraArgs := kfplugins.DistributedTensorflowTrainingTask{} @@ -151,6 +157,22 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].ReplicaNum = psReplicaSpec.GetReplicas() } + evaluatorReplicaSpec := kfTensorflowTaskExtraArgs.GetEvaluatorReplicas() + if evaluatorReplicaSpec != nil { + err := common.OverrideContainerSpec( + replicaSpecMap[kubeflowv1.TFJobReplicaTypeEval].PodSpec, + kubeflowv1.TFJobDefaultContainerName, + evaluatorReplicaSpec.GetImage(), + evaluatorReplicaSpec.GetResources(), + nil, + ) + if err != nil { + return nil, err + } + replicaSpecMap[kubeflowv1.TFJobReplicaTypeEval].RestartPolicy = common.ParseRestartPolicy(evaluatorReplicaSpec.GetRestartPolicy()) + replicaSpecMap[kubeflowv1.TFJobReplicaTypeEval].ReplicaNum = evaluatorReplicaSpec.GetReplicas() + } + if kfTensorflowTaskExtraArgs.GetRunPolicy() != nil { runPolicy = common.ParseRunPolicy(*kfTensorflowTaskExtraArgs.GetRunPolicy()) } @@ -207,9 +229,10 @@ func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginC workersCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas psReplicasCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas chiefCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas + evaluatorReplicasCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeEval].Replicas taskLogs, err := common.GetLogs(pluginContext, common.TensorflowTaskType, app.ObjectMeta, false, - *workersCount, *psReplicasCount, *chiefCount) + *workersCount, *psReplicasCount, *chiefCount, *evaluatorReplicasCount) if err != nil { return pluginsCore.PhaseInfoUndefined, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 7ec1983bea..764be38181 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -70,11 +70,12 @@ var ( jobNamespace = "tensorflow-namespace" ) -func dummyTensorFlowCustomObj(workers int32, psReplicas int32, chiefReplicas int32) *plugins.DistributedTensorflowTrainingTask { +func dummyTensorFlowCustomObj(workers int32, psReplicas int32, chiefReplicas int32, evaluatorReplicas int32) *plugins.DistributedTensorflowTrainingTask { return &plugins.DistributedTensorflowTrainingTask{ - Workers: workers, - PsReplicas: psReplicas, - ChiefReplicas: chiefReplicas, + Workers: workers, + PsReplicas: psReplicas, + ChiefReplicas: chiefReplicas, + EvaluatorReplicas: evaluatorReplicas, } } @@ -175,7 +176,7 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.Tas } func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorResourceHandler, - workers int32, psReplicas int32, chiefReplicas int32, conditionType commonOp.JobConditionType) *kubeflowv1.TFJob { + workers int32, psReplicas int32, chiefReplicas int32, evaluatorReplicas int32, conditionType commonOp.JobConditionType) *kubeflowv1.TFJob { var jobConditions []commonOp.JobCondition now := time.Now() @@ -274,7 +275,7 @@ func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorReso } } - tfObj := dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas) + tfObj := dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate)) if err != nil { @@ -300,7 +301,7 @@ func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorReso func TestBuildResourceTensorFlow(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - tfObj := dummyTensorFlowCustomObj(100, 50, 1) + tfObj := dummyTensorFlowCustomObj(100, 50, 1, 1) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate)) @@ -312,6 +313,7 @@ func TestBuildResourceTensorFlow(t *testing.T) { assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas) assert.Equal(t, int32(50), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas) assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas) + assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeEval].Replicas) // verify TaskExecutionMetadata labels and annotations are copied to the TensorFlowJob for k, v := range dummyAnnotations { @@ -346,10 +348,10 @@ func TestGetTaskPhase(t *testing.T) { ctx := context.TODO() dummyTensorFlowJobResourceCreator := func(conditionType commonOp.JobConditionType) *kubeflowv1.TFJob { - return dummyTensorFlowJobResource(tensorflowResourceHandler, 2, 1, 1, conditionType) + return dummyTensorFlowJobResource(tensorflowResourceHandler, 2, 1, 1, 1, conditionType) } - taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(2, 1, 1))) + taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(2, 1, 1, 1))) taskPhase, err := tensorflowResourceHandler.GetTaskPhase(ctx, taskCtx, dummyTensorFlowJobResourceCreator(commonOp.JobCreated)) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) @@ -390,18 +392,20 @@ func TestGetLogs(t *testing.T) { workers := int32(2) psReplicas := int32(1) chiefReplicas := int32(1) + evaluatorReplicas := int32(1) tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - tensorFlowJob := dummyTensorFlowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, commonOp.JobRunning) - taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas))) + tensorFlowJob := dummyTensorFlowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, evaluatorReplicas, commonOp.JobRunning) + taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas))) jobLogs, err := common.GetLogs(taskCtx, common.TensorflowTaskType, tensorFlowJob.ObjectMeta, false, - workers, psReplicas, chiefReplicas) + workers, psReplicas, chiefReplicas, evaluatorReplicas) assert.NoError(t, err) - assert.Equal(t, 4, len(jobLogs)) + assert.Equal(t, 5, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[0].Uri) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-1/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[1].Uri) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-psReplica-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[2].Uri) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-chiefReplica-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[3].Uri) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-evaluatorReplica-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[4].Uri) } func TestGetProperties(t *testing.T) { @@ -412,26 +416,31 @@ func TestGetProperties(t *testing.T) { func TestReplicaCounts(t *testing.T) { for _, test := range []struct { - name string - chiefReplicaCount int32 - psReplicaCount int32 - workerReplicaCount int32 - expectError bool - contains []commonOp.ReplicaType - notContains []commonOp.ReplicaType + name string + chiefReplicaCount int32 + psReplicaCount int32 + workerReplicaCount int32 + evaluatorReplicaCount int32 + expectError bool + contains []commonOp.ReplicaType + notContains []commonOp.ReplicaType }{ - {"NoWorkers", 1, 1, 0, true, nil, nil}, - {"SingleChief", 1, 0, 1, false, + {"NoWorkers", 1, 1, 0, 1, true, nil, nil}, + {"SingleChief", 1, 0, 1, 0, false, []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypeChief, kubeflowv1.TFJobReplicaTypeWorker}, - []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypePS}}, - {"SinglePS", 0, 1, 1, false, + []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypePS, kubeflowv1.TFJobReplicaTypeEval}}, + {"SinglePS", 0, 1, 1, 0, false, []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypePS, kubeflowv1.TFJobReplicaTypeWorker}, - []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypeChief}}, + []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypeChief, kubeflowv1.TFJobReplicaTypeEval}}, + {"AllContains", 1, 1, 1, 1, false, + []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypePS, kubeflowv1.TFJobReplicaTypeWorker, kubeflowv1.TFJobReplicaTypeChief, kubeflowv1.TFJobReplicaTypeEval}, + nil, + }, } { t.Run(test.name, func(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - tfObj := dummyTensorFlowCustomObj(test.workerReplicaCount, test.psReplicaCount, test.chiefReplicaCount) + tfObj := dummyTensorFlowCustomObj(test.workerReplicaCount, test.psReplicaCount, test.chiefReplicaCount, test.evaluatorReplicaCount) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate)) @@ -499,6 +508,21 @@ func TestBuildResourceTensorFlowV1(t *testing.T) { }, }, }, + EvaluatorReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 1, + Image: testImage, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + }, + RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, + }, RunPolicy: &kfplugins.RunPolicy{ CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, ActiveDeadlineSeconds: int32(100), @@ -534,6 +558,16 @@ func TestBuildResourceTensorFlowV1(t *testing.T) { corev1.ResourceCPU: resource.MustParse("500m"), }, }, + kubeflowv1.TFJobReplicaTypeEval: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + }, } tensorflowResourceHandler := tensorflowOperatorResourceHandler{} @@ -550,6 +584,7 @@ func TestBuildResourceTensorFlowV1(t *testing.T) { assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas) assert.Equal(t, int32(50), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas) assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas) + assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeEval].Replicas) for replicaType, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { var hasContainerWithDefaultTensorFlowName = false