Skip to content

Commit 20ce0fe

Browse files
Update tests
1 parent 82778b2 commit 20ce0fe

File tree

11 files changed

+153
-336
lines changed

11 files changed

+153
-336
lines changed

examples/component_example.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
"source": [
1010
"import numpy as np\n",
1111
"\n",
12-
"from easydynamics.sample import Gaussian\n",
13-
"from easydynamics.sample import Lorentzian\n",
14-
"from easydynamics.sample import DampedHarmonicOscillator\n",
15-
"from easydynamics.sample import Polynomial\n",
12+
"from easydynamics.sample_model import Gaussian\n",
13+
"from easydynamics.sample_model import Lorentzian\n",
14+
"from easydynamics.sample_model import DampedHarmonicOscillator\n",
15+
"from easydynamics.sample_model import Polynomial\n",
1616
"\n",
1717
"import matplotlib.pyplot as plt\n",
1818
"\n",

src/easydynamics/sample_model/components/damped_harmonic_oscillator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ def evaluate(self, x: Union[Numeric, sc.Variable]) -> Union[float, np.ndarray]:
9292
self.convert_unit(x.unit.name)
9393
except Exception as e:
9494
raise UnitError(
95-
f"Input x has unit {x.unit}, but DHO component has unit {self._unit}. Failed to convert DHO to {x.unit}."
95+
f"Input x has unit {x.unit}, but DampedHarmonicOscillator component has unit {self._unit}. Failed to convert DampedHarmonicOscillator to {x.unit}."
9696
) from e
9797
warnings.warn(
98-
f"Input x has unit {x.unit}, but DHO component has unit {self._unit}. Converting DHO to {x.unit}."
98+
f"Input x has unit {x.unit}, but DampedHarmonicOscillator component has unit {self._unit}. Converting DampedHarmonicOscillator to {x.unit}."
9999
)
100100
else:
101101
x_in = x

src/easydynamics/sample_model/components/gaussian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def evaluate(self, x: Union[Numeric, sc.Variable]) -> Union[float, np.ndarray]:
153153
if any(np.isinf(x_in)):
154154
raise ValueError("Input x contains infinite values.")
155155

156-
normalization = 1 / np.sqrt(2 * np.pi) * self._width.value
156+
normalization = 1 / (np.sqrt(2 * np.pi) * self._width.value)
157157
exponent = -0.5 * ((x_in - self._center.value) / self._width.value) ** 2
158158

159159
return self._area.value * normalization * np.exp(exponent)

src/easydynamics/sample_model/components/lorentzian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
)
5050

5151
if center is not None and not isinstance(center, Numeric):
52-
raise TypeError("center must be None, a number.")
52+
raise TypeError("center must be None or a number.")
5353

5454
if isinstance(center, Numeric):
5555
center = float(center)

