Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid using full equality (==) to compare float, avoid assert_array_equal compare float array #4159

Open
wants to merge 60 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
588ceb8
replace some float equality check
DanielYang59 Nov 9, 2024
0b97cb0
explicit encoding
DanielYang59 Nov 9, 2024
82f3431
charge is also float
DanielYang59 Nov 9, 2024
389c59b
enhance types
DanielYang59 Nov 9, 2024
1d22fee
access gcd via math namespace as math is already imported
DanielYang59 Nov 9, 2024
84e3b70
put dunder method to top
DanielYang59 Nov 9, 2024
ea6089e
fix typo
DanielYang59 Nov 9, 2024
e264890
tweak _proj implementation
DanielYang59 Nov 9, 2024
95a6192
Merge branch 'master' into 4158-fix-eq-check
DanielYang59 Nov 9, 2024
e431882
support array like
DanielYang59 Nov 9, 2024
8f30f13
Merge branch '4158-fix-eq-check' of https://github.com/DanielYang59/p…
DanielYang59 Nov 9, 2024
e6ea809
add arg and return type
DanielYang59 Nov 9, 2024
bf0ff16
tweak type
DanielYang59 Nov 9, 2024
5c9992e
avoid more == for float comparison
DanielYang59 Nov 10, 2024
4920eb7
replace some == in test, more left to do
DanielYang59 Nov 10, 2024
f343503
replace more in core test
DanielYang59 Nov 10, 2024
808c495
replace more in test
DanielYang59 Nov 10, 2024
c0692dd
replace even more
DanielYang59 Nov 10, 2024
48e0ead
replace last batch
DanielYang59 Nov 10, 2024
cdff78d
clean up assert approx
DanielYang59 Nov 10, 2024
7eb7caa
replace pytest.approx with approx
DanielYang59 Nov 10, 2024
7745458
also fix membership check
DanielYang59 Nov 10, 2024
24edca5
replace some equality check of list
DanielYang59 Nov 10, 2024
089e3d2
replace some sequences
DanielYang59 Nov 10, 2024
20abc2a
fix test
DanielYang59 Nov 10, 2024
4871d1b
replace float comparison as dict
DanielYang59 Nov 10, 2024
7a8c148
fix test
DanielYang59 Nov 10, 2024
1137a72
replace more float compare, mostly for VASP
DanielYang59 Nov 10, 2024
88aad8b
fix test
DanielYang59 Nov 10, 2024
d12a07b
fix approx in condition block
DanielYang59 Nov 10, 2024
4552881
replace sci notation
DanielYang59 Nov 10, 2024
30e0f66
suppress buggy ruff sim300
DanielYang59 Nov 10, 2024
4f0ff82
number_of_permutations to int
DanielYang59 Nov 10, 2024
24e81d2
revert change for formula_double_format, in favor of another PR
DanielYang59 Nov 10, 2024
8ef27dc
c_indices seems to be int
DanielYang59 Nov 10, 2024
5ac947d
use sci notation for crazily large int
DanielYang59 Nov 10, 2024
2bde949
simplify numpy.testing usage
DanielYang59 Nov 10, 2024
16fa94d
set tol as pos arg
DanielYang59 Nov 10, 2024
300dc30
avoid array equal for list of str
DanielYang59 Nov 10, 2024
d4309b7
assert_array_equal should not be used on float array
DanielYang59 Nov 10, 2024
8cbfcfc
fix module level var name
DanielYang59 Nov 10, 2024
dbe8659
more assert_array_equal on complex number
DanielYang59 Nov 10, 2024
5ff4248
simplify approx on dict value
DanielYang59 Nov 10, 2024
1f01241
avoid module level var when it's used only 3 times
DanielYang59 Nov 10, 2024
32929d4
pytext.approx to approx
DanielYang59 Nov 10, 2024
16dbec3
fix approx on nested dict
DanielYang59 Nov 10, 2024
3df99ab
avoid unnecessary convert to np.array
DanielYang59 Nov 10, 2024
fd573cd
array_equal to all close for float array
DanielYang59 Nov 10, 2024
e46dbf9
assert all close for float array
DanielYang59 Nov 10, 2024
857581b
capital class attrib is treated as constant
DanielYang59 Nov 11, 2024
7787a25
Merge remote-tracking branch 'upstream/master' into 4158-fix-eq-check
DanielYang59 Nov 13, 2024
5700f3b
Merge remote-tracking branch 'upstream/master' into 4158-fix-eq-check
DanielYang59 Nov 14, 2024
79d3ffc
Merge branch 'master' into 4158-fix-eq-check
DanielYang59 Nov 16, 2024
1724f9e
Merge remote-tracking branch 'upstream/master' into 4158-fix-eq-check
DanielYang59 Nov 16, 2024
e7e5209
Merge branch 'master' into 4158-fix-eq-check
DanielYang59 Nov 18, 2024
626c3fb
Merge branch 'master' into 4158-fix-eq-check
DanielYang59 Nov 19, 2024
7c3b822
Merge branch 'master' into 4158-fix-eq-check
DanielYang59 Dec 11, 2024
95ec56a
Merge branch 'master' into 4158-fix-eq-check
DanielYang59 Jan 2, 2025
8284fdc
Merge remote-tracking branch 'upstream/master' into 4158-fix-eq-check
DanielYang59 Jan 10, 2025
17787db
Merge remote-tracking branch 'upstream/master' into 4158-fix-eq-check
DanielYang59 Jan 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/pymatgen/alchemy/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import abc
import math
from collections import defaultdict
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -285,7 +286,7 @@ def __init__(self):

