diff --git a/doc/source/optimization.rst b/doc/source/optimization.rst index 59219ad51..2b8dd8987 100644 --- a/doc/source/optimization.rst +++ b/doc/source/optimization.rst @@ -25,6 +25,7 @@ the ``strategy=`` optional argument of ``tune_kernel()``. Kernel Tuner currently * "pso" particle swarm optimization * "random_sample" takes a random sample of the search space * "simulated_annealing" simulated annealing strategy + * "ensemble" ensemble strategy Most strategies have some mechanism built in to detect when to stop tuning, which may be controlled through specific parameters that can be passed to the strategies using the ``strategy_options=`` optional argument of ``tune_kernel()``. You diff --git a/kernel_tuner/interface.py b/kernel_tuner/interface.py index 97ae22848..0be907737 100644 --- a/kernel_tuner/interface.py +++ b/kernel_tuner/interface.py @@ -34,6 +34,7 @@ from kernel_tuner.integration import get_objective_defaults from kernel_tuner.runners.sequential import SequentialRunner from kernel_tuner.runners.simulation import SimulationRunner +from kernel_tuner.runners.parallel import ParallelRunner from kernel_tuner.searchspace import Searchspace try: @@ -57,6 +58,7 @@ pso, random_sample, simulated_annealing, + ensemble ) strategy_map = { @@ -75,6 +77,7 @@ "simulated_annealing": simulated_annealing, "firefly_algorithm": firefly_algorithm, "bayes_opt": bayes_opt, + "ensemble": ensemble, } @@ -384,6 +387,7 @@ def __deepcopy__(self, _): * "pso" particle swarm optimization * "random_sample" takes a random sample of the search space * "simulated_annealing" simulated annealing strategy + * "ensemble" Ensemble Strategy Strategy-specific parameters and options are explained under strategy_options. @@ -463,6 +467,7 @@ def __deepcopy__(self, _): ), ("metrics", ("specifies user-defined metrics, please see :ref:`metrics`.", "dict")), ("simulation_mode", ("Simulate an auto-tuning search from an existing cachefile", "bool")), + ("parallel_mode", ("Run the auto-tuning on multiple devices (brute-force execution)", "bool")), ("observers", ("""A list of Observers to use during tuning, please see :ref:`observers`.""", "list")), ] ) @@ -574,6 +579,7 @@ def tune_kernel( cache=None, metrics=None, simulation_mode=False, + parallel_mode=False, observers=None, objective=None, objective_higher_is_better=None, @@ -611,6 +617,8 @@ def tune_kernel( tuning_options["max_fevals"] = strategy_options["max_fevals"] if strategy_options and "time_limit" in strategy_options: tuning_options["time_limit"] = strategy_options["time_limit"] + if strategy_options and "num_gpus" in strategy_options: + tuning_options["num_gpus"] = strategy_options["num_gpus"] logging.debug("tune_kernel called") logging.debug("kernel_options: %s", util.get_config_string(kernel_options)) @@ -650,9 +658,13 @@ def tune_kernel( strategy = brute_force # select the runner for this job based on input - selected_runner = SimulationRunner if simulation_mode else SequentialRunner + selected_runner = SimulationRunner if simulation_mode else (ParallelRunner if parallel_mode else SequentialRunner) tuning_options.simulated_time = 0 - runner = selected_runner(kernelsource, kernel_options, device_options, iterations, observers) + if parallel_mode: + num_gpus = tuning_options['num_gpus'] if 'num_gpus' in tuning_options else None + runner = selected_runner(kernelsource, kernel_options, device_options, iterations, observers, num_gpus=num_gpus) + else: + runner = selected_runner(kernelsource, kernel_options, device_options, iterations, observers) # the user-specified function may or may not have an optional atol argument; # we normalize it so that it always accepts atol. diff --git a/kernel_tuner/observers/nvml.py b/kernel_tuner/observers/nvml.py index 0fd812a34..bc93a275b 100644 --- a/kernel_tuner/observers/nvml.py +++ b/kernel_tuner/observers/nvml.py @@ -315,6 +315,15 @@ def __init__( continous_duration=1, ): """Create an NVMLObserver.""" + # needed for re-initializing observer on ray actor + self.init_arguments = { + "observables": observables, + "device": device, + "save_all": save_all, + "nvidia_smi_fallback": nvidia_smi_fallback, + "use_locked_clocks": use_locked_clocks, + "continous_duration": continous_duration + } if nvidia_smi_fallback: self.nvml = nvml( device, @@ -424,6 +433,14 @@ def __init__(self, observables, parent, nvml_instance, continous_duration=1): self.parent = parent self.nvml = nvml_instance + # needed for re-initializing observer on ray actor + self.init_arguments = { + "observables": observables, + "parent": parent, + "nvml_instance": nvml_instance, + "continous_duration": continous_duration + } + supported = ["power_readings", "nvml_power", "nvml_energy"] for obs in observables: if obs not in supported: diff --git a/kernel_tuner/observers/pmt.py b/kernel_tuner/observers/pmt.py index bb1d76bd1..75c924b30 100644 --- a/kernel_tuner/observers/pmt.py +++ b/kernel_tuner/observers/pmt.py @@ -33,6 +33,11 @@ class PMTObserver(BenchmarkObserver): def __init__(self, observable=None): if not pmt: raise ImportError("could not import pmt") + + # needed for re-initializing observer on ray actor + self.init_arguments = { + "observable": observable + } # User specifices a dictonary of platforms and corresponding device if type(observable) is dict: diff --git a/kernel_tuner/observers/powersensor.py b/kernel_tuner/observers/powersensor.py index 6d07e8977..c946f9d44 100644 --- a/kernel_tuner/observers/powersensor.py +++ b/kernel_tuner/observers/powersensor.py @@ -27,6 +27,12 @@ class PowerSensorObserver(BenchmarkObserver): def __init__(self, observables=None, device=None): if not powersensor: raise ImportError("could not import powersensor") + + # needed for re-initializing observer on ray actor + self.init_arguments = { + "observables": observables, + "device": device + } supported = ["ps_energy", "ps_power"] for obs in observables: diff --git a/kernel_tuner/runners/parallel.py b/kernel_tuner/runners/parallel.py new file mode 100644 index 000000000..a7f2d95fc --- /dev/null +++ b/kernel_tuner/runners/parallel.py @@ -0,0 +1,306 @@ +import ray +import sys +from ray.util.actor_pool import ActorPool +from time import perf_counter +from collections import deque +import copy + +from kernel_tuner.core import DeviceInterface +from kernel_tuner.runners.runner import Runner +from kernel_tuner.util import get_num_devices, GPUTypeMismatchError +from kernel_tuner.runners.ray.cache_manager import CacheManager +from kernel_tuner.strategies.common import create_actor_on_device, initialize_ray + +class ParallelRunner(Runner): + """ParallelRunner is used for tuning with multiple processes/threads using Ray for distributed computing.""" + + def __init__(self, kernel_source, kernel_options, device_options, iterations, observers, + num_gpus=None, cache_manager=None, actors=None, simulation_mode=False): + """Instantiate the ParallelRunner. + + :param kernel_source: The kernel source + :type kernel_source: kernel_tuner.core.KernelSource + + :param kernel_options: A dictionary with all options for the kernel. + :type kernel_options: kernel_tuner.interface.Options + + :param device_options: A dictionary with all options for the device + on which the kernel should be tuned. + :type device_options: kernel_tuner.interface.Options + + :param iterations: The number of iterations used for benchmarking + each kernel instance. + :type iterations: int + + :param observers: List of observers. + :type observers: list + + :param num_gpus: Number of GPUs to use. Defaults to None. + :type num_gpus: int, optional + + :param cache_manager: Cache manager instance. Defaults to None. + :type cache_manager: kernel_tuner.runners.ray.cache_manager.CacheManager, optional + + :param actors: List of pre-initialized actors. Defaults to None. + :type actors: list, optional + + :param simulation_mode: Flag to indicate simulation mode. Defaults to False. + :type simulation_mode: bool, optional + """ + self.dev = DeviceInterface(kernel_source, iterations=iterations, observers=observers, **device_options) if not simulation_mode else None + self.kernel_source = kernel_source + self.simulation_mode = simulation_mode + self.kernel_options = kernel_options + self.start_time = perf_counter() + self.last_strategy_start_time = self.start_time + self.observers = observers + self.iterations = iterations + self.device_options = device_options + self.cache_manager = cache_manager + self.num_gpus = num_gpus + self.actors = actors + + initialize_ray() + + if num_gpus is None: + self.num_gpus = get_num_devices(simulation_mode) + + # So we know the number of GPUs in the cache file + if not simulation_mode: + self.dev.name = [self.dev.name] * self.num_gpus + + def get_environment(self, tuning_options): + return self.dev.get_environment() + + def run(self, parameter_space=None, tuning_options=None, ensemble=None, searchspace=None, cache_manager=None): + """Run the tuning process with parallel execution. + + :param parameter_space: The parameter space to explore. + :type parameter_space: iterable + + :param tuning_options: Tuning options. Defaults to None. + :type tuning_options: dict, optional + + :param ensemble: List of strategies for ensemble. Defaults to None. + :type ensemble: list, optional + + :param searchspace: The search space to explore. Defaults to None. + :type searchspace: kernel_tuner.searchspace.Searchspace, optional + + :param cache_manager: Cache manager instance. Defaults to None. + :type cache_manager: kernel_tuner.runners.ray.cache_manager.CacheManager, optional + + :returns: Results of the tuning process. + :rtype: list of dict + """ + if tuning_options is None: #HACK as tuning_options can't be the first argument and parameter_space needs to be a default argument + raise ValueError("tuning_options cannot be None") + + # Create RemoteActor instances + if self.actors is None: + runner_attributes = [self.kernel_source, self.kernel_options, self.device_options, self.iterations, self.observers] + self.actors = [create_actor_on_device(*runner_attributes, id=_id, cache_manager=self.cache_manager, simulation_mode=self.simulation_mode) for _id in range(self.num_gpus)] + + # Check if all GPUs are of the same type + if not self.simulation_mode and not self._check_gpus_equals(): + raise GPUTypeMismatchError(f"Different GPU types found") + + if self.cache_manager is None: + if cache_manager is None: + cache_manager = CacheManager.remote(tuning_options.cache, tuning_options.cachefile) + self.cache_manager = cache_manager + + # set the cache manager for each actor. Can't be done in constructor because we do not always yet have the tuning_options + for actor in self.actors: + actor.set_cache_manager.remote(self.cache_manager) + + # Some observers can't be pickled + run_tuning_options = copy.deepcopy(tuning_options) + run_tuning_options['observers'] = None + # Determine what type of parallelism and run appropriately + if parameter_space and not ensemble and not searchspace: + results, tuning_options_list = self.parallel_function_evaluation(run_tuning_options, parameter_space) + elif ensemble and searchspace and not parameter_space: + results, tuning_options_list = self.multi_strategy_parallel_execution(ensemble, run_tuning_options, searchspace) + else: + raise ValueError("Invalid arguments to parallel runner run method") + + # Update tuning options + # NOTE: tuning options won't have the state of the observers created in the actors as they can't be pickled + cache, cachefile = ray.get(self.cache_manager.get_cache.remote()) + tuning_options.cache = cache + tuning_options.cachefile = cachefile + if self.simulation_mode: + tuning_options.simulated_time += self._calculate_simulated_time(tuning_options_list) + + return results + + def multi_strategy_parallel_execution(self, ensemble, tuning_options, searchspace): + """Runs strategies from the ensemble in parallel using distributed actors, + manages dynamic task allocation, and collects results. + + :param ensemble: List of strategies to execute. + :type ensemble: list + + :param tuning_options: Tuning options. + :type tuning_options: dict + + :param searchspace: Search space to explore. + :type searchspace: kernel_tuner.searchspace.Searchspace + + :returns: Processed results and tuning options list. + :rtype: tuple + """ + ensemble_queue = deque(ensemble) + pending_tasks = {} + all_results = [] + options = tuning_options.strategy_options + max_feval = options["max_fevals"] + num_strategies = len(ensemble) + + # distributing feval to all strategies + base_eval_per_strategy = max_feval // num_strategies + remainder = max_feval % num_strategies + evaluations_per_strategy = [base_eval_per_strategy] * num_strategies + for i in range(remainder): + evaluations_per_strategy[i] += 1 + + # Ensure we always have a list of search spaces + searchspaces = [searchspace] * num_strategies + searchspaces = deque(searchspaces) + + # Start initial tasks for each actor + for actor in self.actors: + strategy = ensemble_queue.popleft() + searchspace = searchspaces.popleft() + remote_tuning_options = self._setup_tuning_options(tuning_options, evaluations_per_strategy) + task = actor.execute.remote(strategy=strategy, searchspace=searchspace, tuning_options=remote_tuning_options) + pending_tasks[task] = actor + + # Manage task completion and redistribution + while pending_tasks: + done_ids, _ = ray.wait(list(pending_tasks.keys()), num_returns=1) + for done_id in done_ids: + result = ray.get(done_id) + all_results.append(result) + actor = pending_tasks.pop(done_id) + + # Reassign actors if strategies remain + if ensemble_queue: + strategy = ensemble_queue.popleft() + searchspace = searchspaces.popleft() + remote_tuning_options = self._setup_tuning_options(tuning_options, evaluations_per_strategy) + task = actor.execute.remote(strategy=strategy, searchspace=searchspace, tuning_options=remote_tuning_options) + pending_tasks[task] = actor + + # Process results + results, tuning_options_list = self._process_results_ensemble(all_results) + + return results, tuning_options_list + + + def _setup_tuning_options(self, tuning_options, evaluations_per_strategy): + """Set up tuning options for each strategy in the ensemble. + + :param tuning_options: Original tuning options. + :type tuning_options: dict + + :param evaluations_per_strategy: Number of evaluations per strategy. + :type evaluations_per_strategy: list + + :returns: Modified tuning options. + :rtype: dict + """ + new_tuning_options = copy.deepcopy(tuning_options) + new_tuning_options.strategy_options["max_fevals"] = evaluations_per_strategy.pop(0) + # the stop criterion uses the max feval in tuning options for some reason + new_tuning_options["max_fevals"] = new_tuning_options.strategy_options["max_fevals"] + return new_tuning_options + + def _process_results_ensemble(self, all_results): + """Process the results from the ensemble execution. + + :param all_results: List of results from all strategies. + :type all_results: list + + :returns: Processed results and tuning options list. + :rtype: tuple + """ + results = [] + tuning_options_list = [] + + for (strategy_results, tuning_options) in all_results: + results.extend(strategy_results) + tuning_options_list.append(tuning_options) + + return results, tuning_options_list + + + def parallel_function_evaluation(self, tuning_options, parameter_space): + """Perform parallel function evaluation. + + :param tuning_options: Tuning options. + :type tuning_options: dict + + :param parameter_space: Parameter space to explore. + :type parameter_space: list + + :returns: Results and tuning options list. + :rtype: tuple + """ + # Create a pool of RemoteActor actors + self.actor_pool = ActorPool(self.actors) + # Distribute execution of the `execute` method across the actor pool with varying parameters and tuning options, collecting the results asynchronously. + all_results = list(self.actor_pool.map_unordered(lambda a, v: a.execute.remote(tuning_options, element=v), parameter_space)) + results = [x[0] for x in all_results] + tuning_options_list = [x[1] for x in all_results] + return results, tuning_options_list + + def _process_results(self, all_results, searchspace): + """ + Process the results and remove duplicates based on the searchspace. + """ + unique_configs = set() + final_results = [] + + for (strategy_results, tuning_options) in all_results: + for new_result in strategy_results: + config_signature = tuple(new_result[key] for key in searchspace.tune_params) + if config_signature not in unique_configs: + final_results.append(new_result) + unique_configs.add(config_signature) + return final_results + + def _calculate_simulated_time(self, tuning_options_list): + """ + Calculate the maximum simulated time from the list of tuning options. + """ + simulated_times = [] + for tuning_options in tuning_options_list: + simulated_times.append(tuning_options.simulated_time) + return max(simulated_times) + + def _check_gpus_equals(self): + """ + Check if all GPUs are of the same type. + """ + gpu_types = [] + env_refs = [actor.get_environment.remote() for actor in self.actors] + environments = ray.get(env_refs) + for env in environments: + gpu_types.append(env["device_name"]) + if len(set(gpu_types)) == 1: + print(f"Running on {len(gpu_types)} {gpu_types[0]}", file=sys.stderr) + return True + else: + return False + + def clean_up_ray(self): + """ + Clean up Ray actors and cache manager. + """ + if self.actors is not None: + for actor in self.actors: + ray.kill(actor) + if self.cache_manager is not None: + ray.kill(self.cache_manager) \ No newline at end of file diff --git a/kernel_tuner/runners/ray/cache_manager.py b/kernel_tuner/runners/ray/cache_manager.py new file mode 100644 index 000000000..9aeb56855 --- /dev/null +++ b/kernel_tuner/runners/ray/cache_manager.py @@ -0,0 +1,23 @@ +import ray + +from kernel_tuner.util import store_cache + +@ray.remote(num_cpus=1) +class CacheManager: + def __init__(self, cache, cachefile): + from kernel_tuner.interface import Options # importing here due to circular import + self.tuning_options = Options({'cache': cache, 'cachefile': cachefile}) + + def store(self, key, params): + store_cache(key, params, self.tuning_options) + + def check_and_retrieve(self, key): + """Checks if a result exists for the given key and returns it if found.""" + if self.tuning_options['cache']: + return self.tuning_options['cache'].get(key, None) + else: + return None + + def get_cache(self): + """Returns the current tuning options.""" + return self.tuning_options['cache'], self.tuning_options['cachefile'] diff --git a/kernel_tuner/runners/ray/remote_actor.py b/kernel_tuner/runners/ray/remote_actor.py new file mode 100644 index 000000000..c0743ad22 --- /dev/null +++ b/kernel_tuner/runners/ray/remote_actor.py @@ -0,0 +1,82 @@ +import ray + +from kernel_tuner.runners.sequential import SequentialRunner +from kernel_tuner.runners.simulation import SimulationRunner +from kernel_tuner.core import DeviceInterface +from kernel_tuner.observers.register import RegisterObserver +from kernel_tuner.util import get_gpu_id, get_gpu_type + +@ray.remote +class RemoteActor(): + def __init__(self, + kernel_source, + kernel_options, + device_options, + iterations, + observers_type_and_arguments, + id, + cache_manager=None, + simulation_mode=False): + self.kernel_source = kernel_source + self.kernel_options = kernel_options + self.device_options = device_options + self.iterations = iterations + self.cache_manager = cache_manager + self.simulation_mode = simulation_mode + self.runner = None + self.id = None + self._reinitialize_observers(observers_type_and_arguments) + self.dev = DeviceInterface(kernel_source, iterations=iterations, observers=self.observers, **device_options) if not simulation_mode else None + + def get_environment(self): + return self.dev.get_environment() + + def execute(self, tuning_options, strategy=None, searchspace=None, element=None): + tuning_options['observers'] = self.observers + if self.runner is None: + self.init_runner() + if strategy and searchspace: + results = strategy.tune(searchspace, self.runner, tuning_options) + # observers can't be pickled + tuning_options['observers'] = None + return results, tuning_options + elif element: + results = self.runner.run([element], tuning_options)[0] + # observers can't be pickled + tuning_options['observers'] = None + return results, tuning_options + else: + raise ValueError("Invalid arguments for ray actor's execute method.") + + def set_cache_manager(self, cache_manager): + if self.cache_manager is None: + self.cache_manager = cache_manager + + def get_cache_magaer(self): + return self.cache_manager + + def init_runner(self): + if self.cache_manager is None: + raise ValueError("Cache manager is not set.") + if self.simulation_mode: + self.runner = SimulationRunner(self.kernel_source, self.kernel_options, self.device_options, + self.iterations, self.observers) + else: + self.runner = SequentialRunner(self.kernel_source, self.kernel_options, self.device_options, + self.iterations, self.observers, cache_manager=self.cache_manager, dev=self.dev) + + def _reinitialize_observers(self, observers_type_and_arguments): + # observers can't be pickled to the actor so we need to re-initialize them + self.observers = [] + for (observer, arguments) in observers_type_and_arguments: + if "device" in arguments: + self.id = get_gpu_id(self.kernel_source.lang) if self.id is None else self.id + arguments["device"] = self.id + if isinstance(observer, RegisterObserver): + self.observers.append(RegisterObserver()) + else: + self.observers.append(observer(**arguments)) + + + def get_gpu_type(self, lang): + return get_gpu_type(lang) diff --git a/kernel_tuner/runners/sequential.py b/kernel_tuner/runners/sequential.py index aeebd5116..46ba17e0a 100644 --- a/kernel_tuner/runners/sequential.py +++ b/kernel_tuner/runners/sequential.py @@ -2,6 +2,7 @@ import logging from datetime import datetime, timezone from time import perf_counter +import ray from kernel_tuner.core import DeviceInterface from kernel_tuner.runners.runner import Runner @@ -11,7 +12,7 @@ class SequentialRunner(Runner): """SequentialRunner is used for tuning with a single process/thread.""" - def __init__(self, kernel_source, kernel_options, device_options, iterations, observers): + def __init__(self, kernel_source, kernel_options, device_options, iterations, observers, cache_manager=None, dev=None): """Instantiate the SequentialRunner. :param kernel_source: The kernel source @@ -27,9 +28,15 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob :param iterations: The number of iterations used for benchmarking each kernel instance. :type iterations: int + + :param observers: List of observers. + :type observers: list + + :param cache_manager: Cache manager instance. Defaults to None. + :type cache_manager: kernel_tuner.runners.ray.cache_manager.CacheManager, optional """ #detect language and create high-level device interface - self.dev = DeviceInterface(kernel_source, iterations=iterations, observers=observers, **device_options) + self.dev = DeviceInterface(kernel_source, iterations=iterations, observers=observers, **device_options) if dev is None else dev self.units = self.dev.units self.quiet = device_options.quiet @@ -40,6 +47,10 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob self.last_strategy_start_time = self.start_time self.last_strategy_time = 0 self.kernel_options = kernel_options + self.device_options = device_options # needed for the ensemble strategy down the line + self.iterations = iterations # needed for the ensemble strategy down the line + self.observers = observers # needed for the ensemble strategy down the line + self.cache_manager = cache_manager #move data to the GPU self.gpu_args = self.dev.ready_argument_list(kernel_options.arguments) @@ -75,8 +86,9 @@ def run(self, parameter_space, tuning_options): # check if configuration is in the cache x_int = ",".join([str(i) for i in element]) - if tuning_options.cache and x_int in tuning_options.cache: - params.update(tuning_options.cache[x_int]) + cache_result = self.config_in_cache(x_int, tuning_options) + if cache_result: + params.update(cache_result) params['compile_time'] = 0 params['verification_time'] = 0 params['benchmark_time'] = 0 @@ -111,9 +123,23 @@ def run(self, parameter_space, tuning_options): print_config_output(tuning_options.tune_params, params, self.quiet, tuning_options.metrics, self.units) # add configuration to cache - store_cache(x_int, params, tuning_options) + self.store_in_cache(x_int, params, tuning_options) # all visited configurations are added to results to provide a trace for optimization strategies results.append(params) return results + + def config_in_cache(self, x_int, tuning_options): + if self.cache_manager and tuning_options.strategy_options['check_and_retrieve']: + return ray.get(self.cache_manager.check_and_retrieve.remote(x_int)) + elif tuning_options.cache and x_int in tuning_options.cache: + return tuning_options.cache[x_int] + else: + return None + + def store_in_cache(self, x_int, params, tuning_options): + if self.cache_manager: + self.cache_manager.store.remote(x_int, params) + else: + store_cache(x_int, params, tuning_options) \ No newline at end of file diff --git a/kernel_tuner/runners/simulation.py b/kernel_tuner/runners/simulation.py index 22c7c667c..f354333b6 100644 --- a/kernel_tuner/runners/simulation.py +++ b/kernel_tuner/runners/simulation.py @@ -58,6 +58,10 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob self.last_strategy_time = 0 self.units = {} + self.device_options = device_options # needed for the ensemble strategy down the line + self.iterations = iterations # needed for the ensemble strategy down the line + self.observers = observers # needed for the ensemble strategy down the line + def get_environment(self, tuning_options): env = self.dev.get_environment() env["simulation"] = True diff --git a/kernel_tuner/searchspace.py b/kernel_tuner/searchspace.py index 5ee7f7ce2..0317ff434 100644 --- a/kernel_tuner/searchspace.py +++ b/kernel_tuner/searchspace.py @@ -50,6 +50,11 @@ def __init__( restrictions = restrictions if restrictions is not None else [] self.tune_params = tune_params self.restrictions = restrictions + self.max_threads = max_threads + self.block_size_names = block_size_names + self.framework = framework + self.solver_method = solver_method + self.path_to_ATF_cache = path_to_ATF_cache # the searchspace can add commonly used constraints (e.g. maxprod(blocks) <= maxthreads) self._modified_restrictions = restrictions self.param_names = list(self.tune_params.keys()) diff --git a/kernel_tuner/strategies/brute_force.py b/kernel_tuner/strategies/brute_force.py index a0e3f8ebe..cf6ba521b 100644 --- a/kernel_tuner/strategies/brute_force.py +++ b/kernel_tuner/strategies/brute_force.py @@ -1,13 +1,21 @@ """ The default strategy that iterates through the whole parameter space """ from kernel_tuner.searchspace import Searchspace from kernel_tuner.strategies import common +from kernel_tuner.runners.parallel import ParallelRunner +from kernel_tuner.runners.ray.cache_manager import CacheManager -_options = {} +_options = dict(num_gpus=("Number of gpus to run parallel execution", None)) def tune(searchspace: Searchspace, runner, tuning_options): - # call the runner - return runner.run(searchspace.sorted_list(), tuning_options) + if isinstance(runner, ParallelRunner): + if tuning_options.strategy_options is None: + tuning_options.strategy_options = {} + tuning_options.strategy_options['check_and_retrieve'] = False + cache_manager = CacheManager.remote(tuning_options.cache, tuning_options.cachefile) + return runner.run(parameter_space=searchspace.sorted_list(), tuning_options=tuning_options, cache_manager=cache_manager) + else: + return runner.run(searchspace.sorted_list(), tuning_options) tune.__doc__ = common.get_strategy_docstring("Brute Force", _options) diff --git a/kernel_tuner/strategies/common.py b/kernel_tuner/strategies/common.py index d01eae937..5e4dba354 100644 --- a/kernel_tuner/strategies/common.py +++ b/kernel_tuner/strategies/common.py @@ -1,11 +1,19 @@ import logging import sys from time import perf_counter +import warnings +import ray import numpy as np from kernel_tuner import util from kernel_tuner.searchspace import Searchspace +from kernel_tuner.util import get_num_devices +from kernel_tuner.runners.ray.remote_actor import RemoteActor +from kernel_tuner.observers.nvml import NVMLObserver, NVMLPowerObserver +from kernel_tuner.observers.pmt import PMTObserver +from kernel_tuner.observers.powersensor import PowerSensorObserver +from kernel_tuner.observers.register import RegisterObserver _docstring_template = """ Find the best performing kernel configuration in the parameter space @@ -44,7 +52,7 @@ def make_strategy_options_doc(strategy_options): def get_options(strategy_options, options): """Get the strategy-specific options or their defaults from user-supplied strategy_options.""" - accepted = list(options.keys()) + ["max_fevals", "time_limit"] + accepted = list(options.keys()) + ["max_fevals", "time_limit", "ensemble", "check_and_retrieve"] for key in strategy_options: if key not in accepted: raise ValueError(f"Unrecognized option {key} in strategy_options") @@ -72,7 +80,57 @@ def __call__(self, x, check_restrictions=True): # check if max_fevals is reached or time limit is exceeded util.check_stop_criterion(self.tuning_options) - # snap values in x to nearest actual value for each parameter, unscale x if needed + x_list = [x] if self._is_single_configuration(x) else x + configs = [self._prepare_config(cfg) for cfg in x_list] + + legal_configs = configs + illegal_results = [] + if check_restrictions and self.searchspace.restrictions: + legal_configs, illegal_results = self._get_legal_configs(configs) + + final_results = self._evaluate_configs(legal_configs) if len(legal_configs) > 0 else [] + # get numerical return values, taking optimization direction into account + all_results = final_results + illegal_results + return_values = [] + for result in all_results: + return_value = result[self.tuning_options.objective] or sys.float_info.max + return_values.append(return_value if not self.tuning_options.objective_higher_is_better else -return_value) + + if len(return_values) == 1: + return return_values[0] + return return_values + + def _is_single_configuration(self, x): + """ + Determines if the input is a single configuration based on its type and composition. + + Parameters: + x: The input to check, which can be an int, float, numpy array, list, or tuple. + + Returns: + bool: True if `x` is a single configuration, which includes being a singular int or float, + a numpy array of ints or floats, or a list or tuple where all elements are ints or floats. + Otherwise, returns False. + """ + if isinstance(x, (int, float)): + return True + if isinstance(x, np.ndarray): + return x.dtype.kind in 'if' # Checks for data type being integer ('i') or float ('f') + if isinstance(x, (list, tuple)): + return all(isinstance(item, (int, float)) for item in x) + return False + + def _prepare_config(self, x): + """ + Prepare a single configuration by snapping to nearest values and/or scaling. + + Args: + x (list): The input configuration to be prepared. + + Returns: + list: The prepared configuration. + + """ if self.snap: if self.scaling: params = unscale_and_snap_to_nearest(x, self.searchspace.tune_params, self.tuning_options.eps) @@ -80,39 +138,66 @@ def __call__(self, x, check_restrictions=True): params = snap_to_nearest_config(x, self.searchspace.tune_params) else: params = x - logging.debug('params ' + str(params)) - - legal = True - result = {} - x_int = ",".join([str(i) for i in params]) - - # else check if this is a legal (non-restricted) configuration - if check_restrictions and self.searchspace.restrictions: - params_dict = dict(zip(self.searchspace.tune_params.keys(), params)) + return params + + def _get_legal_configs(self, configs): + """ + Filters and categorizes configurations into legal and illegal based on defined restrictions. + Configurations are checked against restrictions; illegal ones are modified to indicate an invalid state and + included in the results. Legal configurations are collected and returned for potential use. + + Parameters: + configs (list of tuple): Configurations to be checked, each represented as a tuple of parameter values. + + Returns: + tuple: A pair containing a list of legal configurations and a list of results with illegal configurations marked. + """ + results = [] + legal_configs = [] + for config in configs: + params_dict = dict(zip(self.searchspace.tune_params.keys(), config)) legal = util.check_restrictions(self.searchspace.restrictions, params_dict, self.tuning_options.verbose) if not legal: - result = params_dict - result[self.tuning_options.objective] = util.InvalidConfig() - - if legal: - # compile and benchmark this instance - res = self.runner.run([params], self.tuning_options) - result = res[0] - + params_dict[self.tuning_options.objective] = util.InvalidConfig() + results.append(params_dict) + else: + legal_configs.append(config) + return legal_configs, results + + def _evaluate_configs(self, configs): + """ + Evaluate and manage configurations based on tuning options. Results are sorted by timestamp to maintain + order during parallel processing. The function ensures no duplicates in results and checks for stop criteria + post-processing. Strategy start time is updated upon completion. + + Parameters: + configs (list): Configurations to be evaluated. + + Returns: + list of dict: Processed results of the evaluations. + """ + results = self.runner.run(configs, self.tuning_options) + # sort based on timestamp, needed because of parallel tuning of populations and restrospective stop criterion check + if "timestamp" in results[0]: + results.sort(key=lambda x: x['timestamp']) + + final_results = [] + for result in results: + config = tuple(result[key] for key in self.tuning_options.tune_params if key in result) + x_int = ",".join([str(i) for i in config]) # append to tuning results if x_int not in self.tuning_options.unique_results: self.tuning_options.unique_results[x_int] = result - + # check retrospectively if max_fevals is reached or time limit is exceeded within the results + util.check_stop_criterion(self.tuning_options) + final_results.append(result) + # in case of stop creterion reached, save the results so far self.results.append(result) - # upon returning from this function control will be given back to the strategy, so reset the start time - self.runner.last_strategy_start_time = perf_counter() - - # get numerical return value, taking optimization direction into account - return_value = result[self.tuning_options.objective] or sys.float_info.max - return_value = return_value if not self.tuning_options.objective_higher_is_better else -return_value + # upon returning from this function control will be given back to the strategy, so reset the start time + self.runner.last_strategy_start_time = perf_counter() - return return_value + return final_results def get_bounds_x0_eps(self): """Compute bounds, x0 (the initial guess), and eps.""" @@ -243,3 +328,45 @@ def scale_from_params(params, tune_params, eps): for i, v in enumerate(tune_params.values()): x[i] = 0.5 * eps + v.index(params[i])*eps return x + +def check_num_devices(ensemble_size: int, simulation_mode: bool, runner): + + num_devices = get_num_devices(runner.kernel_source.lang, simulation_mode=simulation_mode) + if num_devices < ensemble_size: + warnings.warn("Number of devices is less than the number of strategies in the ensemble. Some strategies will wait until devices are available.", UserWarning) + +def create_actor_on_device(kernel_source, kernel_options, device_options, iterations, observers, cache_manager, simulation_mode, id): + # Check if Ray is initialized, raise an error if not + if not ray.is_initialized(): + raise RuntimeError("Ray is not initialized. Initialize Ray before creating an actor (remember to include resources).") + + if simulation_mode: + resource_options = {"num_cpus": 1} + else: + resource_options = {"num_gpus": 1} + + observers_type_and_arguments = [] + if observers is not None: + # observers can't be pickled so we will re-initialize them in the actors + # observers related to backends will be initialized once we call the device interface inside the actor, that is why we skip them here + for i, observer in enumerate(observers): + if isinstance(observer, (NVMLObserver, NVMLPowerObserver, PMTObserver, PowerSensorObserver)): + observers_type_and_arguments.append((observer.__class__, observer.init_arguments)) + if isinstance(observer, RegisterObserver): + observers_type_and_arguments.append((observer.__class__, [])) + + # Create the actor with the specified options and resources + return RemoteActor.options(**resource_options).remote(kernel_source, + kernel_options, + device_options, + iterations, + observers_type_and_arguments=observers_type_and_arguments, + cache_manager=cache_manager, + simulation_mode=simulation_mode, + id=id) + +def initialize_ray(): + # Initialize Ray + if not ray.is_initialized(): + ray.init(include_dashboard=True, ignore_reinit_error=True) + diff --git a/kernel_tuner/strategies/ensemble.py b/kernel_tuner/strategies/ensemble.py new file mode 100644 index 000000000..2dab125f4 --- /dev/null +++ b/kernel_tuner/strategies/ensemble.py @@ -0,0 +1,88 @@ +""" +The ensemble strategy that optimizes the search through the parameter space using a combination of multiple strategies. +""" + +import warnings + +from kernel_tuner.searchspace import Searchspace +from kernel_tuner.strategies import common +from kernel_tuner.strategies.common import initialize_ray +from kernel_tuner.runners.simulation import SimulationRunner +from kernel_tuner.util import get_num_devices +from kernel_tuner.runners.parallel import ParallelRunner + +from kernel_tuner.strategies import ( + basinhopping, + bayes_opt, + diff_evo, + dual_annealing, + firefly_algorithm, + genetic_algorithm, + greedy_ils, + greedy_mls, + minimize, + mls, + ordered_greedy_mls, + pso, + random_sample, + simulated_annealing, +) + +strategy_map = { + "random_sample": random_sample, + "minimize": minimize, + "basinhopping": basinhopping, + "diff_evo": diff_evo, + "genetic_algorithm": genetic_algorithm, + "greedy_mls": greedy_mls, + "ordered_greedy_mls": ordered_greedy_mls, + "greedy_ils": greedy_ils, + "dual_annealing": dual_annealing, + "mls": mls, + "pso": pso, + "simulated_annealing": simulated_annealing, + "firefly_algorithm": firefly_algorithm, + "bayes_opt": bayes_opt, +} + +_options = dict( + ensemble=("List of strategies to be used in the ensemble", ["random_sample", "random_sample"]), + max_fevals=("Maximum number of function evaluations", None), + num_gpus=("Number of gpus to run the parallel ensemble on", None) +) + +def tune(searchspace: Searchspace, runner, tuning_options, cache_manager=None, actors=None): + clean_up = True if actors is None and cache_manager is None else False + options = tuning_options.strategy_options + simulation_mode = True if isinstance(runner, SimulationRunner) else False + initialize_ray() + + ensemble, max_fevals, num_gpus =common.get_options(tuning_options.strategy_options, _options) + num_devices = num_gpus if num_gpus is not None else get_num_devices(simulation_mode=simulation_mode) + ensemble_size = len(ensemble) + + # setup strategy options + if 'bayes_opt' in ensemble: # All strategies start from a random sample except for BO + tuning_options.strategy_options["samplingmethod"] = 'random' + tuning_options.strategy_options["max_fevals"] = 100 * ensemble_size if max_fevals is None else max_fevals + tuning_options.strategy_options['check_and_retrieve'] = True + + # define number of ray actors needed + if num_devices < ensemble_size: + warnings.warn("Number of devices is less than the number of strategies in the ensemble. Some strategies will wait until devices are available.", UserWarning) + num_actors = num_devices if ensemble_size > num_devices else ensemble_size + + ensemble = [strategy_map[strategy] for strategy in ensemble] + + parallel_runner = ParallelRunner(runner.kernel_source, runner.kernel_options, runner.device_options, + runner.iterations, runner.observers, num_gpus=num_actors, cache_manager=cache_manager, + simulation_mode=simulation_mode, actors=actors) + + final_results = parallel_runner.run(tuning_options=tuning_options, ensemble=ensemble, searchspace=searchspace) + + if clean_up: + parallel_runner.clean_up_ray() + + return final_results + +tune.__doc__ = common.get_strategy_docstring("Ensemble", _options) \ No newline at end of file diff --git a/kernel_tuner/strategies/genetic_algorithm.py b/kernel_tuner/strategies/genetic_algorithm.py index c29c150b5..52361a744 100644 --- a/kernel_tuner/strategies/genetic_algorithm.py +++ b/kernel_tuner/strategies/genetic_algorithm.py @@ -176,4 +176,4 @@ def disruptive_uniform_crossover(dna1, dna2): "two_point": two_point_crossover, "uniform": uniform_crossover, "disruptive_uniform": disruptive_uniform_crossover, -} +} \ No newline at end of file diff --git a/kernel_tuner/strategies/greedy_ils.py b/kernel_tuner/strategies/greedy_ils.py index a4c521746..26d15f591 100644 --- a/kernel_tuner/strategies/greedy_ils.py +++ b/kernel_tuner/strategies/greedy_ils.py @@ -63,4 +63,4 @@ def random_walk(indiv, permutation_size, no_improve, last_improve, searchspace: return searchspace.get_random_sample(1)[0] for _ in range(permutation_size): indiv = mutate(indiv, 0, searchspace, cache=False) - return indiv + return indiv \ No newline at end of file diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 0d2cef696..21a6edd08 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -11,6 +11,8 @@ from inspect import signature from types import FunctionType from typing import Optional, Union +import ray +import subprocess import numpy as np from constraint import ( @@ -90,6 +92,9 @@ class SkippableFailure(Exception): class StopCriterionReached(Exception): """Exception thrown when a stop criterion has been reached.""" +class GPUTypeMismatchError(Exception): + """Exception thrown when GPU types are not the same in parallel execution""" + try: import torch @@ -890,6 +895,7 @@ def is_or_evals_to_number(s: str) -> Optional[Union[int, float]]: except Exception: # it's not a solvable subexpression, return None return None + # either the left or right side of the equation must evaluate to a constant number left_num = is_or_evals_to_number(left) @@ -1275,3 +1281,26 @@ def cuda_error_check(error): if error != nvrtc.nvrtcResult.NVRTC_SUCCESS: _, desc = nvrtc.nvrtcGetErrorString(error) raise RuntimeError(f"NVRTC error: {desc.decode()}") + +def get_num_devices(simulation_mode=False): + resources = ray.cluster_resources() + if simulation_mode: + num_devices = round(resources.get("CPU") * 0.8) + else: + num_devices = resources.get("GPU") + return int(num_devices) + +def get_gpu_id(lang): + if lang == "CUDA" or lang == "CUPY" or lang == "NVCUDA": + gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES") or os.environ.get("NVIDIA_VISIBLE_DEVICES") or "No GPU assigned" + else: + raise NotImplementedError("TODO: implement other languages") + return int(gpu_id) + +def get_gpu_type(lang): + gpu_id = get_gpu_id(lang) + if lang == "CUDA" or lang == "CUPY" or lang == "NVCUDA": + result = subprocess.run(['nvidia-smi', '--query-gpu=gpu_name', '--format=csv,noheader', '-i', str(gpu_id)], capture_output=True, text=True) + return result.stdout.strip() + else: + raise NotImplementedError("TODO: implement other languages") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 13d1cb647..721c60e7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ python-constraint2 = "^2.0.0b5" xmltodict = "*" pandas = ">=2.0.0" scikit-learn = ">=1.0.2" +ray = { version = ">=2.9.1", extras = ["default"] } # Torch can be used with Kernel Tuner, but is not a dependency, should be up to the user to use it # List of optional dependencies for user installation, e.g. `pip install kernel_tuner[cuda]`, used in the below `extras`. diff --git a/test/strategies/test_strategies.py b/test/strategies/test_strategies.py index 096be38b0..1001aabec 100644 --- a/test/strategies/test_strategies.py +++ b/test/strategies/test_strategies.py @@ -36,7 +36,7 @@ def vector_add(): @pytest.mark.parametrize('strategy', strategy_map) def test_strategies(vector_add, strategy): - options = dict(popsize=5, neighbor='adjacent') + options = dict(popsize=5) print(f"testing {strategy}") diff --git a/test/test_ensemble_tuning.py b/test/test_ensemble_tuning.py new file mode 100644 index 000000000..69efb5a68 --- /dev/null +++ b/test/test_ensemble_tuning.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest +import logging +import sys + +from kernel_tuner import tune_kernel +from kernel_tuner.backends import nvcuda +from kernel_tuner.core import KernelInstance, KernelSource +from .context import skip_if_no_pycuda + +try: + import pycuda.driver +except Exception: + pass + +@pytest.fixture +def env(): + kernel_string = """ + extern "C" __global__ void vector_add(float *c, float *a, float *b, int n) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int j = blockIdx.y * blockDim.y + threadIdx.y; + int index = i + j * gridDim.x * blockDim.x; + if (index < n) { + c[index] = a[index] + b[index]; + } + } + """ + + size = 100 + a = np.random.randn(size).astype(np.float32) + b = np.random.randn(size).astype(np.float32) + c = np.zeros_like(b) + n = np.int32(size) + + args = [c, a, b, n] + tune_params = dict() + + # Extend the range of block sizes for a bigger search space + tune_params["block_size_x"] = [128 + 64 * i for i in range(30)] + tune_params["block_size_y"] = [1 + i for i in range(1, 16)] + + return ["vector_add", kernel_string, size, args, tune_params] + +@skip_if_no_pycuda +def test_parallel_tune_kernel(env): + strategy_options = {"ensemble": ["greedy_ils", "greedy_ils"]} + result, _ = tune_kernel(*env, lang="CUDA", verbose=True, strategy="ensemble", + parallel_mode=True, strategy_options=strategy_options) + assert len(result) > 0 \ No newline at end of file diff --git a/test/test_parallel_tuning.py b/test/test_parallel_tuning.py new file mode 100644 index 000000000..bbe4d96b7 --- /dev/null +++ b/test/test_parallel_tuning.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest +import logging +import sys + +from kernel_tuner import tune_kernel +from kernel_tuner.backends import nvcuda +from kernel_tuner.core import KernelInstance, KernelSource +from .context import skip_if_no_pycuda + +try: + import pycuda.driver +except Exception: + pass + +@pytest.fixture +def env(): + kernel_string = """ + extern "C" __global__ void vector_add(float *c, float *a, float *b, int n) { + int i = blockIdx.x * block_size_x + threadIdx.x; + if (i 0 \ No newline at end of file