Skip to content

Commit c5f284d

Browse files
Split tests into components
1 parent 938a366 commit c5f284d

File tree

8 files changed

+1132
-1090
lines changed

8 files changed

+1132
-1090
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import pytest
2+
3+
import numpy as np
4+
import scipp as sc
5+
6+
from scipy.integrate import simpson
7+
8+
from easydynamics.sample_model import DampedHarmonicOscillator
9+
10+
from easyscience.variable import Parameter
11+
12+
13+
class TestDampedHarmonicOscillator:
14+
@pytest.fixture
15+
def dho(self):
16+
return DampedHarmonicOscillator(
17+
name="TestDHO", area=2.0, center=1.5, width=0.3, unit="meV"
18+
)
19+
20+
def test_initialization(self, dho: DampedHarmonicOscillator):
21+
assert dho.name == "TestDHO"
22+
assert dho.area.value == 2.0
23+
assert dho.center.value == 1.5
24+
assert dho.width.value == 0.3
25+
assert dho.unit == "meV"
26+
27+
def test_input_type_validation_raises(self):
28+
with pytest.raises(TypeError, match="area must be a number or a Parameter"):
29+
DampedHarmonicOscillator(
30+
name="TestDampedHarmonicOscillator",
31+
area="invalid",
32+
center=0.5,
33+
width=0.6,
34+
unit="meV",
35+
)
36+
37+
with pytest.raises(TypeError, match="center must be a number or a Parameter"):
38+
DampedHarmonicOscillator(
39+
name="TestDampedHarmonicOscillator",
40+
area=2.0,
41+
center="invalid",
42+
width=0.6,
43+
unit="meV",
44+
)
45+
46+
with pytest.raises(TypeError, match="width must be a number or a Parameter"):
47+
DampedHarmonicOscillator(
48+
name="TestDampedHarmonicOscillator",
49+
area=2.0,
50+
center=0.5,
51+
width="invalid",
52+
unit="meV",
53+
)
54+
55+
with pytest.raises(TypeError, match="unit must be a string"):
56+
DampedHarmonicOscillator(
57+
name="TestDampedHarmonicOscillator",
58+
area=2.0,
59+
center=0.5,
60+
width=0.6,
61+
unit=123,
62+
)
63+
64+
def test_negative_width_raises(self):
65+
with pytest.raises(
66+
ValueError,
67+
match="The width of a DampedHarmonicOscillator must be greater than zero.",
68+
):
69+
DampedHarmonicOscillator(
70+
name="TestDampedHarmonicOscillator",
71+
area=2.0,
72+
center=0.5,
73+
width=-0.6,
74+
unit="meV",
75+
)
76+
77+
def test_negative_width_raises_in_evaluate(self):
78+
test_dho = DampedHarmonicOscillator(
79+
name="TestDampedHarmonicOscillator",
80+
area=2.0,
81+
center=0.5,
82+
width=0.6,
83+
unit="meV",
84+
)
85+
test_dho.width.value = -0.6
86+
with pytest.raises(
87+
ValueError,
88+
match="The width of a DampedHarmonicOscillator must be greater than zero.",
89+
):
90+
test_dho.evaluate(np.array([0.0, 1.5, 3.0]))
91+
92+
def test_negative_area_warns(self):
93+
with pytest.warns(UserWarning, match="may not be physically meaningful"):
94+
DampedHarmonicOscillator(
95+
name="TestDampedHarmonicOscillator",
96+
area=-2.0,
97+
center=0.5,
98+
width=0.6,
99+
unit="meV",
100+
)
101+
102+
def test_negative_area_warns_in_evaluate(self):
103+
test_dho = DampedHarmonicOscillator(
104+
name="TestDampedHarmonicOscillator",
105+
area=2.0,
106+
center=0.5,
107+
width=0.6,
108+
unit="meV",
109+
)
110+
test_dho.area.value = -2.0
111+
with pytest.warns(UserWarning, match="may not be physically meaningful"):
112+
test_dho.evaluate(np.array([0.0, 1.5, 3.0]))
113+
114+
def test_evaluate(self, dho: DampedHarmonicOscillator):
115+
x = np.array([0.0, 1.5, 3.0])
116+
expected = dho.evaluate(x)
117+
expected_result = (
118+
2
119+
* 2.0
120+
* (1.5**2)
121+
* (0.3)
122+
/ np.pi
123+
/ ((x**2 - 1.5**2) ** 2 + (2 * 0.3 * x) ** 2)
124+
)
125+
np.testing.assert_allclose(expected, expected_result, rtol=1e-5)
126+
127+
def test_evaluate_scipp_array(self, dho: DampedHarmonicOscillator):
128+
x = sc.array(dims=["x"], values=[0.0, 1.5, 3.0], unit="meV")
129+
expected = dho.evaluate(x)
130+
expected_result = (
131+
2
132+
* 2.0
133+
* (1.5**2)
134+
* (0.3)
135+
/ np.pi
136+
/ ((x.values**2 - 1.5**2) ** 2 + (2 * 0.3 * x.values) ** 2)
137+
)
138+
np.testing.assert_allclose(expected, expected_result, rtol=1e-5)
139+
140+
def test_evaluate_with_different_unit(self, dho: DampedHarmonicOscillator):
141+
x = sc.array(dims=["x"], values=[0.0, 500.0, 1000.0], unit="microeV")
142+
expected = dho.evaluate(x)
143+
expected_result = (
144+
2
145+
* 2.0
146+
* 1e3
147+
* ((1.5 * 1e3) ** 2)
148+
* (0.3 * 1e3)
149+
/ np.pi
150+
/ ((x.values**2 - (1.5 * 1e3) ** 2) ** 2 + (2 * 0.3 * 1e3 * x.values) ** 2)
151+
)
152+
np.testing.assert_allclose(expected, expected_result, rtol=1e-5)
153+
154+
def test_input_as_parameter(self):
155+
param_area = Parameter(name="area_param", value=2.0, unit="meV")
156+
param_center = Parameter(name="center_param", value=0.5, unit="meV")
157+
param_width = Parameter(name="width_param", value=0.6, unit="meV")
158+
test_dho = DampedHarmonicOscillator(
159+
name="TestDHO",
160+
area=param_area,
161+
center=param_center,
162+
width=param_width,
163+
unit="meV",
164+
)
165+
assert test_dho.area == param_area
166+
assert test_dho.center == param_center
167+
assert test_dho.width == param_width
168+
169+
def test_get_parameters(self, dho: DampedHarmonicOscillator):
170+
params = dho.get_parameters()
171+
assert len(params) == 3
172+
assert params[0].name == "TestDHO area"
173+
assert params[1].name == "TestDHO center"
174+
assert params[2].name == "TestDHO width"
175+
assert all(isinstance(param, Parameter) for param in params)
176+
177+
def test_area_matches_parameter(self, dho: DampedHarmonicOscillator):
178+
# WHEN
179+
x = np.linspace(
180+
-dho.center.value - 20 * dho.width.value,
181+
dho.center.value + 20 * dho.width.value,
182+
5000,
183+
)
184+
y = dho.evaluate(x)
185+
numerical_area = simpson(y, x)
186+
187+
# THEN EXPECT
188+
assert numerical_area == pytest.approx(dho.area.value, rel=2e-3)
189+
190+
def test_convert_unit(self, dho: DampedHarmonicOscillator):
191+
dho.convert_unit("microeV")
192+
193+
assert dho.unit == "microeV"
194+
assert dho.area.value == 2 * 1e3
195+
assert dho.center.value == 1.5 * 1e3
196+
assert dho.width.value == 0.3 * 1e3
197+
198+
def test_copy(self, dho: DampedHarmonicOscillator):
199+
dho_copy = dho.copy()
200+
assert dho_copy is not dho
201+
assert dho_copy.name == dho.name
202+
203+
assert dho_copy.area.value == dho.area.value
204+
assert dho_copy.area.fixed == dho.area.fixed
205+
206+
assert dho_copy.center.value == dho.center.value
207+
assert dho_copy.center.fixed == dho.center.fixed
208+
209+
assert dho_copy.width.value == dho.width.value
210+
assert dho_copy.width.fixed == dho.width.fixed
211+
212+
assert dho_copy.unit == dho.unit
213+
214+
def test_repr(self, dho: DampedHarmonicOscillator):
215+
repr_str = repr(dho)
216+
assert "DampedHarmonicOscillator" in repr_str
217+
assert "name = TestDHO" in repr_str
218+
assert "unit = meV" in repr_str
219+
assert "area =" in repr_str
220+
assert "center =" in repr_str
221+
assert "width =" in repr_str
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import pytest
2+
3+
import numpy as np
4+
5+
from easydynamics.sample_model import DeltaFunction
6+
from easyscience.variable import Parameter
7+
8+
9+
class TestDeltaFunction:
10+
@pytest.fixture
11+
def delta_function(self):
12+
return DeltaFunction(name="TestDeltaFunction", area=2.0, center=0.5, unit="meV")
13+
14+
def test_initialization(self, delta_function: DeltaFunction):
15+
assert delta_function.name == "TestDeltaFunction"
16+
assert delta_function.area.value == 2.0
17+
assert delta_function.center.value == 0.5
18+
assert delta_function.unit == "meV"
19+
20+
def test_input_type_validation_raises(self):
21+
with pytest.raises(TypeError, match="area must be a number or a Parameter"):
22+
DeltaFunction(
23+
name="TestDeltaFunction",
24+
area="invalid",
25+
center=0.5,
26+
unit="meV",
27+
)
28+
with pytest.raises(
29+
TypeError, match="center must be None, a number or a Parameter"
30+
):
31+
DeltaFunction(
32+
name="TestDeltaFunction",
33+
area=2.0,
34+
center="invalid",
35+
unit="meV",
36+
)
37+
with pytest.raises(TypeError, match="unit must be a string"):
38+
DeltaFunction(name="TestDeltaFunction", area=2.0, center=0.5, unit=123)
39+
40+
def test_negative_area_warns(self):
41+
with pytest.warns(UserWarning, match="may not be physically meaningful"):
42+
DeltaFunction(name="TestDeltaFunction", area=-2.0, center=0.5, unit="meV")
43+
44+
@pytest.mark.xfail(
45+
reason="DeltaFunction.evaluate is not implemented yet without resolution convolution"
46+
)
47+
def test_evaluate(self, delta_function: DeltaFunction):
48+
x = np.array([0.0, 0.5, 1.0])
49+
expected = delta_function.evaluate(x)
50+
expected_result = np.zeros_like(x)
51+
# expected_result[x == 0.5] = 2.0
52+
np.testing.assert_allclose(expected, expected_result, rtol=1e-5)
53+
54+
def test_center_is_fixed_if_set_to_None(self):
55+
test_delta = DeltaFunction(
56+
name="TestDeltaFunction", area=2.0, center=None, unit="meV"
57+
)
58+
assert test_delta.center.value == 0.0
59+
assert test_delta.center.fixed is True
60+
61+
def test_input_as_parameter(self):
62+
param_area = Parameter(name="area_param", value=2.0, unit="meV")
63+
param_center = Parameter(name="center_param", value=0.5, unit="meV")
64+
test_delta = DeltaFunction(
65+
name="TestDeltaFunction", area=param_area, center=param_center, unit="meV"
66+
)
67+
assert test_delta.area == param_area
68+
assert test_delta.center == param_center
69+
70+
def test_get_parameters(self, delta_function: DeltaFunction):
71+
params = delta_function.get_parameters()
72+
assert len(params) == 2
73+
assert params[0].name == "TestDeltaFunction area"
74+
assert params[1].name == "TestDeltaFunction center"
75+
assert all(isinstance(param, Parameter) for param in params)
76+
77+
def test_convert_unit(self, delta_function: DeltaFunction):
78+
delta_function.convert_unit("microeV")
79+
80+
assert delta_function.unit == "microeV"
81+
assert delta_function.area.value == 2 * 1e3
82+
assert delta_function.center.value == 0.5 * 1e3
83+
84+
def test_copy(self, delta_function: DeltaFunction):
85+
delta_copy = delta_function.copy()
86+
assert delta_copy is not delta_function
87+
assert delta_copy.name == delta_function.name
88+
89+
assert delta_copy.area.value == delta_function.area.value
90+
assert delta_copy.area.fixed == delta_function.area.fixed
91+
92+
assert delta_copy.center.value == delta_function.center.value
93+
assert delta_copy.center.fixed == delta_function.center.fixed
94+
95+
assert delta_copy.unit == delta_function.unit
96+
97+
def test_repr(self, delta_function: DeltaFunction):
98+
repr_str = repr(delta_function)
99+
assert "DeltaFunction" in repr_str
100+
assert "name = TestDeltaFunction" in repr_str
101+
assert "unit = meV" in repr_str
102+
assert "area =" in repr_str
103+
assert "center =" in repr_str

0 commit comments

Comments
 (0)