Skip to content

Commit

Permalink
Mypy: Error on Common Truthy Mistakes (#524)
Browse files Browse the repository at this point in the history
Configure mypy to raise an error when:
- An instance of an object is used for a boolean check when neither
`__bool__` or `__len__` are implemented
- A `Iterator` is used on a boolean check when the author almost
certainly wanted a `Collection`

Fix up/refactors areas of the code base where these potential errors
linger.

[ committed by @MattToast ]
[ reviewed by @ankona @al-rigazzi ]
  • Loading branch information
MattToast authored Mar 19, 2024
1 parent 6dea582 commit d1dfac8
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 62 deletions.
4 changes: 4 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@ Detailed Notes
software and files that come from the default Github Ubuntu container.
(SmartSim-PR504_)
- Update the generic `t.Any` typehints in Experiment API. (SmartSim-PR501_)
- The CI will fail static analysis if common erroneous truthy checks are
detected. (SmartSim-PR524_)
- Remove previously deprecated behavior present in test suite on machines with
Slurm and Open MPI. (SmartSim-PR520_)


.. _SmartSim-PR524: https://github.com/CrayLabs/SmartSim/pull/524
.. _SmartSim-PR520: https://github.com/CrayLabs/SmartSim/pull/520
.. _SmartSim-PR518: https://github.com/CrayLabs/SmartSim/pull/518
.. _SmartSim-PR517: https://github.com/CrayLabs/SmartSim/pull/517
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ enable_error_code = [
# "unused-awaitable",
# "ignore-without-code",
# "mutable-override",
"truthy-bool",
"truthy-iterable",
]

[[tool.mypy.overrides]]
Expand Down
4 changes: 2 additions & 2 deletions smartsim/_core/_cli/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,10 @@ def _assess_python_env(


def _format_incompatible_python_env_message(
missing: t.Iterable[str], conflicting: t.Iterable[str]
missing: t.Collection[str], conflicting: t.Collection[str]
) -> str:
indent = "\n\t"
fmt_list: t.Callable[[str, t.Iterable[str]], str] = lambda n, l: (
fmt_list: t.Callable[[str, t.Collection[str]], str] = lambda n, l: (
f"{n}:{indent}{indent.join(l)}" if l else ""
)
missing_str = fmt_list("Missing", missing)
Expand Down
48 changes: 20 additions & 28 deletions smartsim/_core/generation/modelwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import collections
import re
import typing as t

Expand Down Expand Up @@ -125,35 +126,26 @@ def _replace_tags(
:rtype: dict[str,str]
"""
edited = []
unused_tags: t.Dict[str, t.List[int]] = {}
unused_tags: t.DefaultDict[str, t.List[int]] = collections.defaultdict(list)
used_params: t.Dict[str, str] = {}
for i, line in enumerate(self.lines):
search = re.search(self.regex, line)
if search:
while search:
tagged_line = search.group(0)
previous_value = self._get_prev_value(tagged_line)
if self._is_ensemble_spec(tagged_line, params):
new_val = str(params[previous_value])
new_line = re.sub(self.regex, new_val, line, 1)
search = re.search(self.regex, new_line)
used_params[previous_value] = new_val
if not search:
edited.append(new_line)
else:
line = new_line

# if a tag is found but is not in this model's configurations
# put in placeholder value
else:
tag = tagged_line.split(self.tag)[1]
if tag not in unused_tags:
unused_tags[tag] = []
unused_tags[tag].append(i + 1)
edited.append(re.sub(self.regex, previous_value, line))
search = None # Move on to the next tag
else:
edited.append(line)
for i, line in enumerate(self.lines, 1):
while search := re.search(self.regex, line):
tagged_line = search.group(0)
previous_value = self._get_prev_value(tagged_line)
if self._is_ensemble_spec(tagged_line, params):
new_val = str(params[previous_value])
line = re.sub(self.regex, new_val, line, 1)
used_params[previous_value] = new_val

# if a tag is found but is not in this model's configurations
# put in placeholder value
else:
tag = tagged_line.split(self.tag)[1]
unused_tags[tag].append(i)
line = re.sub(self.regex, previous_value, line)
break
edited.append(line)

for tag, value in unused_tags.items():
missing_tag_message = f"Unused tag {tag} on line(s): {str(value)}"
if make_fatal:
Expand Down
3 changes: 1 addition & 2 deletions smartsim/database/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,7 @@ def set_batch_arg(self, arg: str, value: t.Optional[str] = None) -> None:
"it is a reserved keyword in Orchestrator"
)
else:
if hasattr(self, "batch_settings") and self.batch_settings:
self.batch_settings.batch_args[arg] = value
self.batch_settings.batch_args[arg] = value

def set_run_arg(self, arg: str, value: t.Optional[str] = None) -> None:
"""Set a run argument the orchestrator should launch
Expand Down
9 changes: 3 additions & 6 deletions smartsim/entity/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,9 @@ def __init__(
super().__init__(name, getcwd(), perm_strat=perm_strat, **kwargs)

@property
def models(self) -> t.Iterable[Model]:
"""
Helper property to cast self.entities to Model type for type correctness
"""
model_entities = [node for node in self.entities if isinstance(node, Model)]
return model_entities
def models(self) -> t.Collection[Model]:
"""An alias for a shallow copy of the ``entities`` attribute"""
return list(self.entities)

def _initialize_entities(self, **kwargs: t.Any) -> None:
"""Initialize all the models within the ensemble based
Expand Down
43 changes: 19 additions & 24 deletions smartsim/entity/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from __future__ import annotations

import collections.abc
import itertools
import re
import sys
import typing as t
Expand Down Expand Up @@ -414,9 +414,10 @@ def _set_colocated_db_settings(
def _create_pinning_string(
pin_ids: t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]], cpus: int
) -> t.Optional[str]:
"""Create a comma-separated string CPU ids. By default, None returns
0,1,...,cpus-1; an empty iterable will disable pinning altogether,
and an iterable constructs a comma separate string (e.g. 0,2,5)
"""Create a comma-separated string of CPU ids. By default, ``None``
returns 0,1,...,cpus-1; an empty iterable will disable pinning
altogether, and an iterable constructs a comma separated string of
integers (e.g. ``[0, 2, 5]`` -> ``"0,2,5"``)
"""

def _stringify_id(_id: int) -> str:
Expand All @@ -428,40 +429,34 @@ def _stringify_id(_id: int) -> str:

raise TypeError(f"Argument is of type '{type(_id)}' not 'int'")

_invalid_input_message = (
"Expected a cpu pinning specification of type iterable of ints or "
f"iterables of ints. Instead got type `{type(pin_ids)}`"
)
try:
pin_ids = tuple(pin_ids) if pin_ids is not None else None
except TypeError:
raise TypeError(
"Expected a cpu pinning specification of type iterable of ints or "
f"iterables of ints. Instead got type `{type(pin_ids)}`"
) from None

# Deal with MacOSX limitations first. The "None" (default) disables pinning
# and is equivalent to []. The only invalid option is an iterable
# and is equivalent to []. The only invalid option is a non-empty pinning
if sys.platform == "darwin":
if pin_ids is None or not pin_ids:
return None

if isinstance(pin_ids, collections.abc.Iterable):
if pin_ids:
warnings.warn(
"CPU pinning is not supported on MacOSX. Ignoring pinning "
"specification.",
RuntimeWarning,
)
return None
raise TypeError(_invalid_input_message)
return None

# Flatten the iterable into a list and check to make sure that the resulting
# elements are all ints
if pin_ids is None:
return ",".join(_stringify_id(i) for i in range(cpus))
if not pin_ids:
return None
if isinstance(pin_ids, collections.abc.Iterable):
pin_list = []
for pin_id in pin_ids:
if isinstance(pin_id, collections.abc.Iterable):
pin_list.extend([_stringify_id(j) for j in pin_id])
else:
pin_list.append(_stringify_id(pin_id))
return ",".join(sorted(set(pin_list)))
raise TypeError(_invalid_input_message)
pin_ids = ((x,) if isinstance(x, int) else x for x in pin_ids)
to_fmt = itertools.chain.from_iterable(pin_ids)
return ",".join(sorted({_stringify_id(x) for x in to_fmt}))

def params_to_args(self) -> None:
"""Convert parameters to command line arguments and update run settings."""
Expand Down

0 comments on commit d1dfac8

Please sign in to comment.