Skip to content

Commit 3a7187c

Browse files
Cleanup and a few tests
1 parent 07c070a commit 3a7187c

File tree

3 files changed

+102
-152
lines changed

3 files changed

+102
-152
lines changed

examples/sample_model.ipynb

Lines changed: 63 additions & 2 deletions
Large diffs are not rendered by default.

src/easydynamics/sample_model/sample_model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,16 @@ def __init__(
5353

5454
self._unit = unit
5555

56+
# Add initial components if provided. Mostly used for serialization.
5657
if data:
57-
# clear any accidental pre-populated items (defensive)
58+
# Just to be safe
5859
self.clear_components()
5960
for item in data:
6061
# ensure item is a ModelComponent
6162
if not isinstance(item, ModelComponent):
6263
raise TypeError("Data items must be instances of ModelComponent.")
6364
self.insert(index=len(self), value=item)
6465

65-
##############################################
66-
# Methods for managing components #
67-
##############################################
68-
6966
def add_component(
7067
self, component: ModelComponent, name: Optional[str] = None
7168
) -> None:
@@ -221,7 +218,7 @@ def evaluate_component(
221218
name: str,
222219
) -> np.ndarray:
223220
"""
224-
Evaluate a single component by name, optionally applying detailed balance.
221+
Evaluate a single component by name.
225222
226223
Parameters
227224
----------

tests/unit_tests/sample_model/test_sample_model.py

Lines changed: 36 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,23 @@ def test_init(self):
3030
assert sample_model.name == "InitModel"
3131
assert len(sample_model.components) == 0
3232

33-
# def test_init_no_temperature(self, sample_model):
34-
# # WHEN THEN EXPECT
35-
# assert sample_model.name == "TestSampleModel"
36-
# assert len(sample_model.components) == 2
37-
# assert not sample_model.use_detailed_balance
38-
39-
# def test_init_with_temperature(self):
40-
# # WHEN THEN
41-
# sample_model = SampleModel(name="TempModel", temperature=100)
42-
43-
# # EXPECT
44-
# assert sample_model.name == "TempModel"
45-
# assert len(sample_model.components) == 0
46-
# assert sample_model.use_detailed_balance
47-
# assert isinstance(sample_model.temperature, Parameter)
48-
# assert sample_model.temperature.value == 100
33+
def test_initialization_with_components(self):
34+
# WHEN THEN
35+
component1 = Gaussian(
36+
name="InitGaussian", area=1.0, center=0.0, width=1.0, unit="meV"
37+
)
38+
component2 = Lorentzian(
39+
name="InitLorentzian", area=2.0, center=1.0, width=0.5, unit="meV"
40+
)
41+
sample_model = SampleModel(
42+
name="InitModelWithComponents", data=[component1, component2]
43+
)
44+
45+
# EXPECT
46+
assert sample_model.name == "InitModelWithComponents"
47+
assert len(sample_model.components) == 2
48+
assert sample_model["InitGaussian"] is component1
49+
assert sample_model["InitLorentzian"] is component2
4950

5051
# ───── Component Management ─────
5152

@@ -119,79 +120,6 @@ def test_convert_unit(self, sample_model):
119120
for component in list(sample_model):
120121
assert component.unit == "eV"
121122

122-
# # ───── Temperature and Detailed Balance ─────
123-
124-
# def test_set_temperature(self, sample_model):
125-
# # Set valid temperature
126-
# # WHEN THEN
127-
# sample_model.temperature = 300
128-
# # EXPECT
129-
# assert sample_model.temperature.value == 300
130-
# assert sample_model.temperature.unit == "K"
131-
132-
# # WHEN THEN
133-
# sample_model.temperature = 150.0
134-
# # EXPECT
135-
# assert sample_model.temperature.value == 150.0
136-
# assert sample_model.temperature.unit == "K"
137-
138-
# # Set temperature to None
139-
# # WHEN THEN
140-
# sample_model.temperature = None
141-
# # EXPECT
142-
# assert sample_model.temperature is None
143-
# assert not sample_model.use_detailed_balance
144-
145-
# def test_invalid_temperature_raises(self, sample_model):
146-
# # WHEN THEN EXPECT
147-
# with pytest.raises(TypeError, match="Temperature must be a number or None."):
148-
# sample_model.temperature = "invalid"
149-
150-
# def test_negative_temperature_raises(self, sample_model):
151-
# # WHEN THEN EXPECT
152-
# with pytest.raises(ValueError, match="Temperature must be non-negative"):
153-
# sample_model.temperature = -50
154-
155-
# def test_convert_temperature_unit(self, sample_model):
156-
# # WHEN
157-
# sample_model.temperature = 300 # Kelvin
158-
# # THEN
159-
# sample_model.convert_temperature_unit("mK")
160-
# # EXPECT
161-
# assert np.isclose(sample_model.temperature.value, 300000.0)
162-
# assert sample_model.temperature.unit == "mK"
163-
164-
# def test_convert_temperature_unit_incompatible_unit_raises(self, sample_model):
165-
# # WHEN
166-
# sample_model.temperature = 300 # Kelvin
167-
# # THEN EXPECT
168-
# with pytest.raises(UnitError, match="Failed to convert temperature"):
169-
# sample_model.convert_temperature_unit("m")
170-
171-
# def test_convert_temperature_unit_no_temperature_raises(self, sample_model):
172-
# # WHEN THEN EXPECT
173-
# with pytest.raises(ValueError, match="cannot convert units"):
174-
# sample_model.convert_temperature_unit("mK")
175-
176-
# def test_use_detailed_balance(self, sample_model):
177-
# sample_model.temperature = 300
178-
# # WHEN THEN EXPECT
179-
# assert sample_model.use_detailed_balance is False
180-
# sample_model.use_detailed_balance = True
181-
# assert sample_model.use_detailed_balance is True
182-
# sample_model.use_detailed_balance = False
183-
# assert sample_model.use_detailed_balance is False
184-
185-
# def test_use_detailed_balance_no_temperature_raises(self, sample_model):
186-
# # WHEN THEN EXPECT
187-
# with pytest.raises(
188-
# ValueError,
189-
# match="Temperature must be set to use detailed balance.",
190-
# ):
191-
# sample_model.use_detailed_balance = True
192-
193-
# ───── Evaluation ─────
194-
195123
def test_evaluate(self, sample_model):
196124
# WHEN
197125
x = np.linspace(-5, 5, 100)
@@ -202,31 +130,6 @@ def test_evaluate(self, sample_model):
202130
].evaluate(x)
203131
np.testing.assert_allclose(result, expected_result, rtol=1e-5)
204132

205-
# @pytest.mark.parametrize(
206-
# "normalize_db", [True, False], ids=["normalize DB", "Don't normalize DB"]
207-
# )
208-
# def test_evaluate_with_detailed_balance(self, sample_model, normalize_db):
209-
# # WHEN
210-
# sample_model.temperature = 300
211-
# sample_model.use_detailed_balance = True
212-
# sample_model.normalize_detailed_balance = normalize_db
213-
214-
# x = np.linspace(-5, 5, 100)
215-
216-
# # THEN
217-
# result = sample_model.evaluate(x)
218-
219-
# # EXPECT
220-
# expected_result = sample_model["TestGaussian1"].evaluate(x) + sample_model[
221-
# "TestLorentzian1"
222-
# ].evaluate(x)
223-
# expected_result *= detailed_balance_factor(
224-
# energy=x,
225-
# temperature=sample_model.temperature,
226-
# divide_by_temperature=normalize_db,
227-
# )
228-
# np.testing.assert_allclose(result, expected_result, rtol=1e-5)
229-
230133
def test_evaluate_no_components_raises(self):
231134
# WHEN THEN
232135
sample_model = SampleModel(name="EmptyModel")
@@ -247,36 +150,6 @@ def test_evaluate_component(self, sample_model):
247150
np.testing.assert_allclose(result1, expected_result1, rtol=1e-5)
248151
np.testing.assert_allclose(result2, expected_result2, rtol=1e-5)
249152

250-
# @pytest.mark.parametrize(
251-
# "normalize_db", [True, False], ids=["normalize DB", "Don't normalize DB"]
252-
# )
253-
# def test_evaluate_component_with_detailed_balance(self, sample_model, normalize_db):
254-
# # WHEN
255-
# sample_model.temperature = 300
256-
# sample_model.use_detailed_balance = True
257-
# sample_model.normalize_detailed_balance = normalize_db
258-
259-
# # THEN
260-
# x = np.linspace(-5, 5, 100)
261-
# result1 = sample_model.evaluate_component(x, name="TestGaussian1")
262-
# result2 = sample_model.evaluate_component(x, name="TestLorentzian1")
263-
264-
# # EXPECT
265-
# expected_result1 = sample_model["TestGaussian1"].evaluate(x)
266-
# expected_result2 = sample_model["TestLorentzian1"].evaluate(x)
267-
# expected_result1 *= detailed_balance_factor(
268-
# energy=x,
269-
# temperature=sample_model.temperature,
270-
# divide_by_temperature=normalize_db,
271-
# )
272-
# expected_result2 *= detailed_balance_factor(
273-
# energy=x,
274-
# temperature=sample_model.temperature,
275-
# divide_by_temperature=normalize_db,
276-
# )
277-
# np.testing.assert_allclose(result1, expected_result1, rtol=1e-5)
278-
# np.testing.assert_allclose(result2, expected_result2, rtol=1e-5)
279-
280153
def test_evaluate_nonexistent_component_raises(self, sample_model):
281154
# WHEN
282155
x = np.linspace(-5, 5, 100)
@@ -287,6 +160,25 @@ def test_evaluate_nonexistent_component_raises(self, sample_model):
287160
):
288161
sample_model.evaluate_component(x, "NonExistentComponent")
289162

163+
def test_evaluate_component_no_components_raises(self):
164+
# WHEN THEN
165+
sample_model = SampleModel(name="EmptyModel")
166+
x = np.linspace(-5, 5, 100)
167+
# EXPECT
168+
with pytest.raises(ValueError, match="No components in the model to evaluate."):
169+
sample_model.evaluate_component(x, "AnyComponent")
170+
171+
def test_evaluate_component_invalid_name_type_raises(self, sample_model):
172+
# WHEN
173+
x = np.linspace(-5, 5, 100)
174+
175+
# THEN EXPECT
176+
with pytest.raises(
177+
TypeError,
178+
match="Component name must be a string, got <class 'int'> instead.",
179+
):
180+
sample_model.evaluate_component(x, 123)
181+
290182
# ───── Utilities ─────
291183

292184
def test_normalize_area(self, sample_model):

0 commit comments

Comments
 (0)