11import warnings
2- from collections .abc import MutableMapping
32from copy import copy
43from itertools import chain
5- from typing import Dict , List , Optional , Union
4+ from typing import List , Optional , Union
65
76import numpy as np
87import scipp as sc
9-
10- # from easyscience.base_classes import ObjBase
8+ from easyscience . base_classes import CollectionBase
9+ from easyscience .global_object . undo_redo import NotarizedDict
1110from easyscience .job .theoreticalmodel import TheoreticalModelBase
1211from easyscience .variable import Parameter
1312from scipp import UnitError
1918Numeric = Union [float , int ]
2019
2120
22- class SampleModel (TheoreticalModelBase , MutableMapping ):
21+ class SampleModel (CollectionBase , TheoreticalModelBase ):
2322 """
2423 A model of the scattering from a sample, combining multiple model components.
2524 Optionally applies detailed balancing.
2625
2726 Attributes
2827 ----------
29- components : dict
30- Dictionary of model components keyed by name.
3128 temperature : Parameter
3229 Temperature parameter for detailed balance.
3330 use_detailed_balance : bool
@@ -36,6 +33,11 @@ class SampleModel(TheoreticalModelBase, MutableMapping):
3633 Whether to normalize the detailed balance by temperature.
3734 name : str
3835 Name of the SampleModel.
36+ unit : str or sc.Unit
37+ Unit of the SampleModel.
38+ components : List[ModelComponent]
39+ List of model components in the SampleModel.
40+
3941 """
4042
4143 def __init__ (
@@ -58,8 +60,11 @@ def __init__(
5860 Unit of the temperature.
5961 """
6062
61- self .components : Dict [str , ModelComponent ] = {}
62- super ().__init__ (name = name )
63+ CollectionBase .__init__ (self , name = name )
64+ TheoreticalModelBase .__init__ (self , name = name )
65+ if not isinstance (self ._kwargs , NotarizedDict ):
66+ self ._kwargs = NotarizedDict ()
67+
6368 # If temperature is given, create a Parameter and enable detailed balance.
6469 if temperature is None :
6570 self ._temperature = None
@@ -103,26 +108,24 @@ def add_component(
103108
104109 if name is None :
105110 name = component .name
106- if name in self .components :
111+ if name in self .list_component_names () :
107112 raise ValueError (f"Component with name '{ name } ' already exists." )
108113
109- self .components [name ] = component
114+ component .name = name
115+
116+ self .insert (index = len (self ), value = component )
110117
111118 def remove_component (self , name : str ):
112119 """
113120 Remove a model component by name.
114-
115- Parameters
116- ----------
117- name : str
118- Name of the component to remove.
119121 """
120-
121- if name not in self .components :
122+ # Find index where item.name == name
123+ indices = [i for i , item in enumerate (list (self )) if item .name == name ]
124+ if not indices :
122125 raise KeyError (f"No component named '{ name } ' exists in the model." )
123- del self . components [ name ]
126+ del self [ indices [ 0 ] ]
124127
125- def list_components (self ) -> List [str ]:
128+ def list_component_names (self ) -> List [str ]:
126129 """
127130 List the names of all components in the model.
128131
@@ -132,14 +135,15 @@ def list_components(self) -> List[str]:
132135 Component names.
133136 """
134137
135- return list (self . components . keys ())
138+ return [ item . name for item in list (self )]
136139
137140 def clear_components (self ):
138141 """
139142 Remove all components from the model.
140143 """
141144
142- self .components .clear ()
145+ for _ in range (len (self )):
146+ del self [0 ]
143147
144148 def normalize_area (self ) -> None :
145149 # Useful for convolutions.
@@ -152,7 +156,7 @@ def normalize_area(self) -> None:
152156 area_params = []
153157 total_area = 0.0
154158
155- for component in self . components . values ( ):
159+ for component in list ( self ):
156160 if hasattr (component , "area" ):
157161 area_params .append (component .area )
158162 total_area += component .area .value
@@ -175,9 +179,21 @@ def convert_unit(self, unit: Union[str, sc.Unit]):
175179 Convert the unit of the SampleModel and all its components.
176180 """
177181 self ._unit = unit
178- for component in self .components .values ():
182+ # for component in self.components.values():
183+ for component in list (self ):
179184 component .convert_unit (unit )
180185
186+ @property
187+ def components (self ) -> List [ModelComponent ]:
188+ """
189+ Get the list of components in the SampleModel.
190+
191+ Returns
192+ -------
193+ List[ModelComponent]
194+ """
195+ return list (self )
196+
181197 @property
182198 def unit (self ) -> Optional [Union [str , sc .Unit ]]:
183199 """
@@ -331,7 +347,7 @@ def evaluate(
331347 if not self .components :
332348 raise ValueError ("No components in the model to evaluate." )
333349 result = None
334- for component in self . components . values ( ):
350+ for component in list ( self ):
335351 value = component .evaluate (x )
336352 result = value if result is None else result + value
337353
@@ -368,10 +384,21 @@ def evaluate_component(
368384 np.ndarray
369385 Evaluated values for the specified component.
370386 """
371- if name not in self .components :
387+ if not self .components :
388+ raise ValueError ("No components in the model to evaluate." )
389+
390+ if not isinstance (name , str ):
391+ raise TypeError (
392+ (f"Component name must be a string, got { type (name )} instead." )
393+ )
394+
395+ matches = [comp for comp in list (self ) if comp .name == name ]
396+ if not matches :
372397 raise KeyError (f"No component named '{ name } ' exists." )
373398
374- result = self .components [name ].evaluate (x )
399+ component = matches [0 ]
400+
401+ result = component .evaluate (x )
375402 if (
376403 self .use_detailed_balance
377404 and self ._temperature is not None
@@ -401,11 +428,7 @@ def get_parameters(self) -> List[Parameter]:
401428 temp_params = (self ._temperature ,) if self ._temperature is not None else ()
402429
403430 # Create generator for component parameters
404- comp_params = (
405- param
406- for comp in self .components .values ()
407- for param in comp .get_parameters ()
408- )
431+ comp_params = (param for comp in list (self ) for param in comp .get_parameters ())
409432
410433 # Chain them together and return as list
411434 return list (chain (temp_params , comp_params ))
@@ -418,16 +441,6 @@ def get_fit_parameters(self) -> List[Parameter]:
418441 List[Parameter]: A list of fit parameters.
419442 """
420443
421- # parameters = self.get_parameters()
422- # fit_parameters = []
423-
424- # for parameter in parameters:
425- # is_not_fixed = not getattr(parameter, "fixed", False)
426- # is_independent = getattr(parameter, "_independent", True)
427-
428- # if is_not_fixed and is_independent:
429- # fit_parameters.append(parameter)
430-
431444 def is_fit_parameter (param : Parameter ) -> bool :
432445 """Check if a parameter can be used for fitting."""
433446 return not getattr (param , "fixed" , False ) and getattr (
@@ -477,87 +490,14 @@ def __copy__(self) -> "SampleModel":
477490 new_model .use_detailed_balance = self .use_detailed_balance
478491 new_model .normalize_detailed_balance = self .normalize_detailed_balance
479492
480- for comp in self . components . values ( ):
493+ for comp in list ( self ):
481494 new_model .add_component (component = copy (comp ), name = comp .name )
482495 new_model [comp .name ].name = comp .name # Remove 'copy of ' prefix
483496 for par in new_model [comp .name ].get_parameters ():
484497 par .name = par .name .removeprefix ("copy of " )
485498
486499 return new_model
487500
488- ##############################################
489- # dict-like behaviour #
490- ##############################################
491-
492- def __getitem__ (self , key : str ) -> ModelComponent :
493- """
494- Access a component by name.
495-
496- Parameters
497- ----------
498- key : str
499- Name of the component.
500-
501- Returns
502- -------
503- ModelComponent
504- """
505- return self .components [key ]
506-
507- def __setitem__ (self , key : str , value : ModelComponent ) -> None :
508- """
509- Set or replace a component.
510-
511- Parameters
512- ----------
513- key : str
514- Name of the component.
515- value : ModelComponent
516- The component to assign.
517- """
518- if not isinstance (value , ModelComponent ):
519- raise TypeError ("Value must be an instance of ModelComponent." )
520- self .components [key ] = value
521-
522- def __delitem__ (self , key : str ) -> None :
523- """
524- Remove a component by name.
525- Parameters
526- ----------
527- key : str
528- Name of the component to remove.
529- """
530- if not isinstance (key , str ):
531- raise TypeError ("Key must be a string." )
532-
533- if key not in self .components :
534- raise KeyError (f"No component named '{ key } ' exists in the model." )
535-
536- self .remove_component (key )
537-
538- def __contains__ (self , name : str ) -> bool :
539- """
540- Check if a component exists in the model.
541-
542- Parameters
543- ----------
544- name : str
545- Name of the component.
546-
547- Returns
548- -------
549- bool
550- """
551- return name in self .components
552-
553- def __iter__ (self ) -> iter :
554- """Iterate over component names."""
555- return iter (self .components )
556-
557- def __len__ (self ) -> int :
558- """Return the number of components in the model."""
559- return len (self .components )
560-
561501 def __repr__ (self ) -> str :
562502 """
563503 Return a string representation of the SampleModel.
@@ -566,10 +506,14 @@ def __repr__(self) -> str:
566506 -------
567507 str
568508 """
569- comp_names = ", " .join (self .components .keys ()) or "No components"
570- temp_str = (
571- f" | Temperature: { self ._temperature .value } { self ._temperature .unit } "
572- if self ._use_detailed_balance
573- else ""
574- )
509+ comp_names = ", " .join (c .name for c in self ) or "No components"
510+
511+ temp_str = ""
512+ if (
513+ getattr (self , "_use_detailed_balance" , False )
514+ and getattr (self , "_temperature" , None ) is not None
515+ ):
516+ temp = self ._temperature
517+ temp_str = f" | Temperature: { temp .value } { temp .unit } "
518+
575519 return f"<SampleModel name='{ self .name } ' | Components: { comp_names } { temp_str } >"
0 commit comments