Skip to content

Commit f9b7fa1

Browse files
update type hinting
1 parent 39adf19 commit f9b7fa1

File tree

8 files changed

+28
-60
lines changed

8 files changed

+28
-60
lines changed

src/easydynamics/sample_model/components/damped_harmonic_oscillator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Union
24

35
import numpy as np
@@ -138,7 +140,7 @@ def convert_unit(self, unit: str):
138140
self.width.convert_unit(unit)
139141
self.unit = unit
140142

141-
def copy(self) -> "DampedHarmonicOscillator":
143+
def copy(self) -> DampedHarmonicOscillator:
142144
"""
143145
Return a deep copy of this component with independent parameters.
144146
"""

src/easydynamics/sample_model/components/delta_function.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Union
24

35
from easyscience.variable import Parameter
@@ -94,7 +96,7 @@ def convert_unit(self, unit):
9496
self.center.convert_unit(unit)
9597
self.unit = unit
9698

97-
def copy(self) -> "DeltaFunction":
99+
def copy(self) -> DeltaFunction:
98100
"""
99101
Return a deep copy of this component with independent parameters.
100102
"""

src/easydynamics/sample_model/components/gaussian.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Union, List
24

35
import numpy as np
@@ -131,7 +133,7 @@ def convert_unit(self, unit: str):
131133
self.width.convert_unit(unit)
132134
self.unit = unit
133135

134-
def copy(self) -> "Gaussian":
136+
def copy(self) -> Gaussian:
135137
"""
136138
Return a deep copy of this component with independent parameters.
137139
"""

src/easydynamics/sample_model/components/lorentzian.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Union
24

35
import numpy as np
@@ -133,7 +135,7 @@ def convert_unit(self, unit: str):
133135
self.width.convert_unit(unit)
134136
self.unit = unit
135137

136-
def copy(self) -> "Lorentzian":
138+
def copy(self) -> Lorentzian:
137139
model_copy = Lorentzian(
138140
name=self.name,
139141
area=self.area.value,

src/easydynamics/sample_model/components/model_component.py

Lines changed: 8 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from __future__ import annotations
2+
13
from abc import abstractmethod
24

3-
from typing import Union, List, Optional
5+
from typing import Union, List
46

57
import numpy as np
68

@@ -25,14 +27,11 @@ def fix_all_parameters(self):
2527
"""Fix all parameters in the model component."""
2628

2729
pars = self.get_parameters()
28-
if pars is None or len(pars) == 0:
29-
raise ValueError("No parameters found to fix.")
30-
else:
31-
for p in pars:
32-
p.fixed = True
30+
for p in pars:
31+
p.fixed = True
3332

34-
def fit_all_parameters(self):
35-
"""Fit all parameters in the model component."""
33+
def free_all_parameters(self):
34+
"""Free all parameters in the model component."""
3635
for p in self.get_parameters():
3736
p.fixed = False
3837

@@ -66,49 +65,6 @@ def get_parameter(self, parameter_name: str) -> Parameter:
6665
else:
6766
raise ValueError(f"Parameter '{parameter_name}' not found.")
6867

69-
def set_parameter_value(
70-
self, parameter_name: str, value: float, unit: Optional[str] = None
71-
):
72-
"""
73-
Set the value of a specific parameter by name.
74-
"""
75-
param = self.get_parameter(parameter_name)
76-
if unit is not None:
77-
param.convert_unit(unit)
78-
param.value = value
79-
80-
def set_parameter_bounds(
81-
self,
82-
parameter_name: str,
83-
min: Union[float, None] = None,
84-
max: Union[float, None] = None,
85-
unit: Optional[str] = None,
86-
):
87-
"""
88-
Set the bounds of a specific parameter by name.
89-
"""
90-
param = self.get_parameter(parameter_name)
91-
if unit is not None:
92-
param.convert_unit(unit)
93-
if min is not None:
94-
param.min = min
95-
if max is not None:
96-
param.max = max
97-
98-
def fix_parameter(self, parameter_name: str):
99-
"""
100-
Fix a specific parameter by name.
101-
"""
102-
param = self.get_parameter(parameter_name)
103-
param.fixed = True
104-
105-
def free_parameter(self, parameter_name: str):
106-
"""
107-
Free a specific parameter by name.
108-
"""
109-
param = self.get_parameter(parameter_name)
110-
param.fixed = False
111-
11268
@abstractmethod
11369
def evaluate(self, x: Union[Numeric, sc.Variable]) -> np.ndarray:
11470
"""
@@ -135,7 +91,7 @@ def get_parameters(self) -> List[Parameter]:
13591
pass
13692

13793
@abstractmethod
138-
def copy(self) -> "ModelComponent":
94+
def copy(self) -> ModelComponent:
13995
"""
14096
Return a deep copy of this component with independent parameters.
14197
"""

src/easydynamics/sample_model/components/polynomial.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Union
24

35
import numpy as np
@@ -82,7 +84,7 @@ def get_parameters(self):
8284
"""
8385
return self.coefficients
8486

85-
def copy(self) -> "Polynomial":
87+
def copy(self) -> Polynomial:
8688
"""
8789
Return a deep copy of this component with independent parameters.
8890
"""

src/easydynamics/sample_model/components/voigt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from scipy.special import voigt_profile
24

35
from typing import Union
@@ -160,7 +162,7 @@ def get_parameters(self):
160162
"""
161163
return [self.area, self.center, self.gaussian_width, self.lorentzian_width]
162164

163-
def copy(self) -> "Voigt":
165+
def copy(self) -> Voigt:
164166
model_copy = Voigt(
165167
name=self.name,
166168
area=self.area.value,

tests/unit_tests/sample_model/test_components.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def test_fix_all_parameters_sets_all_to_fixed(self, dummy):
4646
# THEN EXPECT
4747
assert all(p.fixed for p in dummy.get_parameters())
4848

49-
def test_fit_all_parameters_sets_all_to_unfixed(self, dummy):
49+
def test_free_all_parameters_sets_all_to_unfixed(self, dummy):
5050
# WHEN
51-
dummy.fit_all_parameters()
51+
dummy.free_all_parameters()
5252

5353
# THEN EXPECT
5454
assert all(not p.fixed for p in dummy.get_parameters())

0 commit comments

Comments
 (0)