Skip to content

Commit 0eb1bba

Browse files
Merge branch 'develop' into SampleModel2
2 parents ce49f05 + 01b9490 commit 0eb1bba

File tree

16 files changed

+869
-838
lines changed

16 files changed

+869
-838
lines changed
Lines changed: 30 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
from __future__ import annotations
22

3-
import warnings
4-
from typing import Union
3+
from typing import Optional, Union
54

65
import numpy as np
76
import scipp as sc
87
from easyscience.variable import Parameter
98

9+
from easydynamics.sample_model.components.mixins import CreateParametersMixin
10+
1011
from .model_component import ModelComponent
1112

1213
Numeric = Union[float, int]
1314

14-
MINIMUM_WIDTH = 1e-10 # To avoid division by zero
15-
1615

17-
class DampedHarmonicOscillator(ModelComponent):
16+
class DampedHarmonicOscillator(CreateParametersMixin, ModelComponent):
1817
"""
1918
Damped Harmonic Oscillator (DHO). 2*area*center^2*width/pi / ( (x^2 - center^2)^2 + (2*width*x)^2 )
2019
@@ -28,97 +27,30 @@ class DampedHarmonicOscillator(ModelComponent):
2827

2928
def __init__(
3029
self,
31-
name: str = "DHO",
32-
center: Numeric = 1.0,
33-
width: Numeric = 1.0,
34-
area: Numeric = 1.0,
35-
unit: Union[str, sc.Unit] = "meV",
30+
name: Optional[str] = "DampedHarmonicOscillator",
31+
area: Optional[Union[Numeric, Parameter]] = 1.0,
32+
center: Optional[Union[Numeric, Parameter]] = 1.0,
33+
width: Optional[Union[Numeric, Parameter]] = 1.0,
34+
unit: Optional[Union[str, sc.Unit]] = "meV",
3635
):
37-
# Validate inputs
38-
if not isinstance(area, Numeric):
39-
raise TypeError("area must be a number.")
40-
area = float(area)
41-
if area < 0:
42-
warnings.warn(
43-
"The area of the Damped Harmonic Oscillator with name {} is negative, which may not be physically meaningful.".format(
44-
name
45-
)
46-
)
47-
48-
if not isinstance(center, Numeric):
49-
raise TypeError("center must be a number.")
50-
51-
center = float(center)
52-
53-
if not isinstance(width, Numeric):
54-
raise TypeError("width must be a number.")
55-
56-
width = float(width)
57-
if width <= 0:
58-
raise ValueError(
59-
"The width of a DampedHarmonicOscillator must be greater than zero."
60-
)
61-
62-
super().__init__(name=name, unit=unit)
63-
64-
# Create Parameters from floats
65-
self._area = Parameter(name=name + " area", value=area, unit=unit)
66-
if area > 0:
67-
self._area.min = 0.0
68-
69-
self._center = Parameter(name=name + " center", value=center, unit=unit)
70-
71-
self._width = Parameter(
72-
name=name + " width", value=width, unit=unit, min=MINIMUM_WIDTH
36+
# Validate inputs and create Parameters if not given
37+
self.validate_unit(unit)
38+
self._unit = unit
39+
40+
# These methods live in ValidationMixin
41+
area = self._create_area_parameter(area=area, name=name, unit=self._unit)
42+
center = self._create_center_parameter(
43+
center=center, name=name, fix_if_none=False, unit=self._unit
7344
)
45+
width = self._create_width_parameter(width=width, name=name, unit=self._unit)
7446

75-
@property
76-
def area(self) -> Parameter:
77-
"""Return the area parameter."""
78-
return self._area
79-
80-
@area.setter
81-
def area(self, value: Numeric):
82-
"""Set the area parameter."""
83-
if not isinstance(value, Numeric):
84-
raise TypeError("area must be a number.")
85-
value = float(value)
86-
if value < 0:
87-
warnings.warn(
88-
"The area of the Damped Harmonic Oscillator with name {} is negative, which may not be physically meaningful.".format(
89-
self.name
90-
)
91-
)
92-
self._area.value = float(value)
93-
94-
@property
95-
def center(self) -> Parameter:
96-
"""Return the center parameter."""
97-
return self._center
98-
99-
@center.setter
100-
def center(self, value: Numeric):
101-
"""Set the center parameter."""
102-
if not isinstance(value, Numeric):
103-
raise TypeError("center must be a number.")
104-
self._center.value = float(value)
105-
106-
@property
107-
def width(self) -> Parameter:
108-
"""Return the width parameter."""
109-
return self._width
110-
111-
@width.setter
112-
def width(self, value: Numeric):
113-
"""Set the width parameter."""
114-
if not isinstance(value, Numeric):
115-
raise TypeError("width must be a number.")
116-
value = float(value)
117-
if value <= 0:
118-
raise ValueError(
119-
"The width of a DampedHarmonicOscillator must be greater than zero."
120-
)
121-
self._width.value = value
47+
super().__init__(
48+
name=name,
49+
unit=unit,
50+
area=area,
51+
center=center,
52+
width=width,
53+
)
12254

12355
def evaluate(
12456
self, x: Union[Numeric, list, np.ndarray, sc.Variable, sc.DataArray]
@@ -129,53 +61,14 @@ def evaluate(
12961

13062
x = self._prepare_x_for_evaluate(x)
13163

132-
normalization = 2 * self._center.value**2 * self._width.value / np.pi
133-
denominator = (x**2 - self._center.value**2) ** 2 + (
64+
normalization = 2 * self.center.value**2 * self.width.value / np.pi
65+
denominator = (x**2 - self.center.value**2) ** 2 + (
13466
2
135-
* self._width.value
67+
* self.width.value
13668
* x # No division by zero here, width>0 enforced in setter
13769
) ** 2
13870

139-
return self._area.value * normalization / (denominator)
140-
141-
def get_parameters(self):
142-
"""
143-
Get all parameters from the model component.
144-
Returns:
145-
List[Parameter]: List of parameters in the component.
146-
"""
147-
return [self._area, self._center, self._width]
148-
149-
def convert_unit(self, unit: Union[str, sc.Unit]):
150-
"""
151-
Convert the unit of the Parameters in the component.
152-
153-
Args:
154-
unit (str or sc.Unit): The new unit to convert to.
155-
"""
156-
157-
self._area.convert_unit(unit)
158-
self._center.convert_unit(unit)
159-
self._width.convert_unit(unit)
160-
self._unit = unit
161-
162-
def __copy__(self) -> DampedHarmonicOscillator:
163-
"""
164-
Return a deep copy of this component with independent parameters.
165-
"""
166-
name = "copy of " + self.name
167-
168-
model_copy = DampedHarmonicOscillator(
169-
name=name,
170-
area=self._area.value,
171-
center=self._center.value,
172-
width=self._width.value,
173-
unit=self._unit,
174-
)
175-
model_copy._area.fixed = self._area.fixed
176-
model_copy._center.fixed = self._center.fixed
177-
model_copy._width.fixed = self._width.fixed
178-
return model_copy
71+
return self.area.value * normalization / (denominator)
17972

18073
def __repr__(self):
181-
return f"DampedHarmonicOscillator(name = {self.name}, unit = {self._unit},\n area = {self._area},\n center = {self._center},\n width = {self._width})"
74+
return f"DampedHarmonicOscillator(name = {self.name}, unit = {self._unit},\n area = {self.area},\n center = {self.center},\n width = {self.width})"
Lines changed: 26 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
from __future__ import annotations
22

3-
import warnings
4-
from typing import Union
3+
from typing import Optional, Union
54

65
import numpy as np
76
import scipp as sc
87
from easyscience.variable import Parameter
98

9+
from easydynamics.sample_model.components.mixins import CreateParametersMixin
10+
1011
from .model_component import ModelComponent
1112

1213
Numeric = Union[float, int]
1314

1415
EPSILON = 1e-8 # small number to avoid floating point issues
1516

1617

17-
class DeltaFunction(ModelComponent):
18+
class DeltaFunction(CreateParametersMixin, ModelComponent):
1819
"""
1920
Delta function. Evaluates to zero everywhere, except in convolutions, where it acts as an identity. This is handled in the ResolutionHandler.
2021
If the center is not provided, it will be centered at 0 and fixed, which is typically what you want in QENS.
@@ -28,72 +29,27 @@ class DeltaFunction(ModelComponent):
2829

