Skip to content

Commit

Permalink
fix type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
martinvonk committed Oct 4, 2023
1 parent e6cdfd0 commit 9d2add9
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions src/pedon/soilmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import matplotlib.pyplot as plt
from numpy import abs as npabs
from numpy import exp, full, linspace, log, logspace, log10, ones
from numpy import exp, full, linspace, log, logspace, log10, nan

from ._typing import FloatArray

Expand Down Expand Up @@ -129,12 +129,20 @@ def k(self, h: FloatArray, s: FloatArray | None = None) -> FloatArray:
return self.k_s * self.k_r(h=h, s=s)

def h(self, theta: FloatArray) -> FloatArray:
h = full(theta.shape, self.h_b)
mask = theta >= self.theta_r
h[mask] = self.h_b * ((theta[mask] - self.theta_r) / (self.s(theta[mask]))) ** (
-1 / self.l
)
return h
if isinstance(theta, float):
if theta >= self.theta_r:
return self.h_b * ((theta - self.theta_r) / (self.s(theta))) ** (
-1 / self.l
)
else:
return self.h_b
else:
h = full(theta.shape, self.h_b)
mask = theta >= self.theta_r
h[mask] = self.h_b * (
(theta[mask] - self.theta_r) / (self.s(theta[mask]))
) ** (-1 / self.l)
return h

def plot(self, ax: plt.Axes | None = None) -> plt.Axes:
return plot_swrc(self, ax=ax)
Expand Down Expand Up @@ -170,7 +178,7 @@ def k(self, h: FloatArray, s: FloatArray | None = None) -> FloatArray:
return self.k_s * self.k_r(h=h, s=s)

def h(self, theta: FloatArray) -> FloatArray:
return (abs(theta) / self.a) ** (-1 / self.b)
return (npabs(theta) / self.a) ** (-1 / self.b)

def plot(self, ax: plt.Axes | None = None) -> plt.Axes:
return plot_swrc(self, ax=ax)
Expand Down Expand Up @@ -289,7 +297,7 @@ def k(self, h: FloatArray, s: FloatArray | None = None) -> FloatArray:
return self.k_s * self.k_r(h=h, s=s)

def h(self, theta: FloatArray) -> FloatArray:
return self.a * (exp ** (self.theta_s / theta) ** (1 / self.m) - exp(1)) ** (
return self.a * (exp(self.theta_s / theta) ** (1 / self.m) - exp(1)) ** (
1 / self.n
)

Expand All @@ -309,10 +317,7 @@ def get_soilmodel(soilmodel_name: str) -> Type[SoilModel]:


def plot_swrc(
sm: SoilModel,
saturation: bool = False,
ax: plt.Axes | None = None,
**kwargs: dict,
sm: SoilModel, saturation: bool = False, ax: plt.Axes | None = None, **kwargs
) -> plt.Axes:
"""Plot soil water retention curve"""

Expand Down Expand Up @@ -342,7 +347,7 @@ def plot_swrc(
def plot_hcf(
sm: SoilModel,
ax: plt.Axes | None = None,
**kwargs: dict,
**kwargs,
) -> plt.Axes:
"""Plot the hydraulic conductivity function"""

Expand Down

0 comments on commit 9d2add9

Please sign in to comment.