Skip to content

Commit

Permalink
able to build package
Browse files Browse the repository at this point in the history
  • Loading branch information
PFLeget committed Sep 18, 2024
1 parent c3c8494 commit 2577c4c
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions python/lsst/meas/algorithms/gp_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def median_with_mad_clipping(data, mad_multiplier=2.0):


@jit
def jax_rbf_kernel(x1, x2, sigma, correlation_length, y_err):
def jax_rbf_kernel(x1, x2, sigma, correlation_length):
"""
Computes the radial basis function (RBF) kernel matrix.
Expand All @@ -87,8 +87,6 @@ def jax_rbf_kernel(x1, x2, sigma, correlation_length, y_err):
The scale parameter of the kernel.
correlation_length : `float`
The correlation length parameter of the kernel.
y_err : `float`
Measurement error for the input values.
Returns:
--------
Expand All @@ -97,10 +95,9 @@ def jax_rbf_kernel(x1, x2, sigma, correlation_length, y_err):
"""
distance_squared = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
kernel = (sigma**2) * jnp.exp(-0.5 * distance_squared / (correlation_length**2))
y_err = jnp.ones(len(x[:, 0])) * y_err
kernel += jnp.eye(len(y_err)) * (y_err**2)
return kernel


@jit
def jax_get_alpha(y, kernel):
"""
Expand Down Expand Up @@ -182,14 +179,17 @@ def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
def fit(self, x_train, y_train):
y = y_train - self.mean
self._x = x_train
kernel = jax_rbf_kernel(x_train, self.std, self.correlation_length, self.white_noise)
kernel = jax_rbf_kernel(x_train, self.std, self.correlation_length)
y_err = jnp.ones(len(x_train[:, 0])) * self.white_noise
kernel += jnp.eye(len(y_err)) * (y_err**2)
self._alpha = jax_get_alpha(y, kernel)

def predict(self, x_predict):
kernel_rect = jax_rbf_kernel(x_predict, self._x, self.std, self.correlation_length, 0)
y_pred = jax_get_gp_predict(kernel_rect, self._alpha)
return y_pred + self.mean


class GaussianProcessTreegp:
"""
Gaussian Process Treegp class for Gaussian Process interpolation.
Expand Down

0 comments on commit 2577c4c

Please sign in to comment.