2930
def __init__(
3031
self,
31-
name: str = "DeltaFunction",
32-
center: Union[None, Numeric] = None,
33-
area: Numeric = 1.0,
32+
name: Optional[str] = "DeltaFunction",
33+
center: Optional[Union[None, Numeric, Parameter]] = None,
34+
area: Optional[Union[Numeric, Parameter]] = 1.0,
3435
unit: Union[str, sc.Unit] = "meV",
3536
):
36-
# Validate inputs
37-
if not isinstance(area, Numeric):
38-
raise TypeError("area must be a number.")
39-
40-
if area < 0:
41-
warnings.warn(
42-
"The area of the Delta function with name {} is negative, which may not be physically meaningful.".format(
43-
name
44-
)
45-
)
46-
area = float(area)
47-
48-
if center is not None and not isinstance(center, Numeric):
49-
raise TypeError("center must be None or a number.")
50-
51-
if isinstance(center, Numeric):
52-
center = float(center)
53-
54-
super().__init__(name=name, unit=unit)
55-
# Create Parameters from floats, or set Parameters if already provided
56-
self._area = Parameter(name=name + " area", value=area, unit=unit)
57-
if area > 0:
58-
self._area.min = 0.0
59-
60-
if center is None:
61-
self._center = Parameter(
62-
name=name + " center", value=0.0, unit=unit, fixed=True
63-
)
64-
else:
65-
self._center = Parameter(name=name + " center", value=center, unit=unit)
66-
67-
@property
68-
def area(self) -> Parameter:
69-
"""Return the area parameter."""
70-
return self._area
71-
72-
@area.setter
73-
def area(self, value: Numeric):
74-
"""Set the area parameter."""
75-
if not isinstance(value, Numeric):
76-
raise TypeError("area must be a number.")
77-
value = float(value)
78-
if value < 0:
79-
warnings.warn(
80-
"The area of the Delta function with name {} is negative, which may not be physically meaningful.".format(
81-
self.name
82-
)
83-
)
84-
self._area.value = value
85-
86-
@property
87-
def center(self) -> Parameter:
88-
"""Return the center parameter."""
89-
return self._center
90-
91-
@center.setter
92-
def center(self, value: Numeric):
93-
"""Set the center parameter."""
94-
if not isinstance(value, Numeric):
95-
raise TypeError("center must be a number.")
96-
self._center.value = float(value)
37+
# Validate inputs and create Parameters if not given
38+
self.validate_unit(unit)
39+
self._unit = unit
40+
41+
# These methods live in ValidationMixin
42+
area = self._create_area_parameter(area=area, name=name, unit=self._unit)
43+
center = self._create_center_parameter(
44+
center=center, name=name, fix_if_none=True, unit=self._unit
45+
)
46+
47+
super().__init__(
48+
name=name,
49+
unit=unit,
50+
area=area,
51+
center=center,
52+
)
9753

