From 356ff698ac6d5072872d63b335f24cc1d65c83a0 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 26 Aug 2019 19:46:12 -0700 Subject: [PATCH] Added command line arguments for Horovod knob environment variables, config file, and new knobs for autotuning (#1345) --- docs/mpirun.rst | 20 +++ docs/tensor-fusion.rst | 15 +- docs/timeline.rst | 11 +- horovod/common/common.h | 4 + horovod/common/parameter_manager.cc | 36 +++-- horovod/common/parameter_manager.h | 10 +- horovod/common/utils/env_parser.cc | 5 + horovod/common/utils/env_parser.h | 2 + horovod/run/common/util/config_parser.py | 160 ++++++++++++++++++++ horovod/run/gloo_run.py | 7 +- horovod/run/mpi_run.py | 6 +- horovod/run/run.py | 172 ++++++++++++++++++++- setup.py | 2 +- test/data/config.test.yaml | 30 ++++ test/test_run.py | 184 +++++++++++++++++++++++ 15 files changed, 617 insertions(+), 47 deletions(-) create mode 100644 horovod/run/common/util/config_parser.py create mode 100644 test/data/config.test.yaml create mode 100644 test/test_run.py diff --git a/docs/mpirun.rst b/docs/mpirun.rst index 29b3d57f15..f9684df83b 100644 --- a/docs/mpirun.rst +++ b/docs/mpirun.rst @@ -94,6 +94,26 @@ example below: Other MPI RDMA implementations may or may not benefit from disabling multithreading, so please consult vendor documentation. +Horovod Parameter Knobs +----------------------- + +Many of the configurable parameters available as command line arguments to ``horovodrun`` can be used with ``mpirun`` +through the use of environment variables. + +Tensor Fusion: + +.. code-block:: bash + + $ mpirun -x HOROVOD_FUSION_THRESHOLD=33554432 -x HOROVOD_CYCLE_TIME=3.5 ... python train.py + +Timeline: + +.. code-block:: bash + + $ mpirun -x HOROVOD_TIMELINE=/path/to/timeline.json -x HOROVOD_TIMELINE_MARK_CYCLES=1 ... python train.py + +Note that when using ``horovodrun``, any command line arguments will override values set in the environment. + Hangs due to non-routed network interfaces ------------------------------------------ diff --git a/docs/tensor-fusion.rst b/docs/tensor-fusion.rst index db742c39ff..31b2d758cd 100644 --- a/docs/tensor-fusion.rst +++ b/docs/tensor-fusion.rst @@ -16,25 +16,22 @@ one reduction operation. The algorithm of Tensor Fusion is as follows: 5. Copy data from the fusion buffer into the output tensors. 6. Repeat until there are no more tensors to reduce in this cycle. -The fusion buffer size can be tweaked using the ``HOROVOD_FUSION_THRESHOLD`` environment variable: +The fusion buffer size can be adjusted using the ``--fusion-threshold-mb`` command line argument to ``horovodrun``: .. code-block:: bash - $ HOROVOD_FUSION_THRESHOLD=33554432 horovodrun -np 4 python train.py + $ horovodrun -np 4 --fusion-threshold-mb 32 python train.py - -Setting the ``HOROVOD_FUSION_THRESHOLD`` environment variable to zero disables Tensor Fusion: +Setting ``--fusion-threshold-mb`` to zero disables Tensor Fusion: .. code-block:: bash - $ HOROVOD_FUSION_THRESHOLD=0 horovodrun -np 4 python train.py - + $ horovodrun -np 4 --fusion-threshold-mb 0 python train.py -You can tweak time between cycles (defined in milliseconds) using the ``HOROVOD_CYCLE_TIME`` environment variable: +You can tweak time between cycles (defined in milliseconds) using the ``--cycle-time-ms`` command line argument: .. code-block:: bash - $ HOROVOD_CYCLE_TIME=3.5 horovodrun -np 4 python train.py - + $ horovodrun -np 4 --cycle-time-ms 3.5 python train.py .. inclusion-marker-end-do-not-remove diff --git a/docs/timeline.rst b/docs/timeline.rst index 3bb6b762a7..5f94f91036 100644 --- a/docs/timeline.rst +++ b/docs/timeline.rst @@ -9,12 +9,12 @@ Horovod has the ability to record the timeline of its activity, called Horovod T :alt: Horovod Timeline -To record a Horovod Timeline, set the ``HOROVOD_TIMELINE`` environment variable to the location of the timeline +To record a Horovod Timeline, set the ``--timeline-filename`` command line argument to the location of the timeline file to be created. This file is only recorded on rank 0, but it contains information about activity of all workers. .. code-block:: bash - $ HOROVOD_TIMELINE=/path/to/timeline.json horovodrun -np 4 python train.py + $ horovodrun -np 4 --timeline-filename /path/to/timeline.json python train.py You can then open the timeline file using the ``chrome://tracing`` facility of the `Chrome `__ browser. @@ -49,13 +49,10 @@ Horovod performs work in cycles. These cycles are used to aid `Tensor Fusion #include "logging.h" +#include "utils/env_parser.h" namespace horovod { namespace common { -#define WARMUPS 3 -#define CYCLES_PER_SAMPLE 10 -#define BAYES_OPT_MAX_SAMPLES 20 -#define GAUSSIAN_PROCESS_NOISE 0.8 +#define DEFAULT_WARMUPS 3 +#define DEFAULT_STEPS_PER_SAMPLE 10 +#define DEFAULT_BAYES_OPT_MAX_SAMPLES 20 +#define DEFAULT_GAUSSIAN_PROCESS_NOISE 0.8 Eigen::VectorXd CreateVector(double x1, double x2) { Eigen::VectorXd v(2); @@ -38,6 +39,8 @@ Eigen::VectorXd CreateVector(double x1, double x2) { // ParameterManager ParameterManager::ParameterManager() : + warmups_(GetIntEnvOrDefault(HOROVOD_AUTOTUNE_WARMUP_SAMPLES, DEFAULT_WARMUPS)), + steps_per_sample_(GetIntEnvOrDefault(HOROVOD_AUTOTUNE_STEPS_PER_SAMPLE, DEFAULT_STEPS_PER_SAMPLE)), hierarchical_allreduce_(CategoricalParameter(std::vector{false, true})), hierarchical_allgather_(CategoricalParameter(std::vector{false, true})), cache_enabled_(CategoricalParameter(std::vector{false, true})), @@ -45,16 +48,19 @@ ParameterManager::ParameterManager() : std::vector{ { BayesianVariable::fusion_buffer_threshold_mb, std::pair(0, 64) }, { BayesianVariable::cycle_time_ms, std::pair(1, 100) } - }, std::vector{ + }, + std::vector{ CreateVector(4, 5), CreateVector(32, 50), CreateVector(16, 25), CreateVector(8, 10) - })), + }, + GetIntEnvOrDefault(HOROVOD_AUTOTUNE_BAYES_OPT_MAX_SAMPLES, DEFAULT_BAYES_OPT_MAX_SAMPLES), + GetDoubleEnvOrDefault(HOROVOD_AUTOTUNE_GAUSSIAN_PROCESS_NOISE, DEFAULT_GAUSSIAN_PROCESS_NOISE))), parameter_chain_(std::vector{&joint_params_, &hierarchical_allreduce_, &hierarchical_allgather_, &cache_enabled_}), active_(false), - warmup_remaining_(WARMUPS), + warmup_remaining_(warmups_), sample_(0), rank_(-1), root_rank_(0), @@ -80,7 +86,7 @@ void ParameterManager::Initialize(int32_t rank, int32_t root_rank, void ParameterManager::SetAutoTuning(bool active) { if (active != active_) { - warmup_remaining_ = WARMUPS; + warmup_remaining_ = warmups_; } active_ = active; }; @@ -140,8 +146,8 @@ bool ParameterManager::Update(const std::vector& tensor_names, } for (const std::string& tensor_name : tensor_names) { - int32_t cycle = tensor_counts_[tensor_name]++; - if (cycle >= (sample_ + 1) * CYCLES_PER_SAMPLE) { + int32_t step = tensor_counts_[tensor_name]++; + if (step >= (sample_ + 1) * steps_per_sample_) { auto now = std::chrono::steady_clock::now(); double duration = std::chrono::duration_cast(now - last_sample_start_).count(); scores_[sample_] = total_bytes_ / duration; @@ -391,10 +397,14 @@ void ParameterManager::CategoricalParameter::ResetState() { // BayesianParameter ParameterManager::BayesianParameter::BayesianParameter( std::vector variables, - std::vector test_points) : + std::vector test_points, + int max_samples, + double gaussian_process_noise) : TunableParameter(test_points[0]), variables_(variables), test_points_(test_points), + max_samples_(max_samples), + gaussian_process_noise_(gaussian_process_noise), iteration_(0) { ResetBayes(); Reinitialize(FilterTestPoint(0)); @@ -453,7 +463,7 @@ void ParameterManager::BayesianParameter::OnTune(double score, Eigen::VectorXd& } bool ParameterManager::BayesianParameter::IsDoneTuning() const { - return iteration_ > BAYES_OPT_MAX_SAMPLES; + return iteration_ > max_samples_; } void ParameterManager::BayesianParameter::ResetState() { @@ -474,7 +484,7 @@ void ParameterManager::BayesianParameter::ResetBayes() { } } - bayes_.reset(new BayesianOptimization(bounds, GAUSSIAN_PROCESS_NOISE)); + bayes_.reset(new BayesianOptimization(bounds, gaussian_process_noise_)); } Eigen::VectorXd ParameterManager::BayesianParameter::FilterTestPoint(int i) { diff --git a/horovod/common/parameter_manager.h b/horovod/common/parameter_manager.h index cc2bd73b8c..1bb75ecc2e 100644 --- a/horovod/common/parameter_manager.h +++ b/horovod/common/parameter_manager.h @@ -185,7 +185,8 @@ class ParameterManager { // A set of numerical parameters optimized jointly using Bayesian Optimization. class BayesianParameter : public TunableParameter { public: - BayesianParameter(std::vector variables, std::vector test_points); + BayesianParameter(std::vector variables, std::vector test_points, + int max_samples, double gaussian_process_noise); void SetValue(BayesianVariable variable, double value, bool fixed); double Value(BayesianVariable variable) const; @@ -201,6 +202,9 @@ class ParameterManager { std::vector variables_; std::vector test_points_; + int max_samples_; + double gaussian_process_noise_; + uint32_t iteration_; struct EnumClassHash { @@ -215,6 +219,9 @@ class ParameterManager { std::unordered_map index_; }; + int warmups_; + int steps_per_sample_; + CategoricalParameter hierarchical_allreduce_; CategoricalParameter hierarchical_allgather_; CategoricalParameter cache_enabled_; @@ -236,7 +243,6 @@ class ParameterManager { int32_t root_rank_; std::ofstream file_; bool writing_; - }; } // namespace common diff --git a/horovod/common/utils/env_parser.cc b/horovod/common/utils/env_parser.cc index ef4919c5cb..66a3d03465 100644 --- a/horovod/common/utils/env_parser.cc +++ b/horovod/common/utils/env_parser.cc @@ -154,5 +154,10 @@ int GetIntEnvOrDefault(const char* env_variable, int default_value) { return env_value != nullptr ? std::strtol(env_value, nullptr, 10) : default_value; } +double GetDoubleEnvOrDefault(const char* env_variable, double default_value) { + auto env_value = std::getenv(env_variable); + return env_value != nullptr ? std::strtod(env_value, nullptr) : default_value; +} + } // namespace common } diff --git a/horovod/common/utils/env_parser.h b/horovod/common/utils/env_parser.h index 10f363f538..fd5447cc37 100644 --- a/horovod/common/utils/env_parser.h +++ b/horovod/common/utils/env_parser.h @@ -41,6 +41,8 @@ void SetIntFromEnv(const char* env, int& val); int GetIntEnvOrDefault(const char* env_variable, int default_value); +double GetDoubleEnvOrDefault(const char* env_variable, double default_value); + } // namespace common } // namespace horovod diff --git a/horovod/run/common/util/config_parser.py b/horovod/run/common/util/config_parser.py new file mode 100644 index 0000000000..65c38c0cbf --- /dev/null +++ b/horovod/run/common/util/config_parser.py @@ -0,0 +1,160 @@ +# Parameter knobs +HOROVOD_FUSION_THRESHOLD = 'HOROVOD_FUSION_THRESHOLD' +HOROVOD_CYCLE_TIME = 'HOROVOD_CYCLE_TIME' +HOROVOD_CACHE_CAPACITY = 'HOROVOD_CACHE_CAPACITY' +HOROVOD_HIERARCHICAL_ALLREDUCE = 'HOROVOD_HIERARCHICAL_ALLREDUCE' +HOROVOD_HIERARCHICAL_ALLGATHER = 'HOROVOD_HIERARCHICAL_ALLGATHER' + +# Autotune knobs +HOROVOD_AUTOTUNE = 'HOROVOD_AUTOTUNE' +HOROVOD_AUTOTUNE_LOG = 'HOROVOD_AUTOTUNE_LOG' +HOROVOD_AUTOTUNE_WARMUP_SAMPLES = 'HOROVOD_AUTOTUNE_WARMUP_SAMPLES' +HOROVOD_AUTOTUNE_STEPS_PER_SAMPLE = 'HOROVOD_AUTOTUNE_STEPS_PER_SAMPLE' +HOROVOD_AUTOTUNE_BAYES_OPT_MAX_SAMPLES = 'HOROVOD_AUTOTUNE_BAYES_OPT_MAX_SAMPLES' +HOROVOD_AUTOTUNE_GAUSSIAN_PROCESS_NOISE = 'HOROVOD_AUTOTUNE_GAUSSIAN_PROCESS_NOISE' + +# Timeline knobs +HOROVOD_TIMELINE = 'HOROVOD_TIMELINE' +HOROVOD_TIMELINE_MARK_CYCLES = 'HOROVOD_TIMELINE_MARK_CYCLES' + +# Stall check knobs +HOROVOD_STALL_CHECK_DISABLE = 'HOROVOD_STALL_CHECK_DISABLE' +HOROVOD_STALL_CHECK_TIME_SECONDS = 'HOROVOD_STALL_CHECK_TIME_SECONDS' +HOROVOD_STALL_SHUTDOWN_TIME_SECONDS = 'HOROVOD_STALL_SHUTDOWN_TIME_SECONDS' + +# Library options knobs +HOROVOD_MPI_THREADS_DISABLE = 'HOROVOD_MPI_THREADS_DISABLE' +HOROVOD_NUM_NCCL_STREAMS = 'HOROVOD_NUM_NCCL_STREAMS' +HOROVOD_MLSL_BGT_AFFINITY = 'HOROVOD_MLSL_BGT_AFFINITY' + + +def _set_arg_from_config(args, arg_base_name, override_args, config, arg_prefix=''): + arg_name = arg_prefix + arg_base_name + if arg_name in override_args: + return + + value = config.get(arg_base_name) + if value is not None: + setattr(args, arg_name, value) + + +def set_args_from_config(args, config, override_args): + # Controller + controller = config.get('controller') + if controller and not args.use_gloo and not args.use_mpi: + if controller.lower() == 'gloo': + args.use_gloo = True + elif controller.lower() == 'mpi': + args.use_mpi = True + else: + raise ValueError('No such controller supported: {}'.format(controller)) + + # Params + params = config.get('params') + if params: + _set_arg_from_config(args, 'fusion_threshold_mb', override_args, params) + _set_arg_from_config(args, 'cycle_time_ms', override_args, params) + _set_arg_from_config(args, 'cache_capacity', override_args, params) + _set_arg_from_config(args, 'hierarchical_allreduce', override_args, params) + _set_arg_from_config(args, 'hierarchical_allgather', override_args, params) + + # Autotune + autotune = config.get('autotune') + if autotune: + args.autotune = autotune.get('enabled', False) if 'autotune' not in override_args else args.autotune + _set_arg_from_config(args, 'log_file', override_args, autotune, arg_prefix='autotune_') + _set_arg_from_config(args, 'warmup_samples', override_args, autotune, arg_prefix='autotune_') + _set_arg_from_config(args, 'steps_per_sample', override_args, autotune, arg_prefix='autotune_') + _set_arg_from_config(args, 'bayes_opt_max_samples', override_args, autotune, arg_prefix='autotune_') + _set_arg_from_config(args, 'gaussian_process_noise', override_args, autotune, arg_prefix='autotune_') + + # Timeline + timeline = config.get('timeline') + if timeline: + _set_arg_from_config(args, 'filename', override_args, timeline, arg_prefix='timeline_') + _set_arg_from_config(args, 'mark_cycles', override_args, timeline, arg_prefix='timeline_') + + # Stall Check + stall_check = config.get('stall_check') + if stall_check: + args.no_stall_check = not stall_check.get('enabled', True) \ + if 'no_stall_check' not in override_args else args.no_stall_check + _set_arg_from_config(args, 'warning_time_seconds', override_args, stall_check, arg_prefix='stall_check_') + _set_arg_from_config(args, 'shutdown_time_seconds', override_args, stall_check, arg_prefix='stall_check_') + + # Library Options + library_options = config.get('library_options') + if library_options: + _set_arg_from_config(args, 'mpi_threads_disable', override_args, library_options) + _set_arg_from_config(args, 'num_nccl_streams', override_args, library_options) + _set_arg_from_config(args, 'mlsl_bgt_affinity', override_args, library_options) + + +def _validate_arg_nonnegative(args, arg_name): + value = getattr(args, arg_name) + if value < 0: + raise ValueError('{}={} must be >= 0'.format(arg_name, value)) + + +def validate_config_args(args): + _validate_arg_nonnegative(args, 'fusion_threshold_mb') + _validate_arg_nonnegative(args, 'cycle_time_ms') + _validate_arg_nonnegative(args, 'cache_capacity') + _validate_arg_nonnegative(args, 'autotune_warmup_samples') + _validate_arg_nonnegative(args, 'autotune_steps_per_sample') + _validate_arg_nonnegative(args, 'autotune_bayes_opt_max_samples') + + if args.autotune_gaussian_process_noise < 0 or args.autotune_gaussian_process_noise > 1: + raise ValueError('{}={} must be in [0, 1]'.format('autotune_gaussian_process_noise', + args.autotune_gaussian_process_noise)) + + _validate_arg_nonnegative(args, 'stall_check_warning_time_seconds') + _validate_arg_nonnegative(args, 'stall_check_shutdown_time_seconds') + _validate_arg_nonnegative(args, 'num_nccl_streams') + _validate_arg_nonnegative(args, 'mlsl_bgt_affinity') + + +def _add_arg_to_env(env, env_key, arg_value, transform_fn=None): + if arg_value is not None: + value = arg_value + if transform_fn: + value = transform_fn(value) + env[env_key] = str(value) + + +def set_env_from_args(env, args): + def identity(value): + return 1 if value else 0 + + # Params + _add_arg_to_env(env, HOROVOD_FUSION_THRESHOLD, args.fusion_threshold_mb, lambda v: v * 1024 * 1024) + _add_arg_to_env(env, HOROVOD_CYCLE_TIME, args.cycle_time_ms) + _add_arg_to_env(env, HOROVOD_CACHE_CAPACITY, args.cache_capacity) + _add_arg_to_env(env, HOROVOD_HIERARCHICAL_ALLREDUCE, args.hierarchical_allreduce, identity) + _add_arg_to_env(env, HOROVOD_HIERARCHICAL_ALLGATHER, args.hierarchical_allgather, identity) + + # Autotune + if args.autotune: + _add_arg_to_env(env, HOROVOD_AUTOTUNE, args.autotune, identity) + _add_arg_to_env(env, HOROVOD_AUTOTUNE_LOG, args.autotune_log_file) + _add_arg_to_env(env, HOROVOD_AUTOTUNE_WARMUP_SAMPLES, args.autotune_warmup_samples) + _add_arg_to_env(env, HOROVOD_AUTOTUNE_STEPS_PER_SAMPLE, args.autotune_steps_per_sample) + _add_arg_to_env(env, HOROVOD_AUTOTUNE_BAYES_OPT_MAX_SAMPLES, args.autotune_bayes_opt_max_samples) + _add_arg_to_env(env, HOROVOD_AUTOTUNE_GAUSSIAN_PROCESS_NOISE, args.autotune_gaussian_process_noise) + + # Timeline + if args.timeline_filename: + _add_arg_to_env(env, HOROVOD_TIMELINE, args.timeline_filename) + _add_arg_to_env(env, HOROVOD_TIMELINE_MARK_CYCLES, args.timeline_mark_cycles, identity) + + # Stall Check + _add_arg_to_env(env, HOROVOD_STALL_CHECK_DISABLE, args.no_stall_check, identity) + _add_arg_to_env(env, HOROVOD_STALL_CHECK_TIME_SECONDS, args.stall_check_warning_time_seconds) + _add_arg_to_env(env, HOROVOD_STALL_SHUTDOWN_TIME_SECONDS, args.stall_check_shutdown_time_seconds) + + # Library Options + _add_arg_to_env(env, HOROVOD_MPI_THREADS_DISABLE, args.mpi_threads_disable, identity) + _add_arg_to_env(env, HOROVOD_NUM_NCCL_STREAMS, args.num_nccl_streams) + _add_arg_to_env(env, HOROVOD_MLSL_BGT_AFFINITY, args.mlsl_bgt_affinity) + + return env diff --git a/horovod/run/gloo_run.py b/horovod/run/gloo_run.py index 28f9a56e55..4ba119f362 100644 --- a/horovod/run/gloo_run.py +++ b/horovod/run/gloo_run.py @@ -109,7 +109,7 @@ def _allocate(hosts, np): return alloc_list -def _launch_jobs(settings, host_alloc_plan, remote_host_names, _run_command): +def _launch_jobs(settings, env, host_alloc_plan, remote_host_names, _run_command): """ executes the jobs defined by run command on hosts. :param hosts_alloc: list of dict indicating the allocating info. @@ -164,7 +164,6 @@ def set_event_on_sigterm(signum, frame): host_name = alloc_info.hostname - env = os.environ.copy() # TODO: Workaround for over-buffered outputs. Investigate how mpirun avoids this problem. env['PYTHONUNBUFFERED'] = '1' local_command = '{horovod_env} {env} {run_command}' .format( @@ -196,7 +195,7 @@ def set_event_on_sigterm(signum, frame): block_until_all_done=True) -def gloo_run(settings, remote_host_names, common_intfs): +def gloo_run(settings, remote_host_names, common_intfs, env): # allocate processes into slots host_alloc_plan = _allocate(settings.hosts, settings.num_proc) @@ -230,5 +229,5 @@ def gloo_run(settings, remote_host_names, common_intfs): common_intfs=','.join(common_intfs), command=' '.join(quote(par) for par in settings.command))) - _launch_jobs(settings, host_alloc_plan, remote_host_names, run_command) + _launch_jobs(settings, env, host_alloc_plan, remote_host_names, run_command) return diff --git a/horovod/run/mpi_run.py b/horovod/run/mpi_run.py index 846633d46c..943b26676b 100644 --- a/horovod/run/mpi_run.py +++ b/horovod/run/mpi_run.py @@ -52,7 +52,7 @@ def _is_open_mpi_installed(): return False -def mpi_run(settings, common_intfs): +def mpi_run(settings, common_intfs, env): if not _is_open_mpi_installed(): raise Exception( 'horovodrun convenience script does not find an installed OpenMPI.\n\n' @@ -64,9 +64,6 @@ def mpi_run(settings, common_intfs): ' MPI distribution (usually mpirun, srun, or jsrun).\n' '3. Use built-in gloo option (horovodrun --gloo ...).') - # Pass all the env variables to the mpirun command. - env = os.environ.copy() - ssh_port_arg = '-mca plm_rsh_args \"-p {ssh_port}\"'.format( ssh_port=settings.ssh_port) if settings.ssh_port else '' @@ -79,6 +76,7 @@ def mpi_run(settings, common_intfs): nccl_socket_intf_arg = '-x NCCL_SOCKET_IFNAME={common_intfs}'.format( common_intfs=','.join(common_intfs)) if common_intfs else '' + # Pass all the env variables to the mpirun command. mpirun_command = ( 'mpirun --allow-run-as-root --tag-output ' '-np {num_proc} {hosts_arg} ' diff --git a/horovod/run/run.py b/horovod/run/run.py index c2c1c04ad3..9bfda1b970 100644 --- a/horovod/run/run.py +++ b/horovod/run/run.py @@ -18,7 +18,6 @@ import hashlib import os import sys -import six import re import textwrap try: @@ -27,10 +26,13 @@ from pipes import quote import horovod +import six +import yaml + from horovod.common.util import (extension_available, gloo_built, mpi_built, nccl_built, ddl_built, mlsl_built) -from horovod.run.common.util import codec, safe_shell_exec, timeout, secret +from horovod.run.common.util import codec, config_parser, safe_shell_exec, timeout, secret from horovod.run.common.util import settings as hvd_settings from horovod.run.driver import driver_service from horovod.run.task import task_service @@ -258,6 +260,7 @@ def _driver_fn(all_host_names, local_host_names, settings): finally: driver.shutdown() + def check_build(verbose): def get_check(value): return 'X' if value else ' ' @@ -306,7 +309,58 @@ def __call__(self, parser, args, values, option_string=None): return CheckBuildAction +def make_override_action(override_args): + class StoreOverrideAction(argparse.Action): + def __init__(self, + option_strings, + dest, + default=False, + type=None, + required=False, + help=None): + super(StoreOverrideAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=1, + default=default, + type=type, + required=required, + help=help) + + def __call__(self, parser, args, values, option_string=None): + override_args.add(self.dest) + setattr(args, self.dest, values[0]) + + return StoreOverrideAction + + +def make_override_true_action(override_args): + class StoreOverrideTrueAction(argparse.Action): + def __init__(self, + option_strings, + dest, + default=False, + required=False, + help=None): + super(StoreOverrideTrueAction, self).__init__( + option_strings=option_strings, + dest=dest, + const=True, + nargs=0, + default=default, + required=required, + help=help) + + def __call__(self, parser, args, values, option_string=None): + override_args.add(self.dest) + setattr(args, self.dest, self.const) + + return StoreOverrideTrueAction + + def parse_args(): + override_args = set() + parser = argparse.ArgumentParser(description='Horovod Runner') parser.add_argument('-v', '--version', action='version', version=horovod.__version__, @@ -347,6 +401,99 @@ def parse_args(): parser.add_argument('command', nargs=argparse.REMAINDER, help='Command to be executed.') + parser.add_argument('--config-file', action='store', dest='config_file', + help='Path to YAML file containing runtime parameter configuration for Horovod. ' + 'Note that this will override any command line arguments provided before ' + 'this argument, and will be overridden by any arguments that come after it.') + + group_params = parser.add_argument_group('tuneable parameter arguments') + group_params.add_argument('--fusion-threshold-mb', action=make_override_action(override_args), type=int, default=64, + help='Fusion buffer threshold in MB. This is the maximum amount of ' + 'tensor data that can be fused together into a single batch ' + 'during allreduce / allgather. Setting 0 disables tensor fusion. ' + '(default: %(default)s)') + group_params.add_argument('--cycle-time-ms', action=make_override_action(override_args), type=float, default=5, + help='Cycle time in ms. This is the delay between each tensor fusion ' + 'cycle. The larger the cycle time, the more batching, but the ' + 'greater latency between each allreduce / allgather operations. ' + '(default: %(default)s)') + group_params.add_argument('--cache-capacity', action=make_override_action(override_args), type=int, default=1024, + help='Maximum number of tensor names that will be cached to reduce amount ' + 'of coordination required between workers before performing allreduce / ' + 'allgather. (default: %(default)s)') + group_params.add_argument('--hierarchical-allreduce', action=make_override_true_action(override_args), + help='Perform hierarchical allreduce between workers instead of ring allreduce. ' + 'Hierarchical allreduce performs a local allreduce / gather within a host, then ' + 'a parallel cross allreduce between equal local ranks across workers, and ' + 'finally a local gather.') + group_params.add_argument('--hierarchical-allgather', action=make_override_true_action(override_args), + help='Perform hierarchical allgather between workers instead of ring allgather. See ' + 'hierarchical allreduce for algorithm details.') + + group_autotune = parser.add_argument_group('autotune arguments') + group_autotune.add_argument('--autotune', action=make_override_true_action(override_args), + help='Perform autotuning to select parameter argument values that maximimize ' + 'throughput for allreduce / allgather. Any parameter explicitly set will ' + 'be held constant during tuning.') + group_autotune.add_argument('--autotune-log-file', action=make_override_action(override_args), + help='Comma-separated log of trials containing each hyperparameter and the ' + 'score of the trial. The last row will always contain the best value ' + 'found.') + group_autotune.add_argument('--autotune-warmup-samples', action=make_override_action(override_args), + type=int, default=3, + help='Number of samples to discard before beginning the optimization process ' + 'during autotuning. Performance during the first few batches can be ' + 'affected by initialization and cache warmups. (default: %(default)s)') + group_autotune.add_argument('--autotune-steps-per-sample', action=make_override_action(override_args), + type=int, default=10, + help='Number of steps (approximate) to record before observing a sample. The sample ' + 'score is defined to be the median score over all batches within the sample. The ' + 'more batches per sample, the less variance in sample scores, but the longer ' + 'autotuning will take. (default: %(default)s)') + group_autotune.add_argument('--autotune-bayes-opt-max-samples', action=make_override_action(override_args), + type=int, default=20, + help='Maximum number of samples to collect for each Bayesian optimization process. ' + '(default: %(default)s)') + group_autotune.add_argument('--autotune-gaussian-process-noise', action=make_override_action(override_args), + type=float, default=0.8, + help='Regularization value [0, 1] applied to account for noise in samples. ' + '(default: %(default)s)') + + group_timeline = parser.add_argument_group('timeline arguments') + group_timeline.add_argument('--timeline-filename', action=make_override_action(override_args), + help='JSON file containing timeline of Horovod events used for debugging ' + 'performance. If this is provided, timeline events will be recorded, ' + 'which can have a negative impact on training performance.') + group_timeline.add_argument('--timeline-mark-cycles', action=make_override_true_action(override_args), + help='Mark cycles on the timeline. Only enabled if the timeline filename ' + 'is provided.') + + group_stall_check = parser.add_argument_group('stall check arguments') + group_stall_check.add_argument('--no-stall-check', action=make_override_true_action(override_args), + help='Disable the stall check. The stall check will log a warning when workers ' + 'have stalled waiting for other ranks to submit tensors.') + group_stall_check.add_argument('--stall-check-warning-time-seconds', action=make_override_action(override_args), + type=int, default=60, + help='Seconds until the stall warning is logged to stderr. (default: %(default)s)') + group_stall_check.add_argument('--stall-check-shutdown-time-seconds', action=make_override_action(override_args), + type=int, default=0, + help='Seconds until Horovod is shutdown due to stall. Shutdown will only take ' + 'place if this value is greater than the warning time. (default: %(default)s)') + + group_library_options = parser.add_argument_group('library arguments') + group_library_options.add_argument('--mpi-threads-disable', action=make_override_true_action(override_args), + help='Disable MPI threading support. Only applies when running in MPI ' + 'mode. In some cases, multi-threaded MPI can slow down other components, ' + 'but is necessary if you wish to run mpi4py on top of Horovod.') + group_library_options.add_argument('--num-nccl-streams', action=make_override_action(override_args), + type=int, default=1, + help='Number of NCCL streams. Only applies when running with NCCL support. ' + '(default: %(default)s)') + group_library_options.add_argument('--mlsl-bgt-affinity', action=make_override_action(override_args), + type=int, default=0, + help='MLSL background thread affinity. Only applies when running with MLSL ' + 'support. (default: %(default)s)') + group_hosts_parent = parser.add_argument_group('host arguments') group_hosts = group_hosts_parent.add_mutually_exclusive_group() group_hosts.add_argument('-H', '--hosts', action='store', dest='hosts', @@ -369,7 +516,15 @@ def parse_args(): help='Run Horovod using the MPI controller. This will ' 'be the default if Horovod was built with MPI support.') - return parser.parse_args() + args = parser.parse_args() + + if args.config_file: + with open(args.config_file, 'r') as f: + config = yaml.load(f, Loader=yaml.FullLoader) + config_parser.set_args_from_config(args, config, override_args) + config_parser.validate_config_args(args) + + return args def parse_host_files(filename): @@ -493,21 +648,24 @@ def run(): if settings.verbose >= 2: print('Local interface found ' + ' '.join(common_intfs)) + env = os.environ.copy() + config_parser.set_env_from_args(env, args) + if args.use_gloo: if not gloo_built(verbose=(settings.verbose >= 2)): raise ValueError('Gloo support has not been built. If this is not expected, ensure CMake is installed ' 'and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error.') - gloo_run(settings, remote_host_names, common_intfs) + gloo_run(settings, remote_host_names, common_intfs, env) elif args.use_mpi: if not mpi_built(verbose=(settings.verbose >= 2)): raise ValueError('MPI support has not been built. If this is not expected, ensure MPI is installed ' 'and reinstall Horovod with HOROVOD_WITH_MPI=1 to debug the build error.') - mpi_run(settings, common_intfs) + mpi_run(settings, common_intfs, env) else: if mpi_built(verbose=(settings.verbose >= 2)): - mpi_run(settings, common_intfs) + mpi_run(settings, common_intfs, env) elif gloo_built(verbose=(settings.verbose >= 2)): - gloo_run(settings, remote_host_names, common_intfs) + gloo_run(settings, remote_host_names, common_intfs, env) else: raise ValueError('Neither MPI nor Gloo support has been built. Try reinstalling Horovod ensuring that ' 'either MPI is installed (MPI) or CMake is installed (Gloo).') diff --git a/setup.py b/setup.py index 28580d830c..edc5053ece 100644 --- a/setup.py +++ b/setup.py @@ -1418,7 +1418,7 @@ def build_extensions(self): 'None of TensorFlow, PyTorch, or MXNet plugins were built. See errors above.') -require_list = ['cloudpickle', 'psutil', 'six'] +require_list = ['cloudpickle', 'psutil', 'pyyaml', 'six'] # Skip cffi if pytorch extension explicitly disabled if not os.environ.get('HOROVOD_WITHOUT_PYTORCH'): diff --git a/test/data/config.test.yaml b/test/data/config.test.yaml new file mode 100644 index 0000000000..afe2277ab2 --- /dev/null +++ b/test/data/config.test.yaml @@ -0,0 +1,30 @@ +controller: gloo + +params: + fusion_threshold_mb: 32 + cycle_time_ms: 10 + cache_capacity: 2048 + hierarchical_allreduce: true + hierarchical_allgather: true + +autotune: + enabled: true + log_file: 'horovod_autotune_log.txt' + warmup_samples: 5 + steps_per_sample: 20 + bayes_opt_max_samples: 50 + gaussian_process_noise: 0.9 + +timeline: + filename: 'horovod_timeline.json' + mark_cycles: true + +stall_check: + enabled: true + warning_time_seconds: 120 + shutdown_time_seconds: 240 + +library_options: + mpi_threads_disable: true + num_nccl_streams: 2 + mlsl_bgt_affinity: 1 diff --git a/test/test_run.py b/test/test_run.py new file mode 100644 index 0000000000..28c376b984 --- /dev/null +++ b/test/test_run.py @@ -0,0 +1,184 @@ +# Copyright 2019 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import os +import sys +import unittest +import warnings + +import pytest + +from horovod.run import run +from horovod.run.common.util import config_parser + + +@contextlib.contextmanager +def override_args(tool=None, *args): + old = sys.argv[:] + try: + if tool: + sys.argv[0] = tool + sys.argv[1:] = args + yield + finally: + sys.argv = old + + +class RunTests(unittest.TestCase): + """ + Tests for horovod.run. + """ + + def __init__(self, *args, **kwargs): + super(RunTests, self).__init__(*args, **kwargs) + warnings.simplefilter('module') + + def test_params_args(self): + with override_args('horovodrun', '-np', '2', + '--fusion-threshold-mb', '10', + '--cycle-time-ms', '20', + '--cache-capacity', '512', + '--hierarchical-allreduce', + '--hierarchical-allgather'): + args = run.parse_args() + env = {} + config_parser.set_env_from_args(env, args) + + self.assertEqual(env[config_parser.HOROVOD_FUSION_THRESHOLD], str(10 * 1024 * 1024)) + self.assertEqual(env[config_parser.HOROVOD_CYCLE_TIME], '20.0') + self.assertEqual(env[config_parser.HOROVOD_CACHE_CAPACITY], '512') + self.assertEqual(env[config_parser.HOROVOD_HIERARCHICAL_ALLREDUCE], '1') + self.assertEqual(env[config_parser.HOROVOD_HIERARCHICAL_ALLGATHER], '1') + + def test_autotune_args(self): + with override_args('horovodrun', '-np', '2', + '--autotune', + '--autotune-log-file', '/tmp/autotune.txt', + '--autotune-warmup-samples', '1', + '--autotune-steps-per-sample', '5', + '--autotune-bayes-opt-max-samples', '10', + '--autotune-gaussian-process-noise', '0.2'): + args = run.parse_args() + env = {} + config_parser.set_env_from_args(env, args) + + self.assertEqual(env[config_parser.HOROVOD_AUTOTUNE], '1') + self.assertEqual(env[config_parser.HOROVOD_AUTOTUNE_LOG], '/tmp/autotune.txt') + self.assertEqual(env[config_parser.HOROVOD_AUTOTUNE_WARMUP_SAMPLES], '1') + self.assertEqual(env[config_parser.HOROVOD_AUTOTUNE_STEPS_PER_SAMPLE], '5') + self.assertEqual(env[config_parser.HOROVOD_AUTOTUNE_BAYES_OPT_MAX_SAMPLES], '10') + self.assertEqual(env[config_parser.HOROVOD_AUTOTUNE_GAUSSIAN_PROCESS_NOISE], '0.2') + + def test_timeline_args(self): + with override_args('horovodrun', '-np', '2', + '--timeline-filename', '/tmp/timeline.json', + '--timeline-mark-cycles'): + args = run.parse_args() + env = {} + config_parser.set_env_from_args(env, args) + + self.assertEqual(env[config_parser.HOROVOD_TIMELINE], '/tmp/timeline.json') + self.assertEqual(env[config_parser.HOROVOD_TIMELINE_MARK_CYCLES], '1') + + def test_stall_check_args(self): + with override_args('horovodrun', '-np', '2', + '--no-stall-check'): + args = run.parse_args() + env = {} + config_parser.set_env_from_args(env, args) + + self.assertEqual(env[config_parser.HOROVOD_STALL_CHECK_DISABLE], '1') + + with override_args('horovodrun', '-np', '2', + '--stall-check-warning-time-seconds', '10', + '--stall-check-shutdown-time-seconds', '20'): + args = run.parse_args() + env = {} + config_parser.set_env_from_args(env, args) + + self.assertEqual(env[config_parser.HOROVOD_STALL_CHECK_DISABLE], '0') + self.assertEqual(env[config_parser.HOROVOD_STALL_CHECK_TIME_SECONDS], '10') + self.assertEqual(env[config_parser.HOROVOD_STALL_SHUTDOWN_TIME_SECONDS], '20') + + def test_library_args(self): + with override_args('horovodrun', '-np', '2', + '--mpi-threads-disable', + '--num-nccl-streams', '2', + '--mlsl-bgt-affinity', '1'): + args = run.parse_args() + env = {} + config_parser.set_env_from_args(env, args) + + self.assertEqual(env[config_parser.HOROVOD_MPI_THREADS_DISABLE], '1') + self.assertEqual(env[config_parser.HOROVOD_NUM_NCCL_STREAMS], '2') + self.assertEqual(env[config_parser.HOROVOD_MLSL_BGT_AFFINITY], '1') + + def test_config_file(self): + config_filename = os.path.join(os.path.dirname(__file__), 'data/config.test.yaml') + with override_args('horovodrun', '-np', '2', + '--config-file', config_filename): + args = run.parse_args() + + self.assertTrue(args.use_gloo) + + # Params + self.assertEqual(args.fusion_threshold_mb, 32) + self.assertEqual(args.cycle_time_ms, 10) + self.assertEqual(args.cache_capacity, 2048) + self.assertTrue(args.hierarchical_allreduce) + self.assertTrue(args.hierarchical_allgather) + + # Autotune + self.assertTrue(args.autotune) + self.assertEqual(args.autotune_log_file, 'horovod_autotune_log.txt') + self.assertEqual(args.autotune_warmup_samples, 5) + self.assertEqual(args.autotune_steps_per_sample, 20) + self.assertEqual(args.autotune_bayes_opt_max_samples, 50) + self.assertEqual(args.autotune_gaussian_process_noise, 0.9) + + # Timeline + self.assertEqual(args.timeline_filename, 'horovod_timeline.json') + self.assertTrue(args.timeline_mark_cycles) + + # Stall Check + self.assertFalse(args.no_stall_check) + self.assertEqual(args.stall_check_warning_time_seconds, 120) + self.assertEqual(args.stall_check_shutdown_time_seconds, 240) + + # Library Options + self.assertTrue(args.mpi_threads_disable) + self.assertEqual(args.num_nccl_streams, 2) + self.assertEqual(args.mlsl_bgt_affinity, 1) + + def test_config_file_override_args(self): + config_filename = os.path.join(os.path.dirname(__file__), 'data/config.test.yaml') + with override_args('horovodrun', '-np', '2', + '--fusion-threshold-mb', '128', + '--config-file', config_filename, + '--cycle-time-ms', '20',): + args = run.parse_args() + self.assertEqual(args.fusion_threshold_mb, 128) + self.assertEqual(args.cycle_time_ms, 20) + + def test_validate_config_args(self): + with override_args('horovodrun', '-np', '2', + '--fusion-threshold-mb', '-1'): + with pytest.raises(ValueError): + run.parse_args()