Skip to content

Commit

Permalink
Merge pull request #336 from libAtoms/fit_error_parity_sign
Browse files Browse the repository at this point in the history
Try to fix sign of scalar diffs and parity quantities in error analysis
  • Loading branch information
bernstei authored Aug 29, 2024
2 parents c59c86d + 71424a4 commit 113f8f7
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions wfl/fit/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def calc(inputs, calc_property_prefix, ref_property_prefix,
Returns
-------
errors: dict of RMSE and MAE for each category and property
diffs: dict with list of differences for each category and property
diffs: dict with list of differences for each category and property (signed for scalar
properties, norms for vectors)
parity: dict with "ref" and "calc" keys, each containing list of property values for
each category and property, for parity plots
"""
Expand Down Expand Up @@ -188,13 +189,13 @@ def _reshape_normalize(quant, prop, atoms, per_atom):

if len(diff.shape) != 2:
raise RuntimeError(f"Should never have diff.shape={diff.shape} with dim != 2 (prop {prop + atom_split_index_label})")
# compute norm along vector components
diff = np.linalg.norm(diff, axis=1)
if not per_component:
if diff.shape[1] > 1:
# compute norm along vector components
diff = np.linalg.norm(diff, axis=1)
if not per_component and selected_ref_quant.shape[1] > 1:
selected_ref_quant = np.linalg.norm(selected_ref_quant, axis=1)
selected_calc_quant = np.linalg.norm(selected_calc_quant, axis=1)


_dict_add([all_diffs, all_weights, all_parity["ref"], all_parity["calc"]],
[diff, _promote(weight, diff), selected_ref_quant, selected_calc_quant],
at_category, prop + atom_split_index_label)
Expand Down Expand Up @@ -380,16 +381,15 @@ def select_units(prop, plt_type, units_dict=None):
"energy/atom": {"parity": ("eV/at", 1.0), "error": ("meV/at", 1.0e3)},
"forces": {"parity": ("eV/Å", 1.0), "error": ("meV/Å", 1.0e3)},
"virial": {"parity": ("eV", 1.0), "error": ("meV", 1.0e3)},
"virial/atom": {"parity": ("eV/at", 1.0), "error": ("meV/at", 1.0e3)}
"virial/atom": {"parity": ("eV/at", 1.0), "error": ("meV/at", 1.0e3)},
"stress": {"parity": ("GPa", 1.0), "error": ("MPa", 1.0e3)},
}
if units_dict is None:
units_dict = {}
use_units_dict.update(units_dict)

if "virial" in prop:
prop = re.sub(r"/comp\b", "", prop)
prop = re.sub(r"/comp\b", "", prop)
if "forces" in prop:
prop = re.sub(r"/comp\b", "", prop)
prop = re.sub(r"/Z_\d+\b", "", prop)

if "energy" in prop:
Expand Down

0 comments on commit 113f8f7

Please sign in to comment.