Skip to content

Commit d1d28d2

Browse files
Use CollectionBase
1 parent a4ba794 commit d1d28d2

File tree

3 files changed

+214
-158
lines changed

3 files changed

+214
-158
lines changed

examples/sample_model.ipynb

Lines changed: 141 additions & 7 deletions
Large diffs are not rendered by default.

src/easydynamics/sample_model/sample_model.py

Lines changed: 66 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import warnings
2-
from collections.abc import MutableMapping
32
from copy import copy
43
from itertools import chain
5-
from typing import Dict, List, Optional, Union
4+
from typing import List, Optional, Union
65

76
import numpy as np
87
import scipp as sc
9-
10-
# from easyscience.base_classes import ObjBase
8+
from easyscience.base_classes import CollectionBase
9+
from easyscience.global_object.undo_redo import NotarizedDict
1110
from easyscience.job.theoreticalmodel import TheoreticalModelBase
1211
from easyscience.variable import Parameter
1312
from scipp import UnitError
@@ -19,15 +18,13 @@
1918
Numeric = Union[float, int]
2019

2120

22-
class SampleModel(TheoreticalModelBase, MutableMapping):
21+
class SampleModel(CollectionBase, TheoreticalModelBase):
2322
"""
2423
A model of the scattering from a sample, combining multiple model components.
2524
Optionally applies detailed balancing.
2625
2726
Attributes
2827
----------
29-
components : dict
30-
Dictionary of model components keyed by name.
3128
temperature : Parameter
3229
Temperature parameter for detailed balance.
3330
use_detailed_balance : bool
@@ -36,6 +33,11 @@ class SampleModel(TheoreticalModelBase, MutableMapping):
3633
Whether to normalize the detailed balance by temperature.
3734
name : str
3835
Name of the SampleModel.
36+
unit : str or sc.Unit
37+
Unit of the SampleModel.
38+
components : List[ModelComponent]
39+
List of model components in the SampleModel.
40+
3941
"""
4042

