Skip to content

Commit

Permalink
Fix mypy issues SCMSUITE-10131 SO107
Browse files Browse the repository at this point in the history
  • Loading branch information
dormrod committed Dec 9, 2024
1 parent 521ac35 commit ff7e709
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 11 deletions.
12 changes: 6 additions & 6 deletions core/basejob.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import threading
import time
from os.path import join as opj
from typing import TYPE_CHECKING, Dict, Generator, Iterable, List, Optional, Union, Callable, Any
from typing import TYPE_CHECKING, Dict, Generator, Iterable, List, Optional, Union
from abc import ABC, abstractmethod
import traceback

Expand All @@ -30,7 +30,7 @@
__all__ = ["SingleJob", "MultiJob"]


def _fail_on_exception(func: Callable[["Job", Any], Any]) -> Callable[["Job", Any], Any]:
def _fail_on_exception(func):
"""Decorator to wrap a job method and mark the job as failed on any exception."""

def wrapper(self: "Job", *args, **kwargs):
Expand All @@ -39,11 +39,11 @@ def wrapper(self: "Job", *args, **kwargs):
except:
# Mark job status as failed and the results as complete
self.status = JobStatus.FAILED
self.results.finished.set()
self.results.done.set()
self.results.finished.set() # type: ignore
self.results.done.set() # type: ignore
# Notify any parent multi-job of the failure
if self.parent and self in self.parent:
self.parent._notify()
if self.parent and self in self.parent: # type: ignore
self.parent._notify() # type: ignore
# Store the exception message to be accessed from get_errormsg
self._error_msg = traceback.format_exc()

Expand Down
6 changes: 3 additions & 3 deletions interfaces/adfsuite/ams.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,8 +1802,8 @@ def get_density_along_axis(

start_step, end_step, every, _ = self._get_integer_start_end_every_max(start_fs, end_fs, every_fs, None)
nEntries = self.readrkf("History", "nEntries")
coords = np.array(self.get_history_property("Coords")).reshape(nEntries, -1, 3)
coords = coords[start_step:end_step:every]
history_coords = np.array(self.get_history_property("Coords")).reshape(nEntries, -1, 3)
coords = history_coords[start_step:end_step:every]
nEntries = len(coords)

axis2index = {"x": 0, "y": 1, "z": 2}
Expand Down Expand Up @@ -2507,7 +2507,7 @@ def get_errormsg(self) -> Optional[str]:
try:
log_err_lines = self.results.grep_file("ams.log", "ERROR: ")
if log_err_lines:
self._error_msg = log_err_lines[-1].partition("ERROR: ")[2]
self._error_msg: Optional[str] = log_err_lines[-1].partition("ERROR: ")[2]
return self._error_msg
except FileError:
pass
Expand Down
4 changes: 2 additions & 2 deletions mol/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,7 +1487,7 @@ def get_fragment(self, indices):

return ret

def get_complete_molecules_within_threshold(self, atom_indices, threshold: float):
def get_complete_molecules_within_threshold(self, atom_indices: List[int], threshold: float):
"""
Returns a new molecule containing complete submolecules for any molecules
that are closer than ``threshold`` to any of the atoms in ``atom_indices``.
Expand All @@ -1508,7 +1508,7 @@ def get_complete_molecules_within_threshold(self, atom_indices, threshold: float
zero_based_indices = [x - 1 for x in atom_indices]
D = distance_array(solvated_coords, solvated_coords)[zero_based_indices]
less_equal = np.less_equal(D, threshold)
within_threshold = np.any(less_equal, axis=0)
within_threshold: List[bool] = np.any(less_equal, axis=0) # type: ignore
good_indices = [i for i, value in enumerate(within_threshold) if value]

complete_indices: Set[int] = set()
Expand Down
9 changes: 9 additions & 0 deletions unit_tests/test_molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,15 @@ def test_system_and_atomic_charge(self, mol):
with pytest.raises(MoleculeError):
assert mol.guess_atomic_charges() == [1, 1, 0, 0]

def test_get_complete_molecules_within_threshold(self, mol):
m0 = mol.get_complete_molecules_within_threshold([2], 0)
m1 = mol.get_complete_molecules_within_threshold([2], 1)
m2 = mol.get_complete_molecules_within_threshold([2], 2)

assert m0.get_formula() == "H"
assert m1.get_formula() == "HO"
assert m2.get_formula() == "H2O"


class TestNiO(MoleculeTestBase):
"""
Expand Down

0 comments on commit ff7e709

Please sign in to comment.