Skip to content

Commit

Permalink
Add params_dict attribute (#521)
Browse files Browse the repository at this point in the history
* add params_dict attribute

* change returns
  • Loading branch information
aloctavodia authored Aug 15, 2024
1 parent c5c58b6 commit b39b18a
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 9 deletions.
7 changes: 7 additions & 0 deletions preliz/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def __repr__(self):
else:
return name

@property
def params_dict(self):
if self.is_frozen:
return dict(zip(self.param_names, self.params))
else:
return None

def summary(self, mass=0.94, interval="hdi", fmt=".2f"):
"""
Namedtuple with the mean, median, standard deviation, and lower and upper bounds
Expand Down
6 changes: 3 additions & 3 deletions preliz/tests/test_maxent.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
],
)
def test_maxent(dist, lower, upper, mass, support, result):
_, opt = maxent(dist, lower, upper, mass)
maxent(dist, lower, upper, mass)

assert_almost_equal(dist.support, support, 0)

Expand All @@ -176,8 +176,8 @@ def test_maxent(dist, lower, upper, mass, support, result):
"HyperGeometric",
"ZeroInflatedBinomial",
]: # optimization fails to converge, but results are reasonable
assert opt.success
assert_allclose(opt.x, result, atol=0.001)
assert dist.opt.success
assert_allclose(dist.opt.x, result, atol=0.001)


def test_maxent_plot():
Expand Down
4 changes: 2 additions & 2 deletions preliz/tests/test_quartile.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,6 @@
],
)
def test_quartile(distribution, q1, q2, q3, result):
_, opt = quartile(distribution, q1, q2, q3)
quartile(distribution, q1, q2, q3)

assert_allclose(opt.x, result, atol=0.01)
assert_allclose(distribution.opt.x, result, atol=0.01)
8 changes: 6 additions & 2 deletions preliz/unidimensional/maxent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def maxent(
Returns
-------
axes: matplotlib axes
dict: dict with the parameters of the distribution
axes: matplotlib axes (only if `plot=True`)
See Also
--------
Expand Down Expand Up @@ -110,6 +111,7 @@ def maxent(
)

opt = optimize_max_ent(distribution, lower, upper, mass, none_idx, fixed)
distribution.opt = opt

r_error, computed_mass = relative_error(distribution, lower, upper, mass)

Expand All @@ -127,7 +129,9 @@ def maxent(
else:
cid = -1
ax.plot([lower, upper], [0, 0], "o", color=ax.get_lines()[cid].get_c(), alpha=0.5)
return ax, opt
return distribution, ax

return distribution


def end_points_ints(lower, upper):
Expand Down
9 changes: 7 additions & 2 deletions preliz/unidimensional/quartile.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def quartile(
Returns
-------
axes: matplotlib axes
dict: dict with the parameters of the distribution
axes: matplotlib axes (only if `plot=True`)
See Also
--------
Expand Down Expand Up @@ -98,6 +99,8 @@ def quartile(

opt = optimize_quartile(distribution, quartiles, none_idx, fixed)

distribution.opt = opt

r_error, _ = relative_error(distribution, q1, q3, 0.5)

if r_error > 0.01:
Expand All @@ -113,4 +116,6 @@ def quartile(
else:
cid = -1
ax.plot(quartiles, [0, 0, 0], "o", color=ax.get_lines()[cid].get_c(), alpha=0.5)
return ax, opt
return distribution, ax

return distribution

0 comments on commit b39b18a

Please sign in to comment.