src/easydynamics/sample_model/components/voigt.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def __init__(
3636
name: str = "Voigt",
3737
area: Numeric = 1.0,
3838
center: Union[Numeric, None] = None,
39-
_gaussian_width: Numeric = 1.0,
40-
_lorentzian_width: Numeric = 1.0,
39+
gaussian_width: Numeric = 1.0,
40+
lorentzian_width: Numeric = 1.0,
4141
unit: Union[str, sc.Unit] = "meV",
4242
):
4343
# Validate inputs
@@ -58,22 +58,20 @@ def __init__(
5858
if isinstance(center, Numeric):
5959
center = float(center)
6060

61-
if not isinstance(_gaussian_width, Numeric):
62-
raise TypeError("_gaussian_width must be a number.")
61+
if not isinstance(gaussian_width, Numeric):
62+
raise TypeError("gaussian_width must be a number.")
6363

64-
_gaussian_width = float(_gaussian_width)
65-
if _gaussian_width <= 0:
66-
raise ValueError(
67-
"The _gaussian_width of a Voigt must be greater than zero."
68-
)
64+
gaussian_width = float(gaussian_width)
65+
if gaussian_width <= 0:
66+
raise ValueError("The gaussian_width of a Voigt must be greater than zero.")
6967

70-
if not isinstance(_lorentzian_width, Numeric):
71-
raise TypeError("_lorentzian_width must be a number.")
68+
if not isinstance(lorentzian_width, Numeric):
69+
raise TypeError("lorentzian_width must be a number.")
7270

73-
_lorentzian_width = float(_lorentzian_width)
74-
if _lorentzian_width <= 0:
71+
lorentzian_width = float(lorentzian_width)
72+
if lorentzian_width <= 0:
7573
raise ValueError(
76-
"The _lorentzian_width of a Voigt must be greater than zero."
74+
"The lorentzian_width of a Voigt must be greater than zero."
7775
)
7876

7977
if not isinstance(unit, (str, sc.Unit)):
@@ -96,15 +94,15 @@ def __init__(
9694
self._center = Parameter(name=name + " center", value=center, unit=unit)
9795

9896
self._gaussian_width = Parameter(
99-
name=name + " _gaussian_width",
100-
value=_gaussian_width,
97+
name=name + " gaussian_width",
98+
value=gaussian_width,
10199
unit=unit,
102100
min=MINIMUM_WIDTH,
103101
)
104102

105103
self._lorentzian_width = Parameter(
106-
name=name + " _lorentzian_width",
107-
value=_lorentzian_width,
104+
name=name + " lorentzian_width",
105+
value=lorentzian_width,
108106
unit=unit,
109107
min=MINIMUM_WIDTH,
110108
)
@@ -174,14 +172,14 @@ def copy(self, name: Optional[str] = None) -> Voigt:
174172
name=name,
175173
area=self._area.value,
176174
center=self._center.value,
177-
_gaussian_width=self._gaussian_width.value,
178-
_lorentzian_width=self._lorentzian_width.value,
175+
gaussian_width=self._gaussian_width.value,
176+
lorentzian_width=self._lorentzian_width.value,
179177
unit=self._unit,
180178
)
181179
model_copy._area.fixed = self._area.fixed
182180
model_copy._center.fixed = self._center.fixed
183-
model_copy.__gaussian_width.fixed = self._gaussian_width.fixed
184-
model_copy.__lorentzian_width.fixed = self._lorentzian_width.fixed
181+
model_copy._gaussian_width.fixed = self._gaussian_width.fixed
182+
model_copy._lorentzian_width.fixed = self._lorentzian_width.fixed
185183

186184
return model_copy
187185

tests/unit_tests/sample_model/components/test_damped_harmonic_oscillator.py

Lines changed: 31 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import scipp as sc
5+
from scipp import UnitError
56

67
from scipy.integrate import simpson
78

@@ -18,14 +19,14 @@ def dho(self):
1819
)
1920

2021
def test_initialization(self, dho: DampedHarmonicOscillator):
21-
assert dho.name == "TestDHO"
22-
assert dho.area.value == 2.0
23-
assert dho.center.value == 1.5
24-
assert dho.width.value == 0.3
25-
assert dho.unit == "meV"
22+
assert dho._name == "TestDHO"
23+
assert dho._area.value == 2.0
24+
assert dho._center.value == 1.5
25+
assert dho._width.value == 0.3
26+
assert dho._unit == "meV"
2627

2728
def test_input_type_validation_raises(self):
28-
with pytest.raises(TypeError, match="area must be a number or a Parameter"):
29+
with pytest.raises(TypeError, match="area must be a number"):
2930
DampedHarmonicOscillator(
3031
name="TestDampedHarmonicOscillator",
3132
area="invalid",
@@ -34,7 +35,7 @@ def test_input_type_validation_raises(self):
3435
unit="meV",
3536
)
3637

37-
with pytest.raises(TypeError, match="center must be a number or a Parameter"):
38+
with pytest.raises(TypeError, match="center must be a number"):
3839
DampedHarmonicOscillator(
3940
name="TestDampedHarmonicOscillator",
4041
area=2.0,
@@ -43,7 +44,7 @@ def test_input_type_validation_raises(self):
4344
unit="meV",
4445
)
4546

46-
with pytest.raises(TypeError, match="width must be a number or a Parameter"):
47+
with pytest.raises(TypeError, match="width must be a number"):
4748
DampedHarmonicOscillator(
4849
name="TestDampedHarmonicOscillator",
4950
area=2.0,
@@ -74,21 +75,6 @@ def test_negative_width_raises(self):
7475
unit="meV",
7576
)
7677