def test(self, structure: Structure):
"""True if structure is neutral."""
return structure.charge == 0.0
return math.isclose(structure.charge, 0.0)


class SpeciesMaxDistFilter(AbstractStructureFilter):
Expand Down
5 changes: 4 additions & 1 deletion src/pymatgen/core/bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def obtain_all_bond_lengths(
If None, a ValueError will be thrown.

Returns:
A dict mapping bond order to bond length in angstrom
dict[float, float]: mapping bond order to bond length in Angstrom.

Todo:
it's better to avoid using float as dict keys.
"""
if isinstance(sp1, Element):
sp1 = sp1.symbol
Expand Down
2 changes: 1 addition & 1 deletion src/pymatgen/electronic_structure/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ def get_elt_projected_plots_color(
proj[b][str(spin)][band_idx][j][str(el)][o]
for o in proj[b][str(spin)][band_idx][j][str(el)]
)
if sum_e == 0.0:
if math.isclose(sum_e, 0.0):
color = [0.0] * len(elt_ordered)
else:
color = [
Expand Down
2 changes: 1 addition & 1 deletion src/pymatgen/io/aims/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def get_content(
magmom = structure.site_properties.get("magmom", spins)
if (
parameters.get("spin", "") == "collinear"
and np.all(magmom == 0.0)
and np.allclose(magmom, 0.0)
and ("default_initial_moment" not in parameters)
):
warn(
Expand Down
10 changes: 8 additions & 2 deletions src/pymatgen/io/cp2k/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,10 +1320,16 @@ def parse_bandstructure(self, bandstructure_filename=None) -> None:
else:
eigenvals = {Spin.up: bands_data.reshape((nbands, nkpts))}

occ = bands_data[:, 1][bands_data[:, -1] != 0.0]
# Filter out occupied and unoccupied states
occupied_mask = ~np.isclose(bands_data[:, -1], 0.0)
unoccupied_mask = np.isclose(bands_data[:, -1], 0.0)

occ = bands_data[:, 1][occupied_mask]
homo = np.max(occ)
unocc = bands_data[:, 1][bands_data[:, -1] == 0.0]

unocc = bands_data[:, 1][unoccupied_mask]
lumo = np.min(unocc)

efermi = (lumo + homo) / 2
self.efermi = efermi

Expand Down
12 changes: 6 additions & 6 deletions src/pymatgen/io/vasp/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,13 +783,13 @@ def run_type(self) -> str:
4: "dDsC",
}

if self.parameters.get("AEXX", 1.00) == 1.00:
if math.isclose(self.parameters.get("AEXX", 1.00), 1.00):
run_type = "HF"
elif self.parameters.get("HFSCREEN", 0.30) == 0.30:
elif math.isclose(self.parameters.get("HFSCREEN", 0.30), 0.30):
run_type = "HSE03"
elif self.parameters.get("HFSCREEN", 0.20) == 0.20:
elif math.isclose(self.parameters.get("HFSCREEN", 0.20), 0.20):
run_type = "HSE06"
elif self.parameters.get("AEXX", 0.20) == 0.20:
elif math.isclose(self.parameters.get("AEXX", 0.20), 0.20):
run_type = "B3LYP"
elif self.parameters.get("LHFCALC", True):
run_type = "PBEO or other Hybrid Functional"
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def get_band_structure(
if (hybrid_band or force_hybrid_mode) and not use_kpoints_opt:
start_bs_index = 0
for i in range(len(self.actual_kpoints)):
if self.actual_kpoints_weights[i] == 0.0:
if math.isclose(self.actual_kpoints_weights[i], 0.0):
start_bs_index = i
break
for i in range(start_bs_index, len(kpoint_file.kpts)):
Expand Down Expand Up @@ -5299,7 +5299,7 @@ def get_parchg(
Returns:
A Chgcar object.
"""
if phase and not np.all(self.kpoints[kpoint] == 0.0):
if phase and not np.allclose(self.kpoints[kpoint], 0.0):
warnings.warn(
"phase is True should only be used for the Gamma kpoint! I hope you know what you're doing!",
stacklevel=2,
Expand Down
89 changes: 52 additions & 37 deletions src/pymatgen/transformations/advanced_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
from collections.abc import Callable, Iterable, Sequence
from typing import Any, Literal

from numpy.typing import NDArray


__author__ = "Shyue Ping Ong, Stephen Dacek, Anubhav Jain, Matthew Horton, Alex Ganose"

Expand All @@ -67,6 +69,9 @@ def __init__(self, charge_balance_sp):
"""
self.charge_balance_sp = str(charge_balance_sp)

def __repr__(self):
return f"Charge Balance Transformation : Species to remove = {self.charge_balance_sp}"

def apply_transformation(self, structure: Structure):
"""Apply the transformation.

Expand All @@ -86,9 +91,6 @@ def apply_transformation(self, structure: Structure):
trans = SubstitutionTransformation({self.charge_balance_sp: {self.charge_balance_sp: 1 - removal_fraction}})
return trans.apply_transformation(structure)

def __repr__(self):
return f"Charge Balance Transformation : Species to remove = {self.charge_balance_sp}"


class SuperTransformation(AbstractTransformation):
"""This is a transformation that is inherently one-to-many. It is constructed
Expand All @@ -110,6 +112,9 @@ def __init__(self, transformations, nstructures_per_trans=1):
self._transformations = transformations
self.nstructures_per_trans = nstructures_per_trans

def __repr__(self):
return f"Super Transformation : Transformations = {' '.join(map(str, self._transformations))}"

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
"""Apply the transformation.

Expand Down Expand Up @@ -139,11 +144,8 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
)
return structures

def __repr__(self):
return f"Super Transformation : Transformations = {' '.join(map(str, self._transformations))}"

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -191,6 +193,9 @@ def __init__(
self.charge_balance_species = charge_balance_species
self.order = order

def __repr__(self):
return f"Multiple Substitution Transformation : Substitution on {self.sp_to_replace}"

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
"""Apply the transformation.

Expand Down Expand Up @@ -233,11 +238,8 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
outputs.append({"structure": new_structure})
return outputs

def __repr__(self):
return f"Multiple Substitution Transformation : Substitution on {self.sp_to_replace}"

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -322,6 +324,9 @@ def __init__(
if max_cell_size and max_disordered_sites:
raise ValueError("Cannot set both max_cell_size and max_disordered_sites!")

def __repr__(self):
return "EnumerateStructureTransformation"

def apply_transformation(
self, structure: Structure, return_ranked_list: bool | int = False
) -> Structure | list[dict]:
Expand Down Expand Up @@ -468,11 +473,8 @@ def sort_func(struct):
return self._all_structures[:num_to_return]
return self._all_structures[0]["structure"]

def __repr__(self):
return "EnumerateStructureTransformation"

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand All @@ -494,6 +496,9 @@ def __init__(self, threshold=1e-2, scale_volumes=True, **kwargs):
self.scale_volumes = scale_volumes
self._substitutor = SubstitutionPredictor(threshold=threshold, **kwargs)

def __repr__(self):
return "SubstitutionPredictorTransformation"

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
"""Apply the transformation.

Expand Down Expand Up @@ -528,11 +533,8 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
outputs.append(output)
return outputs

def __repr__(self):
return "SubstitutionPredictorTransformation"

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -895,7 +897,7 @@ def key(struct: Structure) -> int:
return self._all_structures[:num_to_return] # type: ignore[return-value]

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -984,15 +986,19 @@ def __init__(
self.allowed_doping_species = allowed_doping_species
self.kwargs = kwargs

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
def apply_transformation(
self,
structure: Structure,
return_ranked_list: bool | int = False,
) -> list[dict[Literal["structure", "energy"], Structure | float]] | Structure:
"""
Args:
structure (Structure): Input structure to dope
return_ranked_list (bool | int, optional): If return_ranked_list is int, that number of structures.
is returned. If False, only the single lowest energy structure is returned. Defaults to False.
structure (Structure): Input structure to dope.
return_ranked_list (bool | int, optional): If is int, that number of structures is returned.
If False, only the single lowest energy structure is returned. Defaults to False.

Returns:
list[dict] | Structure: each dict has shape {"structure": Structure, "energy": float}.
list[dict] | Structure: each dict as {"structure": Structure, "energy": float}.
"""
comp = structure.composition
logger.info(f"Composition: {comp}")
Expand Down Expand Up @@ -1125,7 +1131,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
return all_structures[0]["structure"]

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -1253,7 +1259,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
return disordered_structures

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -1714,7 +1720,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
return [{"structure": structure} for structure in structures[:return_ranked_list]]

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -1868,16 +1874,25 @@ def apply_transformation(
return [{"structure": structure} for structure in structures[:return_ranked_list]]

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True


def _proj(b, a):
"""Get vector projection (np.ndarray) of vector b (np.ndarray)
onto vector a (np.ndarray).
def _proj(b: NDArray, a: NDArray) -> NDArray:
"""Get vector projection of vector b onto vector a.

Args:
b (NDArray): Vector to be projected.
a (NDArray): Vector onto which `b` is projected.

Returns:
NDArray: Projection of `b` onto `a`.
"""
return (b.T @ (a / np.linalg.norm(a))) * (a / np.linalg.norm(a))
a = np.asarray(a)
b = np.asarray(b)

return (np.dot(b, a) / np.dot(a, a)) * a
Copy link
Contributor Author

@DanielYang59 DanielYang59 Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new implementation is slightly more readable (personal taste) and gives ~4x speedup, reference (the following is a project to b):

image
Original Implementation Time: 420.86 ms
New Implementation Time: 101.28 ms

Test script (by GPT):

import numpy as np
from numpy.typing import NDArray
from time import perf_counter_ns


def _proj_original(b: NDArray, a: NDArray) -> NDArray:
    return (b.T @ (a / np.linalg.norm(a))) * (a / np.linalg.norm(a))

def _proj_new(b: NDArray, a: NDArray) -> NDArray:
    return (np.dot(b, a) / np.dot(a, a)) * a

def verify_projection():
    a = np.random.rand(3)
    b = np.random.rand(3)
    proj1 = _proj_original(b, a)
    proj2 = _proj_new(b, a)
    assert np.allclose(proj1, proj2)

def benchmark_projections(n_iter=100000):
    a = np.random.rand(3)
    b = np.random.rand(3)

    # Measure original implementation
    start_time = perf_counter_ns()
    for _ in range(n_iter):
        _proj_original(b, a)
    time_original = perf_counter_ns() - start_time

    # Measure new implementation
    start_time = perf_counter_ns()
    for _ in range(n_iter):
        _proj_new(b, a)
    time_new = perf_counter_ns() - start_time

    print(f"Original Implementation Time: {time_original / 1e6:.2f} ms")
    print(f"New Implementation Time: {time_new / 1e6:.2f} ms")

verify_projection()

print("Benchmarking both implementations...")
benchmark_projections()



class SQSTransformation(AbstractTransformation):
Expand Down Expand Up @@ -2146,7 +2161,7 @@ def _get_unique_best_sqs_structs(sqs, best_only, return_ranked_list, remove_dupl
return to_return

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -2195,6 +2210,9 @@ def __init__(self, rattle_std: float, min_distance: float, seed: int | None = No
self.random_state = np.random.RandomState(seed)
self.kwargs = kwargs

def __repr__(self):
return f"{__name__} : rattle_std = {self.rattle_std}"

def apply_transformation(self, structure: Structure) -> Structure:
"""Apply the transformation.

Expand All @@ -2216,6 +2234,3 @@ def apply_transformation(self, structure: Structure) -> Structure:
structure.cart_coords + displacements,
coords_are_cartesian=True,
)

def __repr__(self):
return f"{__name__} : rattle_std = {self.rattle_std}"
8 changes: 5 additions & 3 deletions src/pymatgen/transformations/transformation_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from monty.json import MSONable

if TYPE_CHECKING:
from typing import Any, Literal

from pymatgen.core import Structure

__author__ = "Shyue Ping Ong"
Expand Down Expand Up @@ -55,7 +57,7 @@ def inverse(self) -> AbstractTransformation | None:
"""

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[False]:
"""Determine if a Transformation is a one-to-many transformation. In that case, the
apply_transformation method should have a keyword arg "return_ranked_list" which
allows for the transformed structures to be returned as a ranked list.
Expand All @@ -64,7 +66,7 @@ def is_one_to_many(self) -> bool:
return False

@property
def use_multiprocessing(self) -> bool:
def use_multiprocessing(self) -> Literal[False]:
"""Indicates whether the transformation can be applied by a
subprocessing pool. This should be overridden to return True for
transformations that the transmuter can parallelize.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_coordination_geometry(self):
assert cg_oct.IUCr_symbol_str == "[6o]"

cg_oct.permutations_safe_override = True
assert cg_oct.number_of_permutations == 720.0
DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved
assert cg_oct.number_of_permutations == 720
assert cg_oct.ref_permutation([0, 3, 2, 4, 5, 1]) == (0, 3, 1, 5, 2, 4)

sites = [FakeSite(coords=pp) for pp in cg_oct.points]
Expand Down
Loading
Loading