From 47d541ab77f123df5597fd85fbb0358415f99058 Mon Sep 17 00:00:00 2001 From: Craig Sanders Date: Wed, 7 Dec 2022 09:35:39 -0800 Subject: [PATCH] bug fixes; prepare for next release (#212) Summary: Pull Request resolved: https://github.com/facebookresearch/aepsych/pull/212 Fixes a recent change to the way tell messages are parsed caused a crash when sending tensors; adds statsmodels and updates botorch version; increments version number Reviewed By: mshvartsman Differential Revision: D41696161 fbshipit-source-id: cd7006a54ce33a8f2145601c0c3742e3ecca182a --- aepsych/server/server.py | 15 +++++++++++++-- aepsych/version.py | 2 +- requirements.txt | 2 +- setup.py | 3 ++- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/aepsych/server/server.py b/aepsych/server/server.py index 2ccb9f641..871bf458e 100644 --- a/aepsych/server/server.py +++ b/aepsych/server/server.py @@ -800,10 +800,21 @@ def tell(self, outcome, config, model_data=True): param_value=str(param_value), ) + # Check if we get single or multiple outcomes + # Multiple outcomes come in the form of iterables that aren't strings or single-element tensors if isinstance(outcome, Iterable) and type(outcome) != str: for i, outcome_value in enumerate(outcome): - if isinstance(outcome_value, Iterable) and type(outcome_value) != str: - if len(outcome_value) == 1: + if ( + isinstance(outcome_value, Iterable) + and type(outcome_value) != str + ): + if ( + isinstance(outcome_value, torch.Tensor) + and outcome_value.dim() < 2 + ): + outcome_value = outcome_value.item() + + elif len(outcome_value) == 1: outcome_value = outcome_value[0] else: raise ValueError( diff --git a/aepsych/version.py b/aepsych/version.py index a41644764..5c3fa5c4a 100644 --- a/aepsych/version.py +++ b/aepsych/version.py @@ -5,4 +5,4 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -__version__ = "0.2.0" +__version__ = "0.3.0" diff --git a/requirements.txt b/requirements.txt index 5d30be104..f42c84b1e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ scipy sklearn gpytorch>=1.8.1 statsmodels -git+https://github.com/pytorch/botorch@main#egg=botorch +botorch>=0.8.0 SQLAlchemy dill pandas diff --git a/setup.py b/setup.py index 0da7afcf0..d2ae3a245 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ "scipy", "sklearn", "gpytorch>=1.9.0", - "botorch>=0.7.2", + "botorch>=0.8.0", "SQLAlchemy", "dill", "pandas", @@ -27,6 +27,7 @@ "aepsych_client", "voila==0.3.6", "ipywidgets==7.6.5", + "statsmodels", ] DEV_REQUIRES = [