Skip to content

Commit af9b7d9

Browse files
More tests
1 parent d771784 commit af9b7d9

File tree

3 files changed

+421
-109
lines changed

3 files changed

+421
-109
lines changed

examples/component_example.ipynb

Lines changed: 6 additions & 65 deletions
Large diffs are not rendered by default.

src/easydynamics/sample/components.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def copy(self) -> "Gaussian":
286286
return model_copy
287287

288288
def __repr__(self):
289-
return f"Gaussian(name={self.name}, area={self.area}, center={self.center}, width={self.width})"
289+
return f"Gaussian(name = {self.name}, unit = {self.unit},\n area = {self.area},\n center = {self.center},\n width = {self.width})"
290290

291291

292292
class Lorentzian(ModelComponent):
@@ -322,6 +322,9 @@ def __init__(
322322
raise ValueError("The width of a Lorentzian must be greater than zero.")
323323
width = float(width)
324324

325+
if not isinstance(unit, str):
326+
raise TypeError("unit must be a string.")
327+
325328
if isinstance(area, Numeric):
326329
if area < 0:
327330
warnings.warn(
@@ -361,7 +364,7 @@ def __init__(
361364

362365
def evaluate(self, x: Union[Numeric, sc.Variable]) -> Union[float, np.ndarray]:
363366
if self.width.value <= 0:
364-
raise ValueError("Width must be greater than 0 for Lorentzian.")
367+
raise ValueError("Width must be greater than zero for Lorentzian.")
365368
if self.area.value < 0:
366369
warnings.warn(
367370
"The area of the Lorentzian with name {} is negative, which may not be physically meaningful.".format(
@@ -420,7 +423,7 @@ def copy(self) -> "Lorentzian":
420423
return model_copy
421424

422425
def __repr__(self):
423-
return f"Lorentzian(name={self.name}, area={self.area}, center={self.center}, width={self.width})"
426+
return f"Lorentzian(name = {self.name}, unit = {self.unit},\n area = {self.area},\n center = {self.center},\n width = {self.width})"
424427

425428

426429
class Voigt(ModelComponent):
@@ -456,17 +459,20 @@ def __init__(
456459
if not isinstance(lorentzian_width, (Numeric, Parameter)):
457460
raise TypeError("lorentzian_width must be a number or a Parameter.")
458461

462+
if not isinstance(unit, str):
463+
raise TypeError("unit must be a string.")
464+
459465
if isinstance(gaussian_width, Numeric):
460466
if gaussian_width <= 0:
461467
raise ValueError(
462-
"gaussian_width must be greater than 0 for Voigt profile."
468+
"The gaussian_width of a Voigt must be greater than zero."
463469
)
464470
gaussian_width = float(gaussian_width)
465471

466472
if isinstance(lorentzian_width, Numeric):
467473
if lorentzian_width <= 0:
468474
raise ValueError(
469-
"lorentzian_width must be greater than 0 for Voigt profile."
475+
"The lorentzian_width of a Voigt must be greater than zero."
470476
)
471477
lorentzian_width = float(lorentzian_width)
472478

@@ -516,10 +522,12 @@ def __init__(
516522

517523
def evaluate(self, x: Union[Numeric, sc.Variable]) -> Union[float, np.ndarray]:
518524
if self.gaussian_width.value <= 0:
519-
raise ValueError("gaussian_width must be greater than 0 for Voigt profile.")
525+
raise ValueError(
526+
"gaussian_width must be greater than zero for Voigt profile."
527+
)
520528
if self.lorentzian_width.value <= 0:
521529
raise ValueError(
522-
"lorentzian_width must be greater than 0 for Voigt profile."
530+
"lorentzian_width must be greater than zero for Voigt profile."
523531
)
524532
if self.area.value < 0:
525533
warnings.warn(
@@ -582,7 +590,7 @@ def copy(self) -> "Voigt":
582590
return model_copy
583591

584592
def __repr__(self):
585-
return f"Voigt(name={self.name}, area={self.area}, center={self.center}, gaussian_width={self.gaussian_width}, lorentzian_width={self.lorentzian_width})"
593+
return f"Voigt(name = {self.name}, unit = {self.unit},\n area = {self.area},\n center = {self.center},\n gaussian_width = {self.gaussian_width},\n lorentzian_width = {self.lorentzian_width})"
586594

587595

588596
class DeltaFunction(ModelComponent):
@@ -608,6 +616,9 @@ def __init__(
608616
if center is not None and not isinstance(center, (Numeric, Parameter)):
609617
raise TypeError("center must be None, a number or a Parameter.")
610618

619+
if not isinstance(unit, str):
620+
raise TypeError("unit must be a string.")
621+
611622
if isinstance(area, Numeric):
612623
if area < 0:
613624
warnings.warn(
@@ -633,7 +644,7 @@ def __init__(
633644
self.center = center
634645

635646
if isinstance(area, Numeric):
636-
self.area = Parameter(name=name + " area", value=area, unit=unit, min=0.0)
647+
self.area = Parameter(name=name + " area", value=area, unit=unit)
637648
else:
638649
self.area = area
639650

@@ -681,9 +692,7 @@ def copy(self) -> "DeltaFunction":
681692
return model_copy
682693

683694
def __repr__(self):
684-
return (
685-
f"DeltaFunction(name={self.name}, area={self.area}, center={self.center})"
686-
)
695+
return f"DeltaFunction(name = {self.name}, unit = {self.unit},\n area = {self.area},\n center = {self.center}"
687696

688697

689698
class DampedHarmonicOscillator(ModelComponent):
@@ -714,6 +723,9 @@ def __init__(
714723
if not isinstance(width, (Numeric, Parameter)):
715724
raise TypeError("width must be a number or a Parameter.")
716725

726+
if not isinstance(unit, str):
727+
raise TypeError("unit must be a string.")
728+
717729
if isinstance(width, Numeric):
718730
width = float(width)
719731
if width <= 0:
@@ -756,7 +768,7 @@ def __init__(
756768
def evaluate(self, x: Union[Numeric, sc.Variable]) -> Union[float, np.ndarray]:
757769
if self.width.value <= 0:
758770
raise ValueError(
759-
"Width of a Damped Harmonic Oscillator must be greater than 0."
771+
"Width of a Damped Harmonic Oscillator must be greater than zero."
760772
)
761773
if self.area.value < 0:
762774
warnings.warn(
@@ -826,7 +838,7 @@ def copy(self) -> "DampedHarmonicOscillator":
826838
return model_copy
827839

828840
def __repr__(self):
829-
return f"DampedHarmonicOscillator(name={self.name}, area={self.area}, center={self.center}, width={self.width})"
841+
return f"DampedHarmonicOscillator(name = {self.name}, unit = {self.unit},\n area = {self.area},\n center = {self.center},\n width = {self.width})"
830842

831843

832844
class Polynomial(ModelComponent):
@@ -847,6 +859,12 @@ def __init__(
847859
if not isinstance(coefficients, (list, tuple, np.ndarray)):
848860
raise TypeError("coefficients must be a list, tuple or ndarray of floats.")
849861

862+
if not all(isinstance(c, Numeric) for c in coefficients):
863+
raise TypeError("All coefficients must be numbers.")
864+
865+
if not isinstance(unit, str):
866+
raise TypeError("unit must be a string.")
867+
850868
super().__init__(name=name)
851869
if not coefficients:
852870
raise ValueError("At least one coefficient must be provided.")
@@ -908,7 +926,7 @@ def __repr__(self):
908926
coeffs_str = ", ".join(
909927
f"{param.name}={param.value}" for param in self.coefficients
910928
)
911-
return f"Polynomial(name={self.name}, coefficients=[{coeffs_str}])"
929+
return f"Polynomial(name = {self.name}, unit = {self.unit},\n coefficients = [{coeffs_str}])"
912930

913931
def convert_unit(self, unit):
914932
raise NotImplementedError(

0 commit comments

Comments
 (0)