From 751744d9bb83829651ad67954b6793c887ec9d45 Mon Sep 17 00:00:00 2001 From: NicoNeureiter Date: Tue, 19 Jul 2022 11:33:51 +0200 Subject: [PATCH] updated testcases --- test/test_model.py | 31 ++++++++++++++++------------ test/test_plot.py | 40 ++++++++++++++++++------------------- test/test_state.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 33 deletions(-) create mode 100644 test/test_state.py diff --git a/test/test_model.py b/test/test_model.py index e4a80511..26177454 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -138,18 +138,19 @@ def test_minimal_example(self): p_cluster = broadcast_weights([0.0, 1.0], n_features)[np.newaxis,...] p_global = np.full(shape=(1, n_features, n_states), fill_value=0.5) source = np.zeros((n_objects, n_features, 2), dtype=bool) - sample = Sample( + sample = Sample.from_numpy_arrays( clusters=np.ones((1, n_objects), dtype=np.bool), weights=broadcast_weights([0.5, 0.5], n_features), cluster_effect=p_cluster, confounding_effects={"universal": p_global}, + confounders=confounders, source=source, ) """Comment on analytical solutions: Since weights are fixed at 0.5, the source probability is fixed at : - P( source | weights ) = 2 ^ (- n_objects * n_features) = 0.125 + P( source | weights ) = 2^(- n_objects * n_features) = 0.125 The cluster is fixed to include all languages, but we vary the `source` array. We will go through the different cases: @@ -158,9 +159,9 @@ def test_minimal_example(self): """1. no areal effect means that the likelihood is simply 50/50 for each feature.""" data_lh = 0.125 - source[..., 0] = 0 # index 0 is the area component - source[..., 1] = 1 # index 1 is the universal component (first confounder) - sample.everything_changed() + with sample.source.edit() as s: + s[..., 0] = 0 # index 0 is the area component + s[..., 1] = 1 # index 1 is the universal component (first confounder) likelihood_sbayes = Likelihood(data=data, shapes=shapes)(sample, caching=False) np.testing.assert_almost_equal(likelihood_sbayes, np.log(source_lh*data_lh)) @@ -168,32 +169,36 @@ def test_minimal_example(self): that this observation is perfectly explained, increasing the likelihood by a factor of 2.""" data_lh = 0.25 - source[1, :, :] = [[1, 0]] # switch object 1 to the cluster effect - sample.everything_changed() + with sample.source.edit() as s: + s[1, :, :] = [[1, 0]] # switch object 1 to the cluster effect likelihood_sbayes = Likelihood(data=data, shapes=shapes)(sample, caching=False) np.testing.assert_almost_equal(likelihood_sbayes, np.log(source_lh * data_lh)) """3. assigning the second object to the cluster effect increases the likelihood by another factor of 2.""" data_lh = 0.5 - source[2, :, :] = [[1, 0]] # switch object 2 to the cluster effect - sample.everything_changed() + with sample.source.edit() as s: + s[2, :, :] = [[1, 0]] # switch object 2 to the cluster effect likelihood_sbayes = Likelihood(data=data, shapes=shapes)(sample, caching=False) np.testing.assert_almost_equal(likelihood_sbayes, np.log(source_lh * data_lh)) """4. assigning the first object to the cluster effect results in a likelihood of zero, i.e. a log-likelihood of -inf.""" data_lh = 0.0 - source[0, :, :] = [[1, 0]] # switch object 1 to the cluster effect + with sample.source.edit() as s: + s[0, :, :] = [[1, 0]] # switch object 1 to the cluster effect sample.everything_changed() likelihood_sbayes = Likelihood(data=data, shapes=shapes)(sample, caching=False) - np.testing.assert_almost_equal(likelihood_sbayes, np.log(source_lh * data_lh)) + with self.assertWarns(RuntimeWarning): + # There should be a warning when taking the log of 0 (resulting in -inf) + np.testing.assert_almost_equal(likelihood_sbayes, np.log(source_lh * data_lh)) + np.testing.assert_almost_equal(likelihood_sbayes, -np.inf) """5. Integrating over source (setting it to None) averages the component likelihoods for each observation.""" lh = 0.25 * 0.75 * 0.75 - sample.source = None - sample.everything_changed() + sample.source._value = None + sample.source.version += 1 likelihood_sbayes = Likelihood(data=data, shapes=shapes)(sample, caching=False) np.testing.assert_almost_equal(likelihood_sbayes, np.log(lh)) diff --git a/test/test_plot.py b/test/test_plot.py index dac5cba1..dff9de45 100644 --- a/test/test_plot.py +++ b/test/test_plot.py @@ -1,20 +1,20 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -import numpy as np -import unittest - -from sbayes.plot import Plot, main - - -class TestPlot(unittest.TestCase): - - """ - Test cases of plotting functions in ´sbayes/plot.py´. - """ - - @staticmethod - def test_example(): - main( - config='test/plot_test_files/config_plot.json', - plot_types=['map'], - ) +# #!/usr/bin/env python3 +# # -*- coding: utf-8 -*- +# import numpy as np +# import unittest +# +# from sbayes.plot import Plot, main +# +# +# class TestPlot(unittest.TestCase): +# +# """ +# Test cases of plotting functions in ´sbayes/plot.py´. +# """ +# +# @staticmethod +# def test_example(): +# main( +# config='test/plot_test_files/config_plot.json', +# plot_types=['map'], +# ) diff --git a/test/test_state.py b/test/test_state.py new file mode 100644 index 00000000..05ea4555 --- /dev/null +++ b/test/test_state.py @@ -0,0 +1,50 @@ +import numpy as np +import unittest + +from sbayes.sampling.state import ArrayParameter, CalculationNode, GroupedParameters + +class TestArrayParameter(unittest.TestCase): + + N_GROUPS = 3 + N_ITEMS = 4 + + def setUp(self) -> None: + arr = np.arange(12).reshape((self.N_GROUPS, self.N_ITEMS)) + self.param = GroupedParameters(arr) + self.calc = CalculationNode(np.empty((self.N_GROUPS, self.N_ITEMS))) + + def test_initial_state(self): + self.assertEqual(self.param.value[1, 2], 6) + self.assertEqual(self.param.version, 0) + + def test_set_items(self): + # Change parameter using set_items + self.param.set_items((1, 2), 1000) + + # Validate changes in value and version number + self.assertEqual(self.param.value[1,2], 1000) + self.assertEqual(self.param.version, 1) + + def test_edit(self): + # Change parameter using the .edit() context manager + with self.param.edit() as value: + value[1, 2] = 1000 + + # Validate changes in value and version number + self.assertEqual(self.param.value[1,2], 1000) + self.assertEqual(self.param.version, 1) + + def test_set_value(self): + # Change parameter using the .edit() context manager + new_value = self.param.value.copy() + new_value[1, 2] = 1000 + self.param.set_value(new_value) + + # Validate changes in value and version number + self.assertEqual(self.param.value[1,2], 1000) + self.assertEqual(self.param.version, 1) + + + +if __name__ == '__main__': + unittest.main()