Skip to content

Commit 28d0b0e

Browse files
Tests :)
1 parent 8aa43c0 commit 28d0b0e

File tree

2 files changed

+52
-10
lines changed

2 files changed

+52
-10
lines changed

src/easydynamics/sample_model/sample_model.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,18 +168,18 @@ def convert_unit(self, unit: Union[str, sc.Unit]) -> None:
168168

169169
old_unit = self._unit
170170

171-
for component in list(self):
172-
try:
171+
try:
172+
for component in list(self):
173173
component.convert_unit(unit)
174-
except Exception as e:
175-
# Attempt to rollback on failure
176-
try:
174+
self._unit = unit
175+
except Exception as e:
176+
# Attempt to rollback on failure
177+
try:
178+
for component in list(self):
177179
component.convert_unit(old_unit)
178-
except Exception:
179-
pass # Best effort rollback
180-
raise e
181-
182-
self._unit = unit
180+
except Exception:
181+
pass # Best effort rollback
182+
raise e
183183

184184
def evaluate(
185185
self, x: Union[Numeric, list, np.ndarray, sc.Variable, sc.DataArray]

tests/unit_tests/sample_model/test_sample_model.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,38 @@ def test_convert_unit(self, sample_model):
120120
for component in list(sample_model):
121121
assert component.unit == "eV"
122122

123+
def test_convert_unit_failure_rolls_back(self, sample_model):
124+
# WHEN THEN
125+
# Introduce a faulty component that will fail conversion
126+
class FaultyComponent(Gaussian):
127+
def convert_unit(self, unit: str) -> None:
128+
raise RuntimeError("Conversion failed.")
129+
130+
faulty_component = FaultyComponent(
131+
name="FaultyComponent", area=1.0, center=0.0, width=1.0, unit="meV"
132+
)
133+
sample_model.add_component(faulty_component)
134+
135+
original_units = {
136+
component.name: component.unit for component in list(sample_model)
137+
}
138+
139+
# EXPECT
140+
with pytest.raises(RuntimeError, match="Conversion failed."):
141+
sample_model.convert_unit("eV")
142+
143+
# Check that all components have their original units
144+
for component in list(sample_model):
145+
assert component.unit == original_units[component.name]
146+
147+
def test_set_unit(self, sample_model):
148+
# WHEN THEN EXPECT
149+
with pytest.raises(
150+
AttributeError,
151+
match="Unit is read-only. Use convert_unit to change the unit",
152+
):
153+
sample_model.unit = "eV"
154+
123155
def test_evaluate(self, sample_model):
124156
# WHEN
125157
x = np.linspace(-5, 5, 100)
@@ -290,6 +322,16 @@ def test_fix_and_free_all_parameters(self, sample_model):
290322
for param in sample_model.get_parameters():
291323
assert param.fixed is False
292324

325+
def test_contains(self, sample_model):
326+
# WHEN THEN
327+
assert "TestGaussian1" in sample_model
328+
assert "NonExistentComponent" not in sample_model
329+
assert sample_model["TestLorentzian1"] in sample_model
330+
fake_component = Gaussian(
331+
name="FakeGaussian", area=1.0, center=0.0, width=1.0, unit="meV"
332+
)
333+
assert fake_component not in sample_model
334+
293335
def test_repr_contains_name_and_components(self, sample_model):
294336
# WHEN THEN
295337
rep = repr(sample_model)

0 commit comments

Comments
 (0)