From 4ee190beed288a47c07cb6e8c56212f1709ed354 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20Ho=CC=88chenberger?= Date: Thu, 23 May 2019 10:51:55 +0200 Subject: [PATCH] NF: Add JSON dump and load, and support "equals" operator --- questplus/qp.py | 64 +++++++++++++++++++++++++++++++++++ questplus/tests/test_qp.py | 68 ++++++++++++++++++++++++++++++++++++++ setup.cfg | 1 + 3 files changed, 133 insertions(+) diff --git a/questplus/qp.py b/questplus/qp.py index 3c9cef1..83e89d3 100644 --- a/questplus/qp.py +++ b/questplus/qp.py @@ -1,6 +1,7 @@ from typing import Optional, Sequence import xarray as xr import numpy as np +import json_tricks from copy import deepcopy from questplus import psychometric_function @@ -294,6 +295,69 @@ def param_estimate(self) -> dict: return param_estimates + def to_json(self) -> str: + self_copy = deepcopy(self) + self_copy.prior = self_copy.prior.to_dict() + self_copy.posterior = self_copy.posterior.to_dict() + self_copy.likelihoods = self_copy.likelihoods.to_dict() + return json_tricks.dumps(self_copy) + + @staticmethod + def from_json(data: str): + loaded = json_tricks.loads(data) + loaded.prior = xr.DataArray.from_dict(loaded.prior) + loaded.posterior = xr.DataArray.from_dict(loaded.posterior) + loaded.likelihoods = xr.DataArray.from_dict(loaded.likelihoods) + return loaded + + def __eq__(self, other): + if not self.likelihoods.equals(other.likelihoods): + return False + + if not self.prior.equals(other.prior): + return False + + if not self.posterior.equals(other.posterior): + return False + + for param_name in self.param_domain.keys(): + if not np.array_equal(self.param_domain[param_name], + other.param_domain[param_name]): + return False + + for stim_property in self.stim_domain.keys(): + if not np.array_equal(self.stim_domain[stim_property], + other.stim_domain[stim_property]): + return False + + for outcome_name in self.outcome_domain.keys(): + if not np.array_equal(self.outcome_domain[outcome_name], + other.outcome_domain[outcome_name]): + return False + + if self.stim_selection != other.stim_selection: + return False + + if self.stim_selection_options != other.stim_selection_options: + return False + + if self.stim_scale != other.stim_scale: + return False + + if self.stim_history != other.stim_history: + return False + + if self.resp_history != other.resp_history: + return False + + if self.param_estimation_method != other.param_estimation_method: + return False + + if self.func != other.func: + return False + + return True + class QuestPlusWeibull(QuestPlus): def __init__(self, *, diff --git a/questplus/tests/test_qp.py b/questplus/tests/test_qp.py index de7382b..567a156 100644 --- a/questplus/tests/test_qp.py +++ b/questplus/tests/test_qp.py @@ -425,6 +425,72 @@ def test_weibull(): expected_mode_threshold) +def test_eq(): + threshold = np.arange(-40, 0 + 1) + slope, guess, lapse = 3.5, 0.5, 0.02 + contrasts = threshold.copy() + + stim_domain = dict(intensity=contrasts) + param_domain = dict(threshold=threshold, slope=slope, + lower_asymptote=guess, lapse_rate=lapse) + outcome_domain = dict(response=['Correct', 'Incorrect']) + + f = 'weibull' + scale = 'dB' + stim_selection_method = 'min_entropy' + param_estimation_method = 'mode' + + q1 = QuestPlus(stim_domain=stim_domain, param_domain=param_domain, + outcome_domain=outcome_domain, func=f, stim_scale=scale, + stim_selection_method=stim_selection_method, + param_estimation_method=param_estimation_method) + + q2 = QuestPlus(stim_domain=stim_domain, param_domain=param_domain, + outcome_domain=outcome_domain, func=f, stim_scale=scale, + stim_selection_method=stim_selection_method, + param_estimation_method=param_estimation_method) + + # Add some random responses. + q1.update(stim=q1.next_stim, outcome=dict(response='Correct')) + q1.update(stim=q1.next_stim, outcome=dict(response='Incorrect')) + q2.update(stim=q2.next_stim, outcome=dict(response='Correct')) + q2.update(stim=q2.next_stim, outcome=dict(response='Incorrect')) + + assert q1 == q2 + + +def test_json(): + threshold = np.arange(-40, 0 + 1) + slope, guess, lapse = 3.5, 0.5, 0.02 + contrasts = threshold.copy() + + stim_domain = dict(intensity=contrasts) + param_domain = dict(threshold=threshold, slope=slope, + lower_asymptote=guess, lapse_rate=lapse) + outcome_domain = dict(response=['Correct', 'Incorrect']) + + f = 'weibull' + scale = 'dB' + stim_selection_method = 'min_entropy' + param_estimation_method = 'mode' + + q = QuestPlus(stim_domain=stim_domain, param_domain=param_domain, + outcome_domain=outcome_domain, func=f, stim_scale=scale, + stim_selection_method=stim_selection_method, + param_estimation_method=param_estimation_method) + + # Add some random responses. + q.update(stim=q.next_stim, outcome=dict(response='Correct')) + q.update(stim=q.next_stim, outcome=dict(response='Incorrect')) + + q_dumped = q.to_json() + q_loaded = QuestPlus.from_json(q_dumped) + + assert q_loaded == q + + q_loaded.update(stim=q_loaded.next_stim, outcome=dict(response='Correct')) + + if __name__ == '__main__': test_threshold() test_threshold_slope() @@ -432,3 +498,5 @@ def test_weibull(): test_mean_sd_lapse() test_spatial_contrast_sensitivity() test_weibull() + test_eq() + test_json() diff --git a/setup.cfg b/setup.cfg index 4be22f4..b8e8b78 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ install_requires = numpy scipy xarray + json_tricks [bdist_wheel] universal = 1