diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index fbf2dcc32..e45fe5eff 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -118,6 +118,15 @@ jobs: - name: Run Pylint run: make check-lint + # Run isort/black style check + - name: Run isort + run: isort --check-only --profile black ./smartsim ./tests + + # Run isort/black style check + - name: Run black + run: | + black --exclude smartsim/version.py --check ./smartsim ./tests + # Run pytest (backends subdirectory) - name: Run Pytest if: (matrix.subset == 'backends') diff --git a/smartsim/_core/_install/builder.py b/smartsim/_core/_install/builder.py index dba55bed5..27c11c791 100644 --- a/smartsim/_core/_install/builder.py +++ b/smartsim/_core/_install/builder.py @@ -336,12 +336,10 @@ class _RAIBuildDependency(ABC): @property @abstractmethod - def __rai_dependency_name__(self) -> str: - ... + def __rai_dependency_name__(self) -> str: ... @abstractmethod - def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: - ... + def __place_for_rai__(self, target: t.Union[str, "os.PathLike[str]"]) -> Path: ... def _place_rai_dep_at( @@ -680,8 +678,7 @@ def _threaded_map(fn: t.Callable[[_T], _U], items: t.Iterable[_T]) -> t.Sequence class _WebLocation(ABC): @property @abstractmethod - def url(self) -> str: - ... + def url(self) -> str: ... class _WebGitRepository(_WebLocation): @@ -735,8 +732,7 @@ class _ExtractableWebArchive(_WebArchive, ABC): @abstractmethod def _extract_download( self, download_path: Path, target: t.Union[str, "os.PathLike[str]"] - ) -> None: - ... + ) -> None: ... def extract(self, target: t.Union[str, "os.PathLike[str]"]) -> None: with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/smartsim/_core/control/controller.py b/smartsim/_core/control/controller.py index e4ad012a8..3204600d6 100644 --- a/smartsim/_core/control/controller.py +++ b/smartsim/_core/control/controller.py @@ -875,8 +875,7 @@ def __init__(self, model: Model) -> None: self.entities = [model] self.batch_settings = model.batch_settings - def _initialize_entities(self, **kwargs: t.Any) -> None: - ... + def _initialize_entities(self, **kwargs: t.Any) -> None: ... def _look_up_launched_data( diff --git a/smartsim/_core/control/manifest.py b/smartsim/_core/control/manifest.py index 9962d61b4..452a52f09 100644 --- a/smartsim/_core/control/manifest.py +++ b/smartsim/_core/control/manifest.py @@ -296,9 +296,9 @@ class LaunchedManifestBuilder(t.Generic[_T]): _ensembles: t.List[t.Tuple[Ensemble, t.Tuple[t.Tuple[Model, _T], ...]]] = field( default_factory=list, init=False ) - _databases: t.List[ - t.Tuple[Orchestrator, t.Tuple[t.Tuple[DBNode, _T], ...]] - ] = field(default_factory=list, init=False) + _databases: t.List[t.Tuple[Orchestrator, t.Tuple[t.Tuple[DBNode, _T], ...]]] = ( + field(default_factory=list, init=False) + ) @property def exp_telemetry_subdirectory(self) -> pathlib.Path: diff --git a/smartsim/_core/entrypoints/redis.py b/smartsim/_core/entrypoints/redis.py index 7262a5996..ef9911829 100644 --- a/smartsim/_core/entrypoints/redis.py +++ b/smartsim/_core/entrypoints/redis.py @@ -79,8 +79,7 @@ def print_summary( cmd: t.List[str], network_interface: str, shard_data: LaunchedShardData ) -> None: print( - textwrap.dedent( - f"""\ + textwrap.dedent(f"""\ ----------- Running Command ---------- COMMAND: {' '.join(cmd)} IPADDRESS: {shard_data.hostname} @@ -90,8 +89,7 @@ def print_summary( --------------- Output --------------- - """ - ), + """), flush=True, ) diff --git a/smartsim/entity/dbobject.py b/smartsim/entity/dbobject.py index 5b9fb4ec6..bebedb12c 100644 --- a/smartsim/entity/dbobject.py +++ b/smartsim/entity/dbobject.py @@ -49,9 +49,9 @@ def __init__( ) -> None: self.name = name self.func = func - self.file: t.Optional[ - Path - ] = None # Need to have this explicitly to check on it + self.file: t.Optional[Path] = ( + None # Need to have this explicitly to check on it + ) if file_path: self.file = self._check_filepath(file_path) self.device = self._check_device(device) diff --git a/smartsim/log.py b/smartsim/log.py index 9d01a57c5..44e26339e 100644 --- a/smartsim/log.py +++ b/smartsim/log.py @@ -39,9 +39,9 @@ # constants DEFAULT_DATE_FORMAT: t.Final[str] = "%H:%M:%S" -DEFAULT_LOG_FORMAT: t.Final[ - str -] = "%(asctime)s %(hostname)s %(name)s[%(process)d] %(levelname)s %(message)s" +DEFAULT_LOG_FORMAT: t.Final[str] = ( + "%(asctime)s %(hostname)s %(name)s[%(process)d] %(levelname)s %(message)s" +) EXPERIMENT_LOG_FORMAT = DEFAULT_LOG_FORMAT.replace("s[%", "s {%(exp_path)s} [%") # configure colored loggs diff --git a/tests/test_controller.py b/tests/test_controller.py index 85c90b1c3..65687ec59 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -48,8 +48,7 @@ class MockStep(Step): def _create_unique_name(name): return name - def add_to_batch(self, step): - ... + def add_to_batch(self, step): ... def get_launch_cmd(self): return [] diff --git a/tests/test_dbnode.py b/tests/test_dbnode.py index 0629fb60d..ec0ed23ea 100644 --- a/tests/test_dbnode.py +++ b/tests/test_dbnode.py @@ -93,9 +93,7 @@ def test_launched_shard_info_can_be_serialized(): @pytest.mark.parametrize("limit", [None, 1]) def test_db_node_can_parse_launched_shard_info(limit): rand_shards = [_random_shard_info() for _ in range(3)] - with io.StringIO( - textwrap.dedent( - """\ + with io.StringIO(textwrap.dedent("""\ This is some file like str -------------------------- @@ -110,9 +108,7 @@ def test_db_node_can_parse_launched_shard_info(limit): SMARTSIM_ORC_SHARD_INFO: {} All other lines should be ignored. - """ - ).format(*(json.dumps(s.to_dict()) for s in rand_shards)) - ) as stream: + """).format(*(json.dumps(s.to_dict()) for s in rand_shards))) as stream: parsed_shards = DBNode._parse_launched_shard_info_from_iterable(stream, limit) if limit is not None: rand_shards = rand_shards[:limit]