Skip to content

Commit

Permalink
bug fixes; prepare for next release (#212)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #212

Fixes a recent change to the way tell messages are parsed caused a crash when sending tensors; adds statsmodels and updates botorch version; increments version number

Reviewed By: mshvartsman

Differential Revision: D41696161

fbshipit-source-id: cd7006a54ce33a8f2145601c0c3742e3ecca182a
  • Loading branch information
crasanders authored and facebook-github-bot committed Dec 7, 2022
1 parent e208d20 commit 47d541a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 5 deletions.
15 changes: 13 additions & 2 deletions aepsych/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,10 +800,21 @@ def tell(self, outcome, config, model_data=True):
param_value=str(param_value),
)

# Check if we get single or multiple outcomes
# Multiple outcomes come in the form of iterables that aren't strings or single-element tensors
if isinstance(outcome, Iterable) and type(outcome) != str:
for i, outcome_value in enumerate(outcome):
if isinstance(outcome_value, Iterable) and type(outcome_value) != str:
if len(outcome_value) == 1:
if (
isinstance(outcome_value, Iterable)
and type(outcome_value) != str
):
if (
isinstance(outcome_value, torch.Tensor)
and outcome_value.dim() < 2
):
outcome_value = outcome_value.item()

elif len(outcome_value) == 1:
outcome_value = outcome_value[0]
else:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion aepsych/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

__version__ = "0.2.0"
__version__ = "0.3.0"
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ scipy
sklearn
gpytorch>=1.8.1
statsmodels
git+https://github.com/pytorch/botorch@main#egg=botorch
botorch>=0.8.0
SQLAlchemy
dill
pandas
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"scipy",
"sklearn",
"gpytorch>=1.9.0",
"botorch>=0.7.2",
"botorch>=0.8.0",
"SQLAlchemy",
"dill",
"pandas",
Expand All @@ -27,6 +27,7 @@
"aepsych_client",
"voila==0.3.6",
"ipywidgets==7.6.5",
"statsmodels",
]

DEV_REQUIRES = [
Expand Down

0 comments on commit 47d541a

Please sign in to comment.