Skip to content

Commit

Permalink
correct shape (log)pmf
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Jan 26, 2024
1 parent bb413ae commit 6c61716
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions preliz/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,7 @@ def cdf(self, x, *args, **kwds):

def pmf(self, x, *args, **kwds):
if psi_not_valid(self.psi):
return np.nan
return np.full(len(x), np.nan)
x = np.array(x, ndmin=1)
result = np.zeros_like(x, dtype=float)
result[x == 0] = (1 - self.psi) + self.psi * (1 - self.p) ** self.n
Expand All @@ -1262,7 +1262,7 @@ def pmf(self, x, *args, **kwds):

def logpmf(self, x, *args, **kwds):
if psi_not_valid(self.psi):
return np.nan
return np.full(len(x), np.nan)
result = np.zeros_like(x, dtype=float)
result[x == 0] = np.log((1 - self.psi) + self.psi * (1 - self.p) ** self.n)
result[x != 0] = np.log(self.psi) + stats.binom(self.n, self.p, *args, **kwds).logpmf(
Expand Down Expand Up @@ -1321,7 +1321,7 @@ def cdf(self, x, *args, **kwds):

def pmf(self, x, *args, **kwds):
if psi_not_valid(self.psi):
return np.nan
return np.full(len(x), np.nan)
x = np.array(x, ndmin=1)
result = np.zeros_like(x, dtype=float)
result[x == 0] = (1 - self.psi) + self.psi * (self.n / (self.n + self.mu)) ** self.n
Expand All @@ -1330,7 +1330,7 @@ def pmf(self, x, *args, **kwds):

def logpmf(self, x, *args, **kwds):
if psi_not_valid(self.psi):
return np.nan
return np.full(len(x), np.nan)
result = np.zeros_like(x, dtype=float)
result[x == 0] = np.log((1 - self.psi) + self.psi * (self.n / (self.n + self.mu)) ** self.n)
result[x != 0] = np.log(self.psi) + stats.nbinom(self.n, self.p, *args, **kwds).logpmf(
Expand Down Expand Up @@ -1387,7 +1387,7 @@ def cdf(self, x, *args, **kwds):

def pmf(self, x, *args, **kwds):
if psi_not_valid(self.psi):
return np.nan
return np.full(len(x), np.nan)
x = np.array(x, ndmin=1)
result = np.zeros_like(x, dtype=float)
result[x == 0] = (1 - self.psi) + self.psi * np.exp(-self.mu)
Expand All @@ -1396,7 +1396,7 @@ def pmf(self, x, *args, **kwds):

def logpmf(self, x, *args, **kwds):
if psi_not_valid(self.psi):
return np.nan
return np.full(len(x), np.nan)
result = np.zeros_like(x, dtype=float)
result[x == 0] = np.log(np.exp(-self.mu) * self.psi - self.psi + 1)
result[x != 0] = np.log(self.psi) + stats.poisson(self.mu, *args, **kwds).logpmf(x[x != 0])
Expand Down

0 comments on commit 6c61716

Please sign in to comment.