diff --git a/src/pedon/soilmodel.py b/src/pedon/soilmodel.py index 2f11621..4be12f4 100644 --- a/src/pedon/soilmodel.py +++ b/src/pedon/soilmodel.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt from numpy import abs as npabs -from numpy import exp, full, linspace, log, logspace, log10, ones +from numpy import exp, full, linspace, log, logspace, log10, nan from ._typing import FloatArray @@ -129,12 +129,20 @@ def k(self, h: FloatArray, s: FloatArray | None = None) -> FloatArray: return self.k_s * self.k_r(h=h, s=s) def h(self, theta: FloatArray) -> FloatArray: - h = full(theta.shape, self.h_b) - mask = theta >= self.theta_r - h[mask] = self.h_b * ((theta[mask] - self.theta_r) / (self.s(theta[mask]))) ** ( - -1 / self.l - ) - return h + if isinstance(theta, float): + if theta >= self.theta_r: + return self.h_b * ((theta - self.theta_r) / (self.s(theta))) ** ( + -1 / self.l + ) + else: + return self.h_b + else: + h = full(theta.shape, self.h_b) + mask = theta >= self.theta_r + h[mask] = self.h_b * ( + (theta[mask] - self.theta_r) / (self.s(theta[mask])) + ) ** (-1 / self.l) + return h def plot(self, ax: plt.Axes | None = None) -> plt.Axes: return plot_swrc(self, ax=ax) @@ -170,7 +178,7 @@ def k(self, h: FloatArray, s: FloatArray | None = None) -> FloatArray: return self.k_s * self.k_r(h=h, s=s) def h(self, theta: FloatArray) -> FloatArray: - return (abs(theta) / self.a) ** (-1 / self.b) + return (npabs(theta) / self.a) ** (-1 / self.b) def plot(self, ax: plt.Axes | None = None) -> plt.Axes: return plot_swrc(self, ax=ax) @@ -289,7 +297,7 @@ def k(self, h: FloatArray, s: FloatArray | None = None) -> FloatArray: return self.k_s * self.k_r(h=h, s=s) def h(self, theta: FloatArray) -> FloatArray: - return self.a * (exp ** (self.theta_s / theta) ** (1 / self.m) - exp(1)) ** ( + return self.a * (exp(self.theta_s / theta) ** (1 / self.m) - exp(1)) ** ( 1 / self.n ) @@ -309,10 +317,7 @@ def get_soilmodel(soilmodel_name: str) -> Type[SoilModel]: def plot_swrc( - sm: SoilModel, - saturation: bool = False, - ax: plt.Axes | None = None, - **kwargs: dict, + sm: SoilModel, saturation: bool = False, ax: plt.Axes | None = None, **kwargs ) -> plt.Axes: """Plot soil water retention curve""" @@ -342,7 +347,7 @@ def plot_swrc( def plot_hcf( sm: SoilModel, ax: plt.Axes | None = None, - **kwargs: dict, + **kwargs, ) -> plt.Axes: """Plot the hydraulic conductivity function"""