77-
def test_negative_width_raises_in_evaluate(self):
78-
test_dho = DampedHarmonicOscillator(
79-
name="TestDampedHarmonicOscillator",
80-
area=2.0,
81-
center=0.5,
82-
width=0.6,
83-
unit="meV",
84-
)
85-
test_dho.width.value = -0.6
86-
with pytest.raises(
87-
ValueError,
88-
match="The width of a DampedHarmonicOscillator must be greater than zero.",
89-
):
90-
test_dho.evaluate(np.array([0.0, 1.5, 3.0]))
91-
9278
def test_negative_area_warns(self):
9379
with pytest.warns(UserWarning, match="may not be physically meaningful"):
9480
DampedHarmonicOscillator(
@@ -99,18 +85,6 @@ def test_negative_area_warns(self):
9985
unit="meV",
10086
)
10187

102-
def test_negative_area_warns_in_evaluate(self):
103-
test_dho = DampedHarmonicOscillator(
104-
name="TestDampedHarmonicOscillator",
105-
area=2.0,
106-
center=0.5,
107-
width=0.6,
108-
unit="meV",
109-
)
110-
test_dho.area.value = -2.0
111-
with pytest.warns(UserWarning, match="may not be physically meaningful"):
112-
test_dho.evaluate(np.array([0.0, 1.5, 3.0]))
113-
11488
def test_evaluate(self, dho: DampedHarmonicOscillator):
11589
x = np.array([0.0, 1.5, 3.0])
11690
expected = dho.evaluate(x)
@@ -151,20 +125,13 @@ def test_evaluate_with_different_unit(self, dho: DampedHarmonicOscillator):
151125
)
152126
np.testing.assert_allclose(expected, expected_result, rtol=1e-5)
153127

154-
def test_input_as_parameter(self):
155-
param_area = Parameter(name="area_param", value=2.0, unit="meV")
156-
param_center = Parameter(name="center_param", value=0.5, unit="meV")
157-
param_width = Parameter(name="width_param", value=0.6, unit="meV")
158-
test_dho = DampedHarmonicOscillator(
159-
name="TestDHO",
160-
area=param_area,
161-
center=param_center,
162-
width=param_width,
163-
unit="meV",
164-
)
165-
assert test_dho.area == param_area
166-
assert test_dho.center == param_center
167-
assert test_dho.width == param_width
128+
def test_evaluate_with_incompatible_unit(self, dho: DampedHarmonicOscillator):
129+
x = sc.array(dims=["x"], values=[0.0, 500.0, 1000.0], unit="nm")
130+
with pytest.raises(
131+
UnitError,
132+
match="Input x has unit nm, but DampedHarmonicOscillator component has unit meV. Failed to convert DampedHarmonicOscillator to nm.",
133+
):
134+
dho.evaluate(x)
168135

169136
def test_get_parameters(self, dho: DampedHarmonicOscillator):
170137
params = dho.get_parameters()
@@ -177,39 +144,39 @@ def test_get_parameters(self, dho: DampedHarmonicOscillator):
177144
def test_area_matches_parameter(self, dho: DampedHarmonicOscillator):
178145
# WHEN
179146
x = np.linspace(
180-
-dho.center.value - 20 * dho.width.value,
181-
dho.center.value + 20 * dho.width.value,
147+
-dho._center.value - 20 * dho._width.value,
148+
dho._center.value + 20 * dho._width.value,
182149
5000,
183150
)
184151
y = dho.evaluate(x)
185152
numerical_area = simpson(y, x)
186153

187154
# THEN EXPECT
188-
assert numerical_area == pytest.approx(dho.area.value, rel=2e-3)
155+
assert numerical_area == pytest.approx(dho._area.value, rel=2e-3)
189156

190157
def test_convert_unit(self, dho: DampedHarmonicOscillator):
191158
dho.convert_unit("microeV")
192159

193-
assert dho.unit == "microeV"
194-
assert dho.area.value == 2 * 1e3
195-
assert dho.center.value == 1.5 * 1e3
196-
assert dho.width.value == 0.3 * 1e3
160+
assert dho._unit == "microeV"
161+
assert dho._area.value == 2 * 1e3
162+
assert dho._center.value == 1.5 * 1e3
163+
assert dho._width.value == 0.3 * 1e3
197164

198165
def test_copy(self, dho: DampedHarmonicOscillator):
199166
dho_copy = dho.copy()
200167
assert dho_copy is not dho
201-
assert dho_copy.name == dho.name
168+
assert dho_copy.name == "copy of " + dho._name
202169

203-
assert dho_copy.area.value == dho.area.value
204-
assert dho_copy.area.fixed == dho.area.fixed
170+
assert dho_copy._area.value == dho._area.value
171+
assert dho_copy._area.fixed == dho._area.fixed
205172

