Skip to content

Commit

Permalink
Implemented normalisation for weights with sum larger than one. Also … (
Browse files Browse the repository at this point in the history
#185)

* Implemented normalisation for weights with sum larger than one. Also added a check for negative weights

* add unit test

* update test with new syntax

* resolve conflict
  • Loading branch information
BStoelzner authored Oct 31, 2023
1 parent a38c210 commit 8a7f69e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/qp/mixmod_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, means, stds, weights, *args, **kwargs):
stds: array_like
The standard deviations of the Gaussians
weights : array_like
The weights to attach to the Gaussians
The weights to attach to the Gaussians. Weights should sum up to one. If not, the weights are interpreted as relative weights.
"""
self._scipy_version_warning()
self._means = reshape_to_pdf_size(means, -1)
Expand All @@ -58,6 +58,9 @@ def __init__(self, means, stds, weights, *args, **kwargs):
kwargs['shape'] = means.shape[:-1]
self._ncomps = means.shape[-1]
super().__init__(*args, **kwargs)
if np.any(self._weights<0):
raise ValueError('All weights need to be larger than zero')
self._weights = self._weights/self._weights.sum(axis=1)[:,None]
self._addobjdata('weights', self._weights)
self._addobjdata('stds', self._stds)
self._addobjdata('means', self._means)
Expand Down
8 changes: 8 additions & 0 deletions tests/qp/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,5 +213,13 @@ def test_iterator(self):
test_vals = ens_i.pdf(test_grid)
assert np.allclose(check_vals, test_vals)

def test_mixmod_with_negative_weights(self):
"""Verify that an exception is raised when setting up a mixture model with negative weights"""
means = np.array([0.5,1.1, 2.9])
sigmas = np.array([0.15,0.13,0.14])
weights = np.array([1,0.5,-0.25])
with self.assertRaises(ValueError):
_ = qp.mixmod(weights=weights, means=means, stds=sigmas)

if __name__ == '__main__':
unittest.main()

0 comments on commit 8a7f69e

Please sign in to comment.