Skip to content

Commit

Permalink
Fix returns and anonymous bools (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
oerc0122 authored Sep 10, 2024
1 parent e3d3ae5 commit bad6325
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 44 deletions.
3 changes: 1 addition & 2 deletions castep_outputs/cli/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,4 @@ def parse_args(to_parse: Sequence[str] = ()) -> argparse.Namespace:

def args_to_dict(args: argparse.Namespace) -> dict[str, list[str]]:
""" Convert args namespace to dictionary """
out_dict = {typ: getattr(args, typ) for typ in CASTEP_OUTPUT_NAMES}
return out_dict
return {typ: getattr(args, typ) for typ in CASTEP_OUTPUT_NAMES}
63 changes: 32 additions & 31 deletions castep_outputs/parsers/castep_file_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def parse_castep_file(castep_file_in: TextIO,

logger("Found k-points")

curr_run["k-points"] = _process_kpoint_blocks(block, True)
curr_run["k-points"] = _process_kpoint_blocks(block, implicit_kpoints=True)

elif block := Block.from_re(line, castep_file,
gen_table_re("Number +Fractional coordinates +Weight", r"\+"),
Expand All @@ -699,7 +699,7 @@ def parse_castep_file(castep_file_in: TextIO,

logger("Found k-points list")

curr_run["k-points"] = _process_kpoint_blocks(block, False)
curr_run["k-points"] = _process_kpoint_blocks(block, implicit_kpoints=False)

elif "Applied Electric Field" in line:

Expand Down Expand Up @@ -811,7 +811,7 @@ def parse_castep_file(castep_file_in: TextIO,

logger("Found optical permittivity")

val = _process_3_6_matrix(block, True)
val = _process_3_6_matrix(block, split=True)
curr_run["optical_permittivity"] = val[0]
if val[1]:
curr_run["dc_permittivity"] = val[1]
Expand All @@ -824,7 +824,7 @@ def parse_castep_file(castep_file_in: TextIO,

logger("Found polarisability")

val = _process_3_6_matrix(block, True)
val = _process_3_6_matrix(block, split=True)
curr_run["optical_polarisability"] = val[0]
if val[1]:
curr_run["static_polarisability"] = val[1]
Expand All @@ -838,7 +838,7 @@ def parse_castep_file(castep_file_in: TextIO,

logger("Found NLO")

curr_run["nlo"], _ = _process_3_6_matrix(block, False)
curr_run["nlo"], _ = _process_3_6_matrix(block, split=False)

# Atomic displacements
elif block := Block.from_re(line, castep_file,
Expand Down Expand Up @@ -1051,7 +1051,7 @@ def parse_castep_file(castep_file_in: TextIO,
continue

if not (match := re.search(REs.MINIMISERS_RE, line)):
raise OSError("Invalid Geom block")
raise ValueError("Invalid Geom block")

typ = match.group(0)

Expand Down Expand Up @@ -1173,7 +1173,7 @@ def parse_castep_file(castep_file_in: TextIO,
if "elastic" not in curr_run:
curr_run["elastic"] = {}

val, _ = _process_3_6_matrix(block, False)
val, _ = _process_3_6_matrix(block, split=False)
curr_run["elastic"]["elastic_constants"] = val

elif block := Block.from_re(line, castep_file,
Expand All @@ -1188,7 +1188,7 @@ def parse_castep_file(castep_file_in: TextIO,
if "elastic" not in curr_run:
curr_run["elastic"] = {}

val, _ = _process_3_6_matrix(block, False)
val, _ = _process_3_6_matrix(block, split=False)
curr_run["elastic"]["compliance_matrix"] = val

elif block := Block.from_re(line, castep_file, "Contribution ::", REs.EMPTY):
Expand All @@ -1207,7 +1207,7 @@ def parse_castep_file(castep_file_in: TextIO,
if "elastic" not in curr_run:
curr_run["elastic"] = {}

val, _ = _process_3_6_matrix(block, False)
val, _ = _process_3_6_matrix(block, split=False)
curr_run["elastic"][typ] = val

elif block := Block.from_re(line, castep_file,
Expand Down Expand Up @@ -1389,20 +1389,22 @@ def _process_ps_energy(block: Block) -> tuple[str, PSPotEnergy]:


def _process_tddft(block: Block) -> list[TDDFTData]:
tddata = [{"energy": float(match["energy"]),
"error": float(match["error"]),
"type": match["type"]}
for line in block
if (match := REs.TDDFT_RE.match(line))]
return tddata
return [
{"energy": float(match["energy"]),
"error": float(match["error"]),
"type": match["type"]}
for line in block
if (match := REs.TDDFT_RE.match(line))
]


def _process_atreg_block(block: Block) -> AtomPropBlock:
accum: AtomPropBlock = {
return {
atreg_to_index(match): to_type(match.group("x", "y", "z"), float)
for line in block
if (match := REs.ATDAT3VEC.search(line))}
return accum
if (match := REs.ATDAT3VEC.search(line))
}



def _process_spec_prop(block: Block) -> list[list[str]]:
Expand All @@ -1429,16 +1431,18 @@ def _process_md_block(block: Block) -> MDInfo:


def _process_elf(block: Block) -> list[float]:
curr_data = [float(match.group(1)) for line in block
if (match := re.match(rf"\s+ELF\s+\d+\s+({REs.FNUMBER_RE})", line))]
return curr_data
return [
float(match.group(1)) for line in block
if (match := re.match(rf"\s+ELF\s+\d+\s+({REs.FNUMBER_RE})", line))
]


def _process_hirshfeld(block: Block) -> dict[AtomIndex, float]:
""" Process Hirshfeld block to dict of charges """
accum = {atreg_to_index(match): float(match["charge"]) for line in block
if (match := re.match(rf"\s+{REs.ATREG}\s+(?P<charge>{REs.FNUMBER_RE})", line))}
return accum
return {
atreg_to_index(match): float(match["charge"]) for line in block
if (match := re.match(rf"\s+{REs.ATREG}\s+(?P<charge>{REs.FNUMBER_RE})", line))
}


def _process_thermodynamics(block: Block) -> Thermodynamics:
Expand Down Expand Up @@ -1473,8 +1477,7 @@ def _process_atom_disp(block: Block) -> dict[str, dict[AtomIndex, SixVector]]:


def _process_3_6_matrix(
block: Block,
split: bool,
block: Block, *, split: bool,
) -> tuple[ThreeByThreeMatrix, ThreeByThreeMatrix | None]:
""" Process a single or pair of 3x3 matrices or 3x6 matrix """
parsed = tuple(to_type(vals, float) for line in block
Expand Down Expand Up @@ -1870,7 +1873,7 @@ def _process_phonon_sym_analysis(block: Block) -> PhononSymmetryReport:
return accum


def _process_kpoint_blocks(block: Block,
def _process_kpoint_blocks(block: Block, *,
implicit_kpoints: bool) -> KPointsList | KPointsSpec:

if implicit_kpoints:
Expand Down Expand Up @@ -1970,15 +1973,13 @@ def _process_dynamical_matrix(block: Block) -> tuple[tuple[complex, ...], ...]:
# Get remainder
imag_part = [numbers[2:] for line in block if (numbers := get_numbers(line))]

accum = tuple(
return tuple(
tuple(complex(float(real), float(imag)) for real, imag in zip(real_row, imag_row))
for real_row, imag_row in zip(real_part, imag_part)
)

return accum


def _process_pspot_string(string: str, debug=False) -> PSPotStrInfo:
def _process_pspot_string(string: str, *, debug=False) -> PSPotStrInfo:
if not (match := REs.PSPOT_RE.search(string)):
raise ValueError(f"Attempt to parse {string} as PSPot failed")

Expand Down
9 changes: 3 additions & 6 deletions castep_outputs/parsers/cell_param_file_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,16 +187,13 @@ def _parse_sedc(block: Block) -> dict[str, dict[str, float]]:

def _parse_symops(block: Block) -> list[dict[str, ThreeByThreeMatrix | ThreeVector]]:

accum = []
tmp = [to_type(numbers, float)
for line in block
if (numbers := REs.FLOAT_RAT_RE.findall(line))]

accum = [{"r": tmp[i:i+3],
"t": tmp[i+3]}
for i in range(0, len(tmp), 4)]

return accum
return [{"r": tmp[i:i+3],
"t": tmp[i+3]}
for i in range(0, len(tmp), 4)]


def _parse_general(block: Block) -> dict[str, str | float]:
Expand Down
6 changes: 3 additions & 3 deletions castep_outputs/utilities/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def normalise_key(string: str) -> str:
return re.sub(r"[_\W]+", "_", string).strip("_").lower()


def atreg_to_index(dict_in: dict[str, str] | re.Match, clear: bool = True) -> tuple[str, int]:
def atreg_to_index(dict_in: dict[str, str] | re.Match, *, clear: bool = True) -> tuple[str, int]:
"""
Transform a matched atreg value to species index tuple
Also clear value from dictionary for easier processing
Expand Down Expand Up @@ -86,7 +86,7 @@ def json_safe(obj: Any) -> Any:
return obj


def flatten_dict(dictionary: MutableMapping[Any, Any],
def flatten_dict(dictionary: MutableMapping[Any, Any], *,
parent_key: bool = False,
separator: str = "_") -> dict[str, Any]:
"""
Expand Down Expand Up @@ -121,7 +121,7 @@ def stack_dict(out_dict: dict[Any, list], in_dict: dict[Any, list]) -> dict[Any,


def add_aliases(in_dict: dict[str, Any],
alias_dict: dict[str, str],
alias_dict: dict[str, str], *,
replace: bool = False,
inplace: bool = True) -> dict[str, Any]:
""" Adds aliases of known names into dictionary, if replace is true, remove original """
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ select = [
"ISC", # Flake8 implicit string concat
"RSE", # Flake8 raise
"FA", # Flake8 future
# "FBT", # Flake8 boolean trap
"FBT", # Flake8 boolean trap
"C4", # Flake8 comprehensions
"Q", # Flake8 Quotes
# "RET", # Flake8 return
"RET", # Flake8 return
"ARG", # Flake8 unused args
"PTH", # Flake8 use pathlib
"I", # Isort
Expand Down

0 comments on commit bad6325

Please sign in to comment.