Skip to content

Commit fd3eff6

Browse files
make tests nicer
1 parent 6c59a05 commit fd3eff6

File tree

2 files changed

+91
-111
lines changed

2 files changed

+91
-111
lines changed

src/easydynamics/utils/convolution.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def convolution(
5757
normalize_detailed_balance : bool, optional
5858
Whether to normalize the detailed balance factor. Default is True.
5959
"""
60+
61+
# Input validation
6062
if not isinstance(x, np.ndarray):
6163
raise TypeError(
6264
f"`x` is an instance of {type(x).__name__}, but must be a numpy array."

tests/unit_tests/utils/test_convolution.py

Lines changed: 89 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@ def resolution_model(self):
3434
test_resolution_model.add_component(Gaussian(center=0.2, width=0.3, area=3.0))
3535
return test_resolution_model
3636

37+
@pytest.fixture
38+
def gaussian_component(self):
39+
return Gaussian(center=0.1, width=0.3, area=2.0)
40+
41+
@pytest.fixture
42+
def other_gaussian_component(self):
43+
return Gaussian(center=0.2, width=0.4, area=3.0)
44+
45+
@pytest.fixture
46+
def lorentzian_component(self):
47+
return Lorentzian(center=0.1, width=0.3, area=2.0)
48+
49+
@pytest.fixture
50+
def other_lorentzian_component(self):
51+
return Lorentzian(center=0.2, width=0.4, area=3.0)
52+
3753
@pytest.fixture
3854
def x(self):
3955
return np.linspace(-50, 50, 50001)
@@ -51,12 +67,20 @@ def x(self):
5167
@pytest.mark.parametrize(
5268
"method", ["analytical", "numerical"], ids=["analytical", "numerical"]
5369
)
54-
def test_components_gauss_gauss(self, x, offset_obj, expected_shift, method):
70+
def test_components_gauss_gauss(
71+
self,
72+
x,
73+
gaussian_component,
74+
other_gaussian_component,
75+
offset_obj,
76+
expected_shift,
77+
method,
78+
):
5579
"Test convolution of Gaussian sample and Gaussian resolution components without SampleModel."
5680
"Test with different offset types and methods."
5781
# WHEN
58-
sample_gauss = Gaussian(center=0.1, width=0.3, area=2)
59-
resolution_gauss = Gaussian(center=0.2, width=0.4, area=3)
82+
sample_gauss = gaussian_component
83+
resolution_gauss = other_gaussian_component
6084

6185
# THEN
6286
calculated_convolution = convolution(
@@ -68,6 +92,7 @@ def test_components_gauss_gauss(self, x, offset_obj, expected_shift, method):
6892
)
6993

7094
# EXPECT
95+
# Convolution of two Gaussians is another Gaussian with width = sqrt(w1^2 + w2^2)
7196
expected_width = np.sqrt(
7297
sample_gauss.width.value**2 + resolution_gauss.width.value**2
7398
)
@@ -95,12 +120,14 @@ def test_components_gauss_gauss(self, x, offset_obj, expected_shift, method):
95120
@pytest.mark.parametrize(
96121
"method", ["analytical", "numerical"], ids=["analytical", "numerical"]
97122
)
98-
def test_components_DHO_gauss(self, x, offset_obj, expected_shift, method):
123+
def test_components_DHO_gauss(
124+
self, x, gaussian_component, offset_obj, expected_shift, method
125+
):
99126
"Test convolution of DHO sample and Gaussian resolution components without SampleModel."
100127
"Test with different offset types and methods."
101128
# WHEN
102129
sample_dho = DampedHarmonicOscillator(center=1.5, width=0.3, area=2)
103-
resolution_gauss = Gaussian(center=0.2, width=0.4, area=3)
130+
resolution_gauss = gaussian_component
104131

105132
# THEN
106133
calculated_convolution = convolution(
@@ -112,6 +139,7 @@ def test_components_DHO_gauss(self, x, offset_obj, expected_shift, method):
112139
)
113140

114141
# EXPECT
142+
# no simple analytical form, so compute expected result via direct convolution
115143
sample_values = sample_dho.evaluate(x - expected_shift)
116144
resolution_values = resolution_gauss.evaluate(x)
117145
expected_result = fftconvolve(sample_values, resolution_values, mode="same")
@@ -132,13 +160,19 @@ def test_components_DHO_gauss(self, x, offset_obj, expected_shift, method):
132160
"method", ["analytical", "numerical"], ids=["analytical", "numerical"]
133161
)
134162
def test_components_lorentzian_lorentzian(
135-
self, x, offset_obj, expected_shift, method
163+
self,
164+
x,
165+
lorentzian_component,
166+
other_lorentzian_component,
167+
offset_obj,
168+
expected_shift,
169+
method,
136170
):
137171
"Test convolution of Lorentzian sample and Lorentzian resolution components without SampleModel."
138172
"Test with different offset types and methods."
139173
# WHEN
140-
sample_lorentzian = Lorentzian(center=0.1, width=0.3, area=2)
141-
resolution_lorentzian = Lorentzian(center=0.2, width=0.4, area=3)
174+
sample_lorentzian = lorentzian_component
175+
resolution_lorentzian = other_lorentzian_component
142176

143177
# THEN
144178
calculated_convolution = convolution(
@@ -151,6 +185,7 @@ def test_components_lorentzian_lorentzian(
151185
)
152186

153187
# EXPECT
188+
# Convolution of two Lorentzians is another Lorentzian with width = w1 + w2
154189
expected_width = (
155190
sample_lorentzian.width.value + resolution_lorentzian.width.value
156191
)
@@ -186,83 +221,56 @@ def test_components_lorentzian_lorentzian(
186221
@pytest.mark.parametrize(
187222
"method", ["analytical", "numerical"], ids=["analytical", "numerical"]
188223
)
189-
def test_components_gauss_lorentzian(self, x, offset_obj, expected_shift, method):
190-
"Test convolution of Gaussian sample and Lorentzian resolution components without SampleModel."
224+
@pytest.mark.parametrize(
225+
"sample_is_gauss",
226+
[True, False],
227+
ids=["gauss_sample__lorentz_resolution", "lorentz_sample__gauss_resolution"],
228+
)
229+
def test_components_gauss_lorentzian(
230+
self,
231+
x,
232+
gaussian_component,
233+
lorentzian_component,
234+
offset_obj,
235+
expected_shift,
236+
method,
237+
sample_is_gauss,
238+
):
239+
"Test convolution of Gaussian and Lorentzian components without SampleModel."
191240
"Test with different offset types and methods."
192241
# WHEN
193-
sample_gauss = Gaussian(center=0.1, width=0.3, area=2)
194-
resolution_lorentzian = Lorentzian(center=0.2, width=0.4, area=3)
242+
if sample_is_gauss:
243+
sample = gaussian_component
244+
resolution = lorentzian_component
245+
else:
246+
sample = lorentzian_component
247+
resolution = gaussian_component
195248

196249
# THEN
197250
calculated_convolution = convolution(
198251
x=x,
199-
sample_model=sample_gauss,
200-
resolution_model=resolution_lorentzian,
252+
sample_model=sample,
253+
resolution_model=resolution,
201254
offset=offset_obj,
202255
method=method,
203256
upsample_factor=5,
204257
)
205258

206259
# EXPECT
207-
expected_center = (
208-
sample_gauss.center.value
209-
+ resolution_lorentzian.center.value
210-
+ expected_shift
211-
)
212-
expected_area = sample_gauss.area.value * resolution_lorentzian.area.value
213-
expected_result = expected_area * voigt_profile(
214-
x - expected_center,
215-
sample_gauss.width.value,
216-
resolution_lorentzian.width.value,
217-
)
260+
expected_center = sample.center.value + resolution.center.value + expected_shift
261+
expected_area = sample.area.value * resolution.area.value
218262

219-
np.testing.assert_allclose(
220-
calculated_convolution,
221-
expected_result,
222-
atol=NUMERICAL_CONVOLUTION_ABSOLUTE_TOLERANCE,
223-
rtol=NUMERICAL_CONVOLUTION_RELATIVE_TOLERANCE,
263+
gaussian_width = (
264+
sample.width.value if sample_is_gauss else resolution.width.value
224265
)
225-
226-
@pytest.mark.parametrize(
227-
"offset_obj, expected_shift",
228-
[
229-
(None, 0.0),
230-
(0.4, 0.4),
231-
(Parameter("off", 0.4), 0.4),
232-
],
233-
ids=["none", "float", "parameter"],
234-
)
235-
@pytest.mark.parametrize(
236-
"method", ["analytical", "numerical"], ids=["analytical", "numerical"]
237-
)
238-
def test_components_lorentzian_gauss(self, x, offset_obj, expected_shift, method):
239-
"Test convolution of Lorentzian sample and Gaussian resolution components without SampleModel."
240-
"Test with different offset types and methods."
241-
# WHEN
242-
resolution_gauss = Gaussian(center=0.1, width=0.3, area=2)
243-
sample_lorentzian = Lorentzian(center=0.2, width=0.4, area=3)
244-
245-
# THEN
246-
calculated_convolution = convolution(
247-
x=x,
248-
sample_model=sample_lorentzian,
249-
resolution_model=resolution_gauss,
250-
offset=offset_obj,
251-
method=method,
252-
upsample_factor=5,
266+
lorentzian_width = (
267+
resolution.width.value if sample_is_gauss else sample.width.value
253268
)
254269

255-
# EXPECT
256-
expected_center = (
257-
sample_lorentzian.center.value
258-
+ resolution_gauss.center.value
259-
+ expected_shift
260-
)
261-
expected_area = sample_lorentzian.area.value * resolution_gauss.area.value
262270
expected_result = expected_area * voigt_profile(
263271
x - expected_center,
264-
resolution_gauss.width.value,
265-
sample_lorentzian.width.value,
272+
gaussian_width,
273+
lorentzian_width,
266274
)
267275

268276
np.testing.assert_allclose(
@@ -284,12 +292,23 @@ def test_components_lorentzian_gauss(self, x, offset_obj, expected_shift, method
284292
@pytest.mark.parametrize(
285293
"method", ["analytical", "numerical"], ids=["analytical", "numerical"]
286294
)
287-
def test_components_delta_gauss(self, x, offset_obj, expected_shift, method):
295+
@pytest.mark.parametrize(
296+
"sample_is_gauss",
297+
[True, False],
298+
ids=["gauss_sample__delta_resolution", "delta_sample__gauss_resolution"],
299+
)
300+
def test_components_delta_gauss(
301+
self, x, gaussian_component, offset_obj, expected_shift, method, sample_is_gauss
302+
):
288303
"Test convolution of Delta function sample and Gaussian resolution components without SampleModel."
289304
"Test with different offset types and methods."
290305
# WHEN
291-
sample_delta = DeltaFunction(name="Delta", center=0.1, area=2)
292-
resolution_gauss = Gaussian(center=0.2, width=0.3, area=3)
306+
if sample_is_gauss:
307+
sample_delta = DeltaFunction(name="Delta", center=0.1, area=2)
308+
resolution_gauss = gaussian_component
309+
else:
310+
sample_delta = DeltaFunction(name="Delta", center=0.1, area=2)
311+
resolution_gauss = gaussian_component
293312

294313
# THEN
295314
calculated_convolution = convolution(
@@ -313,47 +332,6 @@ def test_components_delta_gauss(self, x, offset_obj, expected_shift, method):
313332

314333
np.testing.assert_allclose(calculated_convolution, expected_result, atol=1e-10)
315334

316-
@pytest.mark.parametrize(
317-
"offset_obj, expected_shift",
318-
[
319-
(None, 0.0),
320-
(0.4, 0.4),
321-
(Parameter("off", 0.4), 0.4),
322-
],
323-
ids=["none", "float", "parameter"],
324-
)
325-
@pytest.mark.parametrize(
326-
"method", ["analytical", "numerical"], ids=["analytical", "numerical"]
327-
)
328-
def test_components_gauss_delta(self, x, offset_obj, expected_shift, method):
329-
"Test convolution of Gaussian sample and Delta function resolution components without SampleModel."
330-
"Test with different offset types and methods."
331-
# WHEN
332-
sample_gauss = Gaussian(center=0.1, width=0.2, area=2)
333-
resolution_delta = DeltaFunction(name="Delta", center=0.2, area=3)
334-
335-
# THEN
336-
calculated_convolution = convolution(
337-
x=x,
338-
sample_model=sample_gauss,
339-
resolution_model=resolution_delta,
340-
offset=offset_obj,
341-
method=method,
342-
)
343-
344-
# EXPECT
345-
expected_center = (
346-
sample_gauss.center.value + resolution_delta.center.value + expected_shift
347-
)
348-
expected_area = sample_gauss.area.value * resolution_delta.area.value
349-
expected_result = (
350-
expected_area
351-
* np.exp(-0.5 * ((x - expected_center) / sample_gauss.width.value) ** 2)
352-
/ (np.sqrt(2 * np.pi) * sample_gauss.width.value)
353-
)
354-
355-
np.testing.assert_allclose(calculated_convolution, expected_result, atol=1e-10)
356-
357335
# Test convolution of SampleModel
358336
@pytest.mark.parametrize(
359337
"offset_obj, expected_shift",

0 commit comments

Comments
 (0)