Skip to content

Commit

Permalink
use moment for older pymc
Browse files Browse the repository at this point in the history
  • Loading branch information
adrn committed Aug 20, 2024
1 parent 230900a commit cbe5db9
Showing 1 changed file with 48 additions and 23 deletions.
71 changes: 48 additions & 23 deletions thejoker/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,30 @@ def rng_fn(cls, rng, a, b, size):
uu = rng.uniform(size=size)
return np.exp(uu * _fac + np.log(a))

uniformlog = UniformLogRV()

class UniformLog(pm.Continuous):
rv_op = uniformlog

@classmethod
def dist(cls, a, b, **kwargs):
a = pt.as_tensor_variable(a)
b = pt.as_tensor_variable(b)
return super().dist([a, b], **kwargs)

def support_point(rv, size, a, b):
a, b = pt.broadcast_arrays(a, b)
return 0.5 * (a + b)

def logp(value, a, b):
_fac = pt.log(b) - pt.log(a)
res = -pt.as_tensor_variable(value) - pt.log(_fac)
return check_parameters(
res,
(a > 0) & (a < b),
msg="a > 0 and a < b",
)

else: # old behavior

class UniformLogRV(RandomVariable):
Expand All @@ -41,31 +65,32 @@ def rng_fn(cls, rng, a, b, size):
uu = rng.uniform(size=size)
return np.exp(uu * _fac + np.log(a))

uniformlog = UniformLogRV()

uniformlog = UniformLogRV()

class UniformLog(pm.Continuous):
rv_op = uniformlog

class UniformLog(pm.Continuous):
rv_op = uniformlog

@classmethod
def dist(cls, a, b, **kwargs):
a = pt.as_tensor_variable(a)
b = pt.as_tensor_variable(b)
return super().dist([a, b], **kwargs)

def support_point(rv, size, a, b):
a, b = pt.broadcast_arrays(a, b)
return 0.5 * (a + b)

def logp(value, a, b):
_fac = pt.log(b) - pt.log(a)
res = -pt.as_tensor_variable(value) - pt.log(_fac)
return check_parameters(
res,
(a > 0) & (a < b),
msg="a > 0 and a < b",
)
@classmethod
def dist(cls, a, b, **kwargs):
a = pt.as_tensor_variable(a)
b = pt.as_tensor_variable(b)
return super().dist([a, b], **kwargs)

def support_point(rv, size, a, b):
a, b = pt.broadcast_arrays(a, b)
return 0.5 * (a + b)

# TODO: remove this once new pymc version is released
moment = support_point

def logp(value, a, b):
_fac = pt.log(b) - pt.log(a)
res = -pt.as_tensor_variable(value) - pt.log(_fac)
return check_parameters(
res,
(a > 0) & (a < b),
msg="a > 0 and a < b",
)


class FixedCompanionMassRV(NormalRV):
Expand Down

0 comments on commit cbe5db9

Please sign in to comment.