From 659c5aa0cd09046999ef8848d7d12b88e7b55d64 Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Thu, 30 Jan 2025 04:13:31 +0100 Subject: [PATCH] Apply black to fix CI --- optunahub/_conf.py | 4 +++- optunahub/hub.py | 4 +++- optunahub/samplers/_simple_base.py | 12 +++++++++--- tests/package_for_test_hub/sampler.py | 4 +++- 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/optunahub/_conf.py b/optunahub/_conf.py index 43dced8..adb8806 100644 --- a/optunahub/_conf.py +++ b/optunahub/_conf.py @@ -26,7 +26,9 @@ def cache_home() -> str: ) else: # UNIX-like return os.path.join( - os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")), + os.getenv( + "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache") + ), "optunahub", ) diff --git a/optunahub/hub.py b/optunahub/hub.py index 7a2dfd8..5433077 100644 --- a/optunahub/hub.py +++ b/optunahub/hub.py @@ -105,7 +105,9 @@ def load_module( hostname = urlparse(base_url).hostname if hostname is None: raise ValueError(f"Invalid base URL: {base_url}") - cache_dir_prefix = os.path.join(_conf.cache_home(), hostname, repo_owner, repo_name, ref) + cache_dir_prefix = os.path.join( + _conf.cache_home(), hostname, repo_owner, repo_name, ref + ) package_cache_dir = os.path.join(cache_dir_prefix, dir_path) use_cache = not force_reload and os.path.exists(package_cache_dir) diff --git a/optunahub/samplers/_simple_base.py b/optunahub/samplers/_simple_base.py index 480d15f..02742da 100644 --- a/optunahub/samplers/_simple_base.py +++ b/optunahub/samplers/_simple_base.py @@ -15,7 +15,9 @@ class SimpleBaseSampler(BaseSampler, abc.ABC): """A simple base class to implement user-defined samplers.""" def __init__( - self, search_space: dict[str, BaseDistribution] | None = None, seed: int | None = None + self, + search_space: dict[str, BaseDistribution] | None = None, + seed: int | None = None, ) -> None: self.search_space = search_space self._seed = seed @@ -53,7 +55,9 @@ def sample_independent( ) -> Any: # This method is optional. # By default, parameter values are sampled by ``optuna.samplers.RandomSampler``. - return self._default_sample_independent(study, trial, param_name, param_distribution) + return self._default_sample_independent( + study, trial, param_name, param_distribution + ) def reseed_rng(self) -> None: self._default_reseed_rng() @@ -66,7 +70,9 @@ def _default_infer_relative_search_space( self, study: Study, trial: FrozenTrial ) -> dict[str, BaseDistribution]: search_space: dict[str, BaseDistribution] = {} - for name, distribution in self._intersection_search_space.calculate(study).items(): + for name, distribution in self._intersection_search_space.calculate( + study + ).items(): if distribution.single(): # Single value objects are not sampled with the `sample_relative` method, # but with the `sample_independent` method. diff --git a/tests/package_for_test_hub/sampler.py b/tests/package_for_test_hub/sampler.py index fe2ba2d..0063f84 100644 --- a/tests/package_for_test_hub/sampler.py +++ b/tests/package_for_test_hub/sampler.py @@ -11,7 +11,9 @@ class TestSampler(optunahub.samplers.SimpleBaseSampler): - def __init__(self, search_space: Optional[Dict[str, BaseDistribution]] = None) -> None: + def __init__( + self, search_space: Optional[Dict[str, BaseDistribution]] = None + ) -> None: super().__init__(search_space) self._rng = np.random.RandomState()