Skip to content

Commit 8aa43c0

Browse files
Respond to reviewer comments
1 parent 3a7187c commit 8aa43c0

File tree

3 files changed

+56
-104
lines changed

3 files changed

+56
-104
lines changed

examples/sample_model.ipynb

Lines changed: 3 additions & 63 deletions
Large diffs are not rendered by default.

src/easydynamics/sample_model/sample_model.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ class SampleModel(CollectionBase, TheoreticalModelBase):
2222
Name of the SampleModel.
2323
unit : str or sc.Unit
2424
Unit of the SampleModel.
25-
components : List[ModelComponent]
26-
List of model components in the SampleModel.
2725
2826
"""
2927

@@ -46,17 +44,14 @@ def __init__(
4644
Initial list of model components to include in the sample model.
4745
"""
4846

49-
CollectionBase.__init__(self, name=name)
50-
TheoreticalModelBase.__init__(self, name=name)
47+
super().__init__(name=name)
5148
if not isinstance(self._kwargs, NotarizedDict):
5249
self._kwargs = NotarizedDict()
5350

5451
self._unit = unit
5552

5653
# Add initial components if provided. Mostly used for serialization.
5754
if data:
58-
# Just to be safe
59-
self.clear_components()
6055
for item in data:
6156
# ensure item is a ModelComponent
6257
if not isinstance(item, ModelComponent):
@@ -76,7 +71,7 @@ def add_component(
7671
Name to assign to the component. If None, uses the component's own name.
7772
"""
7873
if not isinstance(component, ModelComponent):
79-
raise TypeError("component must be an instance of ModelComponent.")
74+
raise TypeError("Component must be an instance of ModelComponent.")
8075

8176
if name is None:
8277
name = component.name
@@ -87,15 +82,15 @@ def add_component(
8782

8883
self.insert(index=len(self), value=component)
8984

90-
def remove_component(self, name: str):
85+
def remove_component(self, name: str) -> None:
9186
"""
9287
Remove a model component by name.
9388
"""
94-
# Find index where item.name == name
95-
indices = [i for i, item in enumerate(list(self)) if item.name == name]
96-
if not indices:
97-
raise KeyError(f"No component named '{name}' exists in the model.")
98-
del self[indices[0]]
89+
for i, item in enumerate(self):
90+
if item.name == name:
91+
del self[i]
92+
return
93+
raise KeyError(f"No component named '{name}' exists in the model.")
9994

10095
def list_component_names(self) -> List[str]:
10196
"""
@@ -122,7 +117,7 @@ def normalize_area(self) -> None:
122117
"""
123118
Normalize the areas of all components so they sum to 1.
124119
"""
125-
if not self.components:
120+
if not list(self):
126121
raise ValueError("No components in the model to normalize.")
127122

128123
area_params = []
@@ -146,17 +141,6 @@ def normalize_area(self) -> None:
146141
for param in area_params:
147142
param.value /= total_area
148143

149-
@property
150-
def components(self) -> List[ModelComponent]:
151-
"""
152-
Get the list of components in the SampleModel.
153-
154-
Returns
155-
-------
156-
List[ModelComponent]
157-
"""
158-
return list(self)
159-
160144
@property
161145
def unit(self) -> Optional[Union[str, sc.Unit]]:
162146
"""
@@ -181,10 +165,21 @@ def convert_unit(self, unit: Union[str, sc.Unit]) -> None:
181165
"""
182166
Convert the unit of the SampleModel and all its components.
183167
"""
184-
self._unit = unit
185-
# for component in self.components.values():
168+
169+
old_unit = self._unit
170+
186171
for component in list(self):
187-
component.convert_unit(unit)
172+
try:
173+
component.convert_unit(unit)
174+
except Exception as e:
175+
# Attempt to rollback on failure
176+
try:
177+
component.convert_unit(old_unit)
178+
except Exception:
179+
pass # Best effort rollback
180+
raise e
181+
182+
self._unit = unit
188183

189184
def evaluate(
190185
self, x: Union[Numeric, list, np.ndarray, sc.Variable, sc.DataArray]
@@ -203,14 +198,9 @@ def evaluate(
203198
Evaluated model values.
204199
"""
205200

206-
if not self.components:
201+
if not list(self):
207202
raise ValueError("No components in the model to evaluate.")
208-
result = None
209-
for component in list(self):
210-
value = component.evaluate(x)
211-
result = value if result is None else result + value
212-
213-
return result
203+
return sum(component.evaluate(x) for component in list(self))
214204

215205
def evaluate_component(
216206
self,
@@ -232,7 +222,7 @@ def evaluate_component(
232222
np.ndarray
233223
Evaluated values for the specified component.
234224
"""
235-
if not self.components:
225+
if not list(self):
236226
raise ValueError("No components in the model to evaluate.")
237227

238228
if not isinstance(name, str):
@@ -264,6 +254,28 @@ def free_all_parameters(self) -> None:
264254
for param in self.get_parameters():
265255
param.fixed = False
266256

257+
def __contains__(self, item: Union[str, ModelComponent]) -> bool:
258+
"""
259+
Check if a component with the given name or instance exists in the SampleModel.
260+
Args:
261+
----------
262+
item : str or ModelComponent
263+
The component name or instance to check for.
264+
Returns
265+
-------
266+
bool
267+
True if the component exists, False otherwise.
268+
"""
269+
270+
if isinstance(item, str):
271+
# Check by component name
272+
return any(comp.name == item for comp in self)
273+
elif isinstance(item, ModelComponent):
274+
# Check by component instance
275+
return any(comp is item for comp in self)
276+
else:
277+
return False
278+
267279
def __repr__(self) -> str:
268280
"""
269281
Return a string representation of the SampleModel.

tests/unit_tests/sample_model/test_sample_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_init(self):
2828

2929
# EXPECT
3030
assert sample_model.name == "InitModel"
31-
assert len(sample_model.components) == 0
31+
assert list(sample_model) == []
3232

3333
def test_initialization_with_components(self):
3434
# WHEN THEN
@@ -44,7 +44,7 @@ def test_initialization_with_components(self):
4444

4545
# EXPECT
4646
assert sample_model.name == "InitModelWithComponents"
47-
assert len(sample_model.components) == 2
47+
assert len(list(sample_model)) == 2
4848
assert sample_model["InitGaussian"] is component1
4949
assert sample_model["InitLorentzian"] is component2
5050

@@ -72,15 +72,15 @@ def test_add_duplicate_component_raises(self, sample_model):
7272
def test_add_invalid_component_raises(self, sample_model):
7373
# WHEN THEN EXPECT
7474
with pytest.raises(
75-
TypeError, match="component must be an instance of ModelComponent."
75+
TypeError, match="Component must be an instance of ModelComponent."
7676
):
7777
sample_model.add_component("NotAComponent")
7878

7979
def test_remove_component(self, sample_model):
8080
# WHEN THEN
8181
sample_model.remove_component("TestGaussian1")
8282
# EXPECT
83-
assert "TestGaussian1" not in sample_model.components
83+
assert "TestGaussian1" not in list(sample_model)
8484

8585
def test_remove_nonexistent_component_raises(self, sample_model):
8686
# WHEN THEN EXPECT
@@ -111,7 +111,7 @@ def test_clear_components(self, sample_model):
111111
# WHEN THEN
112112
sample_model.clear_components()
113113
# EXPECT
114-
assert len(sample_model.components) == 0
114+
assert len(list(sample_model)) == 0
115115

116116
def test_convert_unit(self, sample_model):
117117
# WHEN THEN

0 commit comments

Comments
 (0)