9854
def evaluate(
9955
self, x: Union[Numeric, list, np.ndarray, sc.Variable, sc.DataArray]
@@ -105,8 +61,8 @@ def evaluate(
10561
# x assumed sorted, 1D numpy array
10662
x = self._prepare_x_for_evaluate(x)
10763
model = np.zeros_like(x, dtype=float)
108-
center = self._center.value
109-
area = self._area.value
64+
center = self.center.value
65+
area = self.area.value
11066

11167
if x.min() - EPSILON <= center <= x.max() + EPSILON:
11268
# nearest index
@@ -131,40 +87,5 @@ def evaluate(
13187

13288
return model
13389

134-
def get_parameters(self):
135-
"""
136-
Get all parameters from the model component.
137-
Returns:
138-
List[Parameter]: List of parameters in the component.
139-
"""
140-
return [self._area, self._center]
141-
142-
def convert_unit(self, unit: Union[str, sc.Unit]):
143-
"""
144-
Convert the unit of the Parameters in the component.
145-
146-
Args:
147-
unit (str or sc.Unit): The new unit to convert to.
148-
"""
149-
self._area.convert_unit(unit)
150-
self._center.convert_unit(unit)
151-
self._unit = unit
152-
153-
def __copy__(self) -> DeltaFunction:
154-
"""
155-
Return a deep copy of this component with independent parameters.
156-
"""
157-
name = "copy of " + self.name
158-
159-
model_copy = DeltaFunction(
160-
name=name,
161-
area=self._area.value,
162-
center=self._center.value,
163-
unit=self._unit,
164-
)
165-
model_copy._area.fixed = self._area.fixed
166-
model_copy._center.fixed = self._center.fixed
167-
return model_copy
168-
16990
def __repr__(self):
170-
return f"DeltaFunction(name = {self.name}, unit = {self._unit},\n area = {self._area},\n center = {self._center}"
91+
return f"DeltaFunction(name = {self.name}, unit = {self._unit},\n area = {self.area},\n center = {self.center}"

0 commit comments

Comments
 (0)