Skip to content

Commit

Permalink
Merge: Enable single point precision via env vars (#226)
Browse files Browse the repository at this point in the history
Added
- env var for torch
- env var for numpy
- these are booleans that enforce usage of single precision if True,
False by default

Addressing suggestion from #223 

Notes:
- Did some very light testing via fulltest and setting the tox env vars
to use single precision there
- some tests expectedly fail due to numerical instability -> no action
required
- some tests fail because hypothesis generate incompatible floats -->
likely fixed by using `st.floats(...,width=32)` if single precision is
desired --> not done under the assumption we dont want tests enabled for
single precision, would require flexible checking everywhere
- some tests fail as constraints complain about combining Single and
Double precision --> likely because `rhs` and `coefficients` are defined
as python floats, hence 64 bit. This means using double precision with
env vars also implies providing explicit 32bit floats in such fields
(like `coefficients` or `rhs`), otherwise it wont work. This is not
something I'd fix but perhaps mention in the userguide
  • Loading branch information
Scienfitz authored May 16, 2024
2 parents 24bd1c1 + cc45fa4 commit cbc1b0c
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `torch`, `gpytorch` and `botorch` are lazy-loaded for improved startup time
- If an exception is encountered during simulation, incomplete results are returned
with a warning instead of passing through the uncaught exception
- Environment variables `BAYBE_NUMPY_USE_SINGLE_PRECISION` and
`BAYBE_TORCH_USE_SINGLE_PRECISION` to enforce single point precision usage

### Removed
- `model_params` attribute from `Surrogate` base class, `GaussianProcessSurrogate` and
Expand Down
1 change: 1 addition & 0 deletions baybe/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def to_botorch(
if p in param_names
]

# TODO: Cast rhs to correct precision once BoTorch also supports single point.
return (
torch.tensor(param_indices),
torch.tensor(self.coefficients, dtype=DTypeFloatTorch),
Expand Down
12 changes: 11 additions & 1 deletion baybe/utils/numerical.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
"""Utilities for numeric operations."""

import os
from collections.abc import Sequence

import numpy as np
import numpy.typing as npt

DTypeFloatNumpy = np.float64
from baybe.utils.boolean import strtobool

VARNAME_NUMPY_USE_SINGLE_PRECISION = "BAYBE_NUMPY_USE_SINGLE_PRECISION"
"""Environment variable name for enforcing single precision in numpy."""

DTypeFloatNumpy = (
np.float32
if strtobool(os.environ.get(VARNAME_NUMPY_USE_SINGLE_PRECISION, "False"))
else np.float64
)
"""Floating point data type used for numpy arrays."""

DTypeFloatONNX = np.float32
Expand Down
14 changes: 13 additions & 1 deletion baybe/utils/torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
"""Torch utilities shipped as separate module for lazy-loading."""


import os

import torch

DTypeFloatTorch = torch.float64
from baybe.utils.boolean import strtobool

VARNAME_TORCH_USE_SINGLE_PRECISION = "BAYBE_TORCH_USE_SINGLE_PRECISION"
"""Environment variable name for enforcing single precision in torch."""

DTypeFloatTorch = (
torch.float32
if strtobool(os.environ.get(VARNAME_TORCH_USE_SINGLE_PRECISION, "False"))
else torch.float64
)
"""Floating point data type used for torch tensors."""

0 comments on commit cbc1b0c

Please sign in to comment.