4143
def __init__(
@@ -58,8 +60,11 @@ def __init__(
5860
Unit of the temperature.
5961
"""
6062

61-
self.components: Dict[str, ModelComponent] = {}
62-
super().__init__(name=name)
63+
CollectionBase.__init__(self, name=name)
64+
TheoreticalModelBase.__init__(self, name=name)
65+
if not isinstance(self._kwargs, NotarizedDict):
66+
self._kwargs = NotarizedDict()
67+
6368
# If temperature is given, create a Parameter and enable detailed balance.
6469
if temperature is None:
6570
self._temperature = None
@@ -103,26 +108,24 @@ def add_component(
103108

104109
if name is None:
105110
name = component.name
106-
if name in self.components:
111+
if name in self.list_component_names():
107112
raise ValueError(f"Component with name '{name}' already exists.")
108113

109-
self.components[name] = component
114+
component.name = name
115+
116+
self.insert(index=len(self), value=component)
110117

111118
def remove_component(self, name: str):
112119
"""
113120
Remove a model component by name.
114-
115-
Parameters
116-
----------
117-
name : str
118-
Name of the component to remove.
119121
"""
120-
121-
if name not in self.components:
122+
# Find index where item.name == name
123+
indices = [i for i, item in enumerate(list(self)) if item.name == name]
124+
if not indices:
122125
raise KeyError(f"No component named '{name}' exists in the model.")
123-
del self.components[name]
126+
del self[indices[0]]
124127

125-
def list_components(self) -> List[str]:
128+
def list_component_names(self) -> List[str]:
126129
"""
127130
List the names of all components in the model.
128131
@@ -132,14 +135,15 @@ def list_components(self) -> List[str]:
132135
Component names.
133136
"""
134137

135-
return list(self.components.keys())
138+
return [item.name for item in list(self)]
136139

137140
def clear_components(self):
138141
"""
139142
Remove all components from the model.
140143
"""
141144

142-
self.components.clear()
145+
for _ in range(len(self)):
146+
del self[0]
143147

144148
def normalize_area(self) -> None:
145149
# Useful for convolutions.
@@ -152,7 +156,7 @@ def normalize_area(self) -> None:
152156
area_params = []
153157
total_area = 0.0
154158

155-
for component in self.components.values():
159+
for component in list(self):
156160
if hasattr(component, "area"):
157161
area_params.append(component.area)
158162
total_area += component.area.value
@@ -175,9 +179,21 @@ def convert_unit(self, unit: Union[str, sc.Unit]):
175179
Convert the unit of the SampleModel and all its components.
176180
"""
177181
self._unit = unit
178-
for component in self.components.values():
182+
# for component in self.components.values():
183+
for component in list(self):
179184
component.convert_unit(unit)
180185

186+
@property
187+
def components(self) -> List[ModelComponent]:
188+
"""
189+
Get the list of components in the SampleModel.
190+
191+
Returns
192+
-------
193+
List[ModelComponent]
194+
"""
195+
return list(self)
196+
181197
@property
182198
def unit(self) -> Optional[Union[str, sc.Unit]]:
183199
"""
@@ -331,7 +347,7 @@ def evaluate(
331347
if not self.components:
332348
raise ValueError("No components in the model to evaluate.")
333349
result = None
334-
for component in self.components.values():
350+
for component in list(self):
335351
value = component.evaluate(x)
336352
result = value if result is None else result + value
337353

@@ -368,10 +384,21 @@ def evaluate_component(
368384
np.ndarray
369385
Evaluated values for the specified component.
370386
"""
371-
if name not in self.components:
387+
if not self.components:
388+
raise ValueError("No components in the model to evaluate.")
389+
390+
if not isinstance(name, str):
391+
raise TypeError(
392+
(f"Component name must be a string, got {type(name)} instead.")
393+
)
394+
395+
matches = [comp for comp in list(self) if comp.name == name]
396+
if not matches:
372397
raise KeyError(f"No component named '{name}' exists.")
373398

374-
result = self.components[name].evaluate(x)
399+
component = matches[0]
400+
401+
result = component.evaluate(x)
375402
if (
376403
self.use_detailed_balance
377404
and self._temperature is not None
@@ -401,11 +428,7 @@ def get_parameters(self) -> List[Parameter]:
401428
temp_params = (self._temperature,) if self._temperature is not None else ()
402429

403430
# Create generator for component parameters
404-
comp_params = (
405-
param
406-
for comp in self.components.values()
407-
for param in comp.get_parameters()
408-
)
431+
comp_params = (param for comp in list(self) for param in comp.get_parameters())
409432

410433
# Chain them together and return as list
411434
return list(chain(temp_params, comp_params))
@@ -418,16 +441,6 @@ def get_fit_parameters(self) -> List[Parameter]:
418441
List[Parameter]: A list of fit parameters.
419442
"""
420443

421-
# parameters = self.get_parameters()
422-
# fit_parameters = []
423-
424-
# for parameter in parameters:
425-
# is_not_fixed = not getattr(parameter, "fixed", False)
426-
# is_independent = getattr(parameter, "_independent", True)
427-
428-
# if is_not_fixed and is_independent:
429-
# fit_parameters.append(parameter)
430-
431444
def is_fit_parameter(param: Parameter) -> bool:
432445
"""Check if a parameter can be used for fitting."""
433446
return not getattr(param, "fixed", False) and getattr(
@@ -477,87 +490,14 @@ def __copy__(self) -> "SampleModel":
477490
new_model.use_detailed_balance = self.use_detailed_balance
478491
new_model.normalize_detailed_balance = self.normalize_detailed_balance
479492

480-
for comp in self.components.values():
493+
for comp in list(self):
481494
new_model.add_component(component=copy(comp), name=comp.name)
482495
new_model[comp.name].name = comp.name # Remove 'copy of ' prefix
483496
for par in new_model[comp.name].get_parameters():
484497
par.name = par.name.removeprefix("copy of ")
485498

486499
return new_model
487500

488-
##############################################
489-
# dict-like behaviour #
490-
##############################################
491-
492-
def __getitem__(self, key: str) -> ModelComponent:
493-
"""
494-
Access a component by name.
495-
496-
Parameters
497-
----------
498-
key : str
499-
Name of the component.
500-
501-
Returns
502-
-------
503-
ModelComponent
504-
"""
505-
return self.components[key]
506-
507-
def __setitem__(self, key: str, value: ModelComponent) -> None:
508-
"""
509-
Set or replace a component.
510-
511-
Parameters
512-
----------
513-
key : str
514-
Name of the component.
515-
value : ModelComponent
516-
The component to assign.
517-
"""
518-
if not isinstance(value, ModelComponent):
519-
raise TypeError("Value must be an instance of ModelComponent.")
520-
self.components[key] = value
521-
522-
def __delitem__(self, key: str) -> None:
523-
"""
524-
Remove a component by name.
525-
Parameters
526-
----------
527-
key : str
528-
Name of the component to remove.
529-
"""
530-
if not isinstance(key, str):
531-
raise TypeError("Key must be a string.")
532-
533-
if key not in self.components:
534-
raise KeyError(f"No component named '{key}' exists in the model.")
535-
536-
self.remove_component(key)
537-
538-
def __contains__(self, name: str) -> bool:
539-
"""
540-
Check if a component exists in the model.
541-
542-
Parameters
543-
----------
544-
name : str
545-
Name of the component.
546-
547-
Returns
548-
-------
549-
bool
550-
"""
551-
return name in self.components
552-
553-
def __iter__(self) -> iter:
554-
"""Iterate over component names."""
555-
return iter(self.components)
556-
557-
def __len__(self) -> int:
558-
"""Return the number of components in the model."""
559-
return len(self.components)
560-
561501
def __repr__(self) -> str:
562502
"""
563503
Return a string representation of the SampleModel.
@@ -566,10 +506,14 @@ def __repr__(self) -> str:
566506
-------
567507
str
568508
"""
569-
comp_names = ", ".join(self.components.keys()) or "No components"
570-
temp_str = (
571-
f" | Temperature: {self._temperature.value} {self._temperature.unit}"
572-
if self._use_detailed_balance
573-
else ""
574-
)
509+
comp_names = ", ".join(c.name for c in self) or "No components"
510+
511+
temp_str = ""
512+
if (
513+
getattr(self, "_use_detailed_balance", False)
514+
and getattr(self, "_temperature", None) is not None
515+
):
516+
temp = self._temperature
517+
temp_str = f" | Temperature: {temp.value} {temp.unit}"
518+
575519
return f"<SampleModel name='{self.name}' | Components: {comp_names}{temp_str}>"

0 commit comments

Comments
 (0)