Skip to content

Commit

Permalink
Expose kernel_factory attribute as kernel_or_factory
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed May 16, 2024
1 parent 39b6fb1 commit aaed381
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
7 changes: 6 additions & 1 deletion baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,14 @@ class GaussianProcessSurrogate(Surrogate):

# Object variables
kernel_factory: KernelFactory = field(
factory=DefaultKernelFactory, converter=to_kernel_factory
alias="kernel_or_factory",
factory=DefaultKernelFactory,
converter=to_kernel_factory,
)
"""The factory used to create the kernel of the Gaussian process.
Accepts either a :class:`baybe.kernels.base.Kernel` or a
:class:`.kernel_factory.KernelFactory`.
When passing a :class:`baybe.kernels.base.Kernel`, it gets automatically wrapped
into a :class:`.kernel_factory.PlainKernelFactory`."""

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def fixture_default_surrogate_model(request, onnx_surrogate, kernel):
"""The default surrogate model to be used if not specified differently."""
if hasattr(request, "param") and request.param == "onnx":
return onnx_surrogate
return GaussianProcessSurrogate(kernel_factory=kernel)
return GaussianProcessSurrogate(kernel_or_factory=kernel)


@pytest.fixture(name="initial_recommender")
Expand Down

0 comments on commit aaed381

Please sign in to comment.