206-
assert dho_copy.center.value == dho.center.value
207-
assert dho_copy.center.fixed == dho.center.fixed
173+
assert dho_copy._center.value == dho._center.value
174+
assert dho_copy._center.fixed == dho._center.fixed
208175

209-
assert dho_copy.width.value == dho.width.value
210-
assert dho_copy.width.fixed == dho.width.fixed
176+
assert dho_copy._width.value == dho._width.value
177+
assert dho_copy._width.fixed == dho._width.fixed
211178

212-
assert dho_copy.unit == dho.unit
179+
assert dho_copy._unit == dho._unit
213180

214181
def test_repr(self, dho: DampedHarmonicOscillator):
215182
repr_str = repr(dho)

tests/unit_tests/sample_model/components/test_delta_function.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
import numpy as np
4+
from scipp import UnitError
45

56
from easydynamics.sample_model import DeltaFunction
67
from easyscience.variable import Parameter
@@ -13,21 +14,19 @@ def delta_function(self):
1314

1415
def test_initialization(self, delta_function: DeltaFunction):
1516
assert delta_function.name == "TestDeltaFunction"
16-
assert delta_function.area.value == 2.0
17-
assert delta_function.center.value == 0.5
17+
assert delta_function._area.value == 2.0
18+
assert delta_function._center.value == 0.5
1819
assert delta_function.unit == "meV"
1920

2021
def test_input_type_validation_raises(self):
21-
with pytest.raises(TypeError, match="area must be a number or a Parameter"):
22+
with pytest.raises(TypeError, match="area must be a number"):
2223
DeltaFunction(
2324
name="TestDeltaFunction",
2425
area="invalid",
2526
center=0.5,
2627
unit="meV",
2728
)
28-
with pytest.raises(
29-
TypeError, match="center must be None, a number or a Parameter"
30-
):
29+
with pytest.raises(TypeError, match="center must be None or a number"):
3130
DeltaFunction(
3231
name="TestDeltaFunction",
3332
area=2.0,
@@ -55,17 +54,8 @@ def test_center_is_fixed_if_set_to_None(self):
5554
test_delta = DeltaFunction(
5655
name="TestDeltaFunction", area=2.0, center=None, unit="meV"
5756
)
58-
assert test_delta.center.value == 0.0
59-
assert test_delta.center.fixed is True
60-
61-
def test_input_as_parameter(self):
62-
param_area = Parameter(name="area_param", value=2.0, unit="meV")
63-
param_center = Parameter(name="center_param", value=0.5, unit="meV")
64-
test_delta = DeltaFunction(
65-
name="TestDeltaFunction", area=param_area, center=param_center, unit="meV"
66-
)
67-
assert test_delta.area == param_area
68-
assert test_delta.center == param_center
57+
assert test_delta._center.value == 0.0
58+
assert test_delta._center.fixed is True
6959

7060
def test_get_parameters(self, delta_function: DeltaFunction):
7161
params = delta_function.get_parameters()
@@ -78,19 +68,19 @@ def test_convert_unit(self, delta_function: DeltaFunction):
7868
delta_function.convert_unit("microeV")
7969

8070
assert delta_function.unit == "microeV"
81-
assert delta_function.area.value == 2 * 1e3
82-
assert delta_function.center.value == 0.5 * 1e3
71+
assert delta_function._area.value == 2 * 1e3
72+
assert delta_function._center.value == 0.5 * 1e3
8373

8474
def test_copy(self, delta_function: DeltaFunction):
8575
delta_copy = delta_function.copy()
8676
assert delta_copy is not delta_function
87-
assert delta_copy.name == delta_function.name
77+
assert delta_copy.name == "copy of " + delta_function.name
8878

89-
assert delta_copy.area.value == delta_function.area.value
90-
assert delta_copy.area.fixed == delta_function.area.fixed
79+
assert delta_copy._area.value == delta_function._area.value
80+
assert delta_copy._area.fixed == delta_function._area.fixed
9181

92-
assert delta_copy.center.value == delta_function.center.value
93-
assert delta_copy.center.fixed == delta_function.center.fixed
82+
assert delta_copy._center.value == delta_function._center.value
83+
assert delta_copy._center.fixed == delta_function._center.fixed
9484

9585
assert delta_copy.unit == delta_function.unit
9686

0 commit comments

Comments
 (0)