Skip to content

Commit

Permalink
Add ewm_parameters (#726)
Browse files Browse the repository at this point in the history
  • Loading branch information
dycw authored Sep 11, 2024
1 parent 94dc5c2 commit 93c390f
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 5 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ python-dateutil==2.9.0.post0
# via pandas
python-dotenv==1.0.1
# via dycw-utilities (pyproject.toml)
pytz==2024.1
pytz==2024.2
# via pandas
pyyaml==6.0.2
# via optuna
Expand Down
2 changes: 1 addition & 1 deletion requirements/jupyter.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ python-dateutil==2.9.0.post0
# pandas
python-json-logger==2.0.7
# via jupyter-events
pytz==2024.1
pytz==2024.2
# via pandas
pyyaml==6.0.2
# via jupyter-events
Expand Down
2 changes: 1 addition & 1 deletion requirements/streamlit.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pytest-rerunfailures==14.0
# via dycw-utilities (pyproject.toml)
python-dateutil==2.9.0.post0
# via pandas
pytz==2024.1
pytz==2024.2
# via pandas
referencing==0.35.1
# via
Expand Down
67 changes: 67 additions & 0 deletions src/tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@
MIN_UINT64,
CheckIntegerError,
NumberOfDecimalsError,
_EWMParameters,
_EWMParametersAlphaError,
_EWMParametersArgumentsError,
_EWMParametersCOMError,
_EWMParametersHalfLifeError,
_EWMParametersSpanError,
check_integer,
ewm_parameters,
is_at_least,
is_at_least_or_nan,
is_at_most,
Expand Down Expand Up @@ -123,6 +130,66 @@ def test_max_error(self) -> None:
check_integer(1, max=0)


class TestEWMParameters:
expected: ClassVar[_EWMParameters] = _EWMParameters(
com=1.0, span=3.0, half_life=1.0, alpha=0.5
)

def test_com(self) -> None:
result = ewm_parameters(com=1.0)
assert result == self.expected

def test_span(self) -> None:
result = ewm_parameters(span=3.0)
assert result == self.expected

def test_half_life(self) -> None:
result = ewm_parameters(half_life=1.0)
assert result == self.expected

def test_alpha(self) -> None:
result = ewm_parameters(alpha=0.5)
assert result == self.expected

def test_error_com(self) -> None:
with raises(
_EWMParametersCOMError,
match=escape(r"Center of mass (γ) must be positive; got 0.0"), # noqa: RUF001
):
_ = ewm_parameters(com=0.0)

def test_error_span(self) -> None:
with raises(
_EWMParametersSpanError,
match=escape("Span (θ) must be greater than 1; got 1.0"),
):
_ = ewm_parameters(span=1.0)

def test_error_half_life(self) -> None:
with raises(
_EWMParametersHalfLifeError,
match=escape("Half-life (λ) must be positive; got 0.0"),
):
_ = ewm_parameters(half_life=0.0)

@mark.parametrize("alpha", [param(0.0), param(1.0)], ids=str)
def test_error_alpha(self, *, alpha: float) -> None:
with raises(
_EWMParametersAlphaError,
match=r"Smoothing factor \(α\) must be between 0 and 1 \(exclusive\); got \d\.0", # noqa: RUF001
):
_ = ewm_parameters(alpha=alpha)

def test_error_arguments(self) -> None:
with raises(
_EWMParametersArgumentsError,
match=escape(
r"Exactly one of center of mass (γ), span (θ), half-life (λ) and smoothing factor (α) must be given; got γ=None, θ=None, λ=None and α=None" # noqa: RUF001
),
):
_ = ewm_parameters()


class TestIsAtLeast:
@mark.parametrize(
("x", "y", "expected"),
Expand Down
2 changes: 1 addition & 1 deletion src/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import annotations

__version__ = "0.53.0"
__version__ = "0.53.1"
118 changes: 117 additions & 1 deletion src/utilities/math.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from math import isclose, isfinite, isnan, log10
from math import exp, isclose, isfinite, isnan, log, log10
from typing import Literal, overload

from typing_extensions import override
Expand All @@ -20,6 +20,120 @@
# functions


@dataclass(frozen=True, kw_only=True)
class _EWMParameters:
"""A set of EWM parameters."""

com: float
span: float
half_life: float
alpha: float


def ewm_parameters(
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
) -> _EWMParameters:
"""Compute a set of EWM parameters."""
if (com is not None) and (span is None) and (half_life is None) and (alpha is None):
if com <= 0:
raise _EWMParametersCOMError(com=com)
alpha = 1 / (1 + com)
return _EWMParameters(
com=com,
span=_ewm_parameters_alpha_to_span(alpha),
half_life=_ewm_parameters_alpha_to_half_life(alpha),
alpha=alpha,
)
if (com is None) and (span is not None) and (half_life is None) and (alpha is None):
if span <= 1:
raise _EWMParametersSpanError(span=span)
alpha = 2 / (span + 1)
return _EWMParameters(
com=_ewm_parameters_alpha_to_com(alpha),
span=span,
half_life=_ewm_parameters_alpha_to_half_life(alpha),
alpha=alpha,
)
if (com is None) and (span is None) and (half_life is not None) and (alpha is None):
if half_life <= 0:
raise _EWMParametersHalfLifeError(half_life=half_life)
alpha = 1 - exp(-log(2) / half_life)
return _EWMParameters(
com=_ewm_parameters_alpha_to_com(alpha),
span=_ewm_parameters_alpha_to_span(alpha),
half_life=half_life,
alpha=alpha,
)
if (com is None) and (span is None) and (half_life is None) and (alpha is not None):
if not (0 < alpha < 1):
raise _EWMParametersAlphaError(alpha=alpha)
return _EWMParameters(
com=_ewm_parameters_alpha_to_com(alpha),
span=_ewm_parameters_alpha_to_span(alpha),
half_life=_ewm_parameters_alpha_to_half_life(alpha),
alpha=alpha,
)
raise _EWMParametersArgumentsError(
com=com, span=span, half_life=half_life, alpha=alpha
)


@dataclass(kw_only=True)
class EWMParametersError(Exception):
com: float | None = None
span: float | None = None
half_life: float | None = None
alpha: float | None = None


@dataclass(kw_only=True)
class _EWMParametersCOMError(EWMParametersError):
@override
def __str__(self) -> str:
return f"Center of mass (γ) must be positive; got {self.com}" # noqa: RUF001


@dataclass(kw_only=True)
class _EWMParametersSpanError(EWMParametersError):
@override
def __str__(self) -> str:
return f"Span (θ) must be greater than 1; got {self.span}"


class _EWMParametersHalfLifeError(EWMParametersError):
@override
def __str__(self) -> str:
return f"Half-life (λ) must be positive; got {self.half_life}"


class _EWMParametersAlphaError(EWMParametersError):
@override
def __str__(self) -> str:
return f"Smoothing factor (α) must be between 0 and 1 (exclusive); got {self.alpha}" # noqa: RUF001


class _EWMParametersArgumentsError(EWMParametersError):
@override
def __str__(self) -> str:
return f"Exactly one of center of mass (γ), span (θ), half-life (λ) and smoothing factor (α) must be given; got γ={self.com}, θ={self.span}, λ={self.half_life} and α={self.alpha}" # noqa: RUF001


def _ewm_parameters_alpha_to_com(alpha: float, /) -> float:
return 1 / alpha - 1


def _ewm_parameters_alpha_to_span(alpha: float, /) -> float:
return 2 / alpha - 1


def _ewm_parameters_alpha_to_half_life(alpha: float, /) -> float:
return -log(2) / log(1 - alpha)


def is_equal(x: float, y: float, /) -> bool:
"""Check if x == y."""
return (x == y) or (isnan(x) and isnan(y))
Expand Down Expand Up @@ -488,7 +602,9 @@ def __str__(self) -> str:
"MIN_UINT32",
"MIN_UINT64",
"CheckIntegerError",
"EWMParametersError",
"check_integer",
"ewm_parameters",
"is_at_least",
"is_at_least_or_nan",
"is_at_most",
Expand Down

0 comments on commit 93c390f

Please sign in to comment.