-
Notifications
You must be signed in to change notification settings - Fork 875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid using full equality (==
) to compare float, avoid assert_array_equal
compare float array
#4159
base: master
Are you sure you want to change the base?
Avoid using full equality (==
) to compare float, avoid assert_array_equal
compare float array
#4159
Conversation
"""Get vector projection (np.ndarray) of vector b (np.ndarray) | ||
onto vector a (np.ndarray). | ||
""" | ||
return (b.T @ (a / np.linalg.norm(a))) * (a / np.linalg.norm(a)) | ||
return (np.dot(b, a) / np.dot(a, a)) * a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new implementation is slightly more readable (personal taste) and gives ~4x speedup, reference (the following is a project to b):
Original Implementation Time: 420.86 ms
New Implementation Time: 101.28 ms
Test script (by GPT):
import numpy as np
from numpy.typing import NDArray
from time import perf_counter_ns
def _proj_original(b: NDArray, a: NDArray) -> NDArray:
return (b.T @ (a / np.linalg.norm(a))) * (a / np.linalg.norm(a))
def _proj_new(b: NDArray, a: NDArray) -> NDArray:
return (np.dot(b, a) / np.dot(a, a)) * a
def verify_projection():
a = np.random.rand(3)
b = np.random.rand(3)
proj1 = _proj_original(b, a)
proj2 = _proj_new(b, a)
assert np.allclose(proj1, proj2)
def benchmark_projections(n_iter=100000):
a = np.random.rand(3)
b = np.random.rand(3)
# Measure original implementation
start_time = perf_counter_ns()
for _ in range(n_iter):
_proj_original(b, a)
time_original = perf_counter_ns() - start_time
# Measure new implementation
start_time = perf_counter_ns()
for _ in range(n_iter):
_proj_new(b, a)
time_new = perf_counter_ns() - start_time
print(f"Original Implementation Time: {time_original / 1e6:.2f} ms")
print(f"New Implementation Time: {time_new / 1e6:.2f} ms")
verify_projection()
print("Benchmarking both implementations...")
benchmark_projections()
advanced_transformations
advanced_transformations
==
) to compare float
@@ -38,22 +38,22 @@ def setUp(self): | |||
self.site2_xanes = XAS.from_dict(site2_xanes_dict) | |||
|
|||
def test_e0(self): | |||
assert approx(self.k_xanes.e0) == 7728.565 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe a == pytest.approx(ref_val)
is the recommended usage plus it's slightly more readable. However do note approx
is asymmetric:
a == pytest.approx(b, rel=1e-6, abs=1e-12): True if the relative tolerance is met w.r.t. b or if the absolute tolerance is met. Because the relative tolerance is only calculated w.r.t. b, this test is asymmetric and you can think of b as the reference value. In the special case that you explicitly specify an absolute tolerance but not a relative tolerance, only the absolute tolerance is considered.
==
) to compare float ==
) to compare float, avoid assert_array_equal
compare float array
Summary
==
to compare float, to fix tests use equality to compare floating point numbers #4158assert_array_equal
on int array:_proj
implementation, ~3x speedup==
(list/tuple/dict ...):pymatgen/tests/core/test_bonds.py
Line 56 in bd9fba9