Skip to content

Commit

Permalink
updated testcases
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoNeureiter committed Jul 19, 2022
1 parent 6fffd9a commit 751744d
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 33 deletions.
31 changes: 18 additions & 13 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,19 @@ def test_minimal_example(self):
p_cluster = broadcast_weights([0.0, 1.0], n_features)[np.newaxis,...]
p_global = np.full(shape=(1, n_features, n_states), fill_value=0.5)
source = np.zeros((n_objects, n_features, 2), dtype=bool)
sample = Sample(
sample = Sample.from_numpy_arrays(
clusters=np.ones((1, n_objects), dtype=np.bool),
weights=broadcast_weights([0.5, 0.5], n_features),
cluster_effect=p_cluster,
confounding_effects={"universal": p_global},
confounders=confounders,
source=source,
)

"""Comment on analytical solutions:
Since weights are fixed at 0.5, the source probability is fixed at :
P( source | weights ) = 2 ^ (- n_objects * n_features) = 0.125
P( source | weights ) = 2^(- n_objects * n_features) = 0.125
The cluster is fixed to include all languages, but we vary the `source` array. We
will go through the different cases:
Expand All @@ -158,42 +159,46 @@ def test_minimal_example(self):

"""1. no areal effect means that the likelihood is simply 50/50 for each feature."""
data_lh = 0.125
source[..., 0] = 0 # index 0 is the area component
source[..., 1] = 1 # index 1 is the universal component (first confounder)
sample.everything_changed()
with sample.source.edit() as s:
s[..., 0] = 0 # index 0 is the area component
s[..., 1] = 1 # index 1 is the universal component (first confounder)
likelihood_sbayes = Likelihood(data=data, shapes=shapes)(sample, caching=False)
np.testing.assert_almost_equal(likelihood_sbayes, np.log(source_lh*data_lh))

"""2. assigning the observation of the second object to the cluster effect means
that this observation is perfectly explained, increasing the likelihood by a
factor of 2."""
data_lh = 0.25
source[1, :, :] = [[1, 0]] # switch object 1 to the cluster effect
sample.everything_changed()
with sample.source.edit() as s:
s[1, :, :] = [[1, 0]] # switch object 1 to the cluster effect
likelihood_sbayes = Likelihood(data=data, shapes=shapes)(sample, caching=False)
np.testing.assert_almost_equal(likelihood_sbayes, np.log(source_lh * data_lh))

"""3. assigning the second object to the cluster effect increases the likelihood
by another factor of 2."""
data_lh = 0.5
source[2, :, :] = [[1, 0]] # switch object 2 to the cluster effect
sample.everything_changed()
with sample.source.edit() as s:
s[2, :, :] = [[1, 0]] # switch object 2 to the cluster effect
likelihood_sbayes = Likelihood(data=data, shapes=shapes)(sample, caching=False)
np.testing.assert_almost_equal(likelihood_sbayes, np.log(source_lh * data_lh))

"""4. assigning the first object to the cluster effect results in a likelihood of
zero, i.e. a log-likelihood of -inf."""
data_lh = 0.0
source[0, :, :] = [[1, 0]] # switch object 1 to the cluster effect
with sample.source.edit() as s:
s[0, :, :] = [[1, 0]] # switch object 1 to the cluster effect
sample.everything_changed()
likelihood_sbayes = Likelihood(data=data, shapes=shapes)(sample, caching=False)
np.testing.assert_almost_equal(likelihood_sbayes, np.log(source_lh * data_lh))
with self.assertWarns(RuntimeWarning):
# There should be a warning when taking the log of 0 (resulting in -inf)
np.testing.assert_almost_equal(likelihood_sbayes, np.log(source_lh * data_lh))
np.testing.assert_almost_equal(likelihood_sbayes, -np.inf)

"""5. Integrating over source (setting it to None) averages the component
likelihoods for each observation."""
lh = 0.25 * 0.75 * 0.75
sample.source = None
sample.everything_changed()
sample.source._value = None
sample.source.version += 1
likelihood_sbayes = Likelihood(data=data, shapes=shapes)(sample, caching=False)
np.testing.assert_almost_equal(likelihood_sbayes, np.log(lh))

Expand Down
40 changes: 20 additions & 20 deletions test/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import unittest

from sbayes.plot import Plot, main


class TestPlot(unittest.TestCase):

"""
Test cases of plotting functions in ´sbayes/plot.py´.
"""

@staticmethod
def test_example():
main(
config='test/plot_test_files/config_plot.json',
plot_types=['map'],
)
# #!/usr/bin/env python3
# # -*- coding: utf-8 -*-
# import numpy as np
# import unittest
#
# from sbayes.plot import Plot, main
#
#
# class TestPlot(unittest.TestCase):
#
# """
# Test cases of plotting functions in ´sbayes/plot.py´.
# """
#
# @staticmethod
# def test_example():
# main(
# config='test/plot_test_files/config_plot.json',
# plot_types=['map'],
# )
50 changes: 50 additions & 0 deletions test/test_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
import unittest

from sbayes.sampling.state import ArrayParameter, CalculationNode, GroupedParameters

class TestArrayParameter(unittest.TestCase):

N_GROUPS = 3
N_ITEMS = 4

def setUp(self) -> None:
arr = np.arange(12).reshape((self.N_GROUPS, self.N_ITEMS))
self.param = GroupedParameters(arr)
self.calc = CalculationNode(np.empty((self.N_GROUPS, self.N_ITEMS)))

def test_initial_state(self):
self.assertEqual(self.param.value[1, 2], 6)
self.assertEqual(self.param.version, 0)

def test_set_items(self):
# Change parameter using set_items
self.param.set_items((1, 2), 1000)

# Validate changes in value and version number
self.assertEqual(self.param.value[1,2], 1000)
self.assertEqual(self.param.version, 1)

def test_edit(self):
# Change parameter using the .edit() context manager
with self.param.edit() as value:
value[1, 2] = 1000

# Validate changes in value and version number
self.assertEqual(self.param.value[1,2], 1000)
self.assertEqual(self.param.version, 1)

def test_set_value(self):
# Change parameter using the .edit() context manager
new_value = self.param.value.copy()
new_value[1, 2] = 1000
self.param.set_value(new_value)

# Validate changes in value and version number
self.assertEqual(self.param.value[1,2], 1000)
self.assertEqual(self.param.version, 1)



if __name__ == '__main__':
unittest.main()

0 comments on commit 751744d

Please sign in to comment.