From ffb7279f2df5f79b59e624fc2f8fc6e0dbd0373f Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:08:43 -0800 Subject: [PATCH] Use the same dtype for input indices and targets in preprocess_sparse_image --- gpax/utils/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gpax/utils/utils.py b/gpax/utils/utils.py index 2f5e72e..9581820 100644 --- a/gpax/utils/utils.py +++ b/gpax/utils/utils.py @@ -156,6 +156,7 @@ def preprocess_sparse_image(sparse_image): and an array of full indices of the shape (N_full, D) for reconstructing the full image. D is the image dimensionality (D=2 for a 2D image) """ + dtype = sparse_image.dtype # Find non-zero element indices non_zero_indices = onp.nonzero(sparse_image) # Create the GP input using the indices @@ -164,7 +165,7 @@ def preprocess_sparse_image(sparse_image): targets = sparse_image[non_zero_indices] # Generate indices for the entire image full_indices = onp.array(onp.meshgrid(*[onp.arange(dim) for dim in sparse_image.shape])).T.reshape(-1, sparse_image.ndim) - return gp_input, targets, full_indices + return gp_input.astype(dtype), targets.astype(dtype), full_indices.astype(dtype) def initialize_inducing_points(X, ratio=0.1, method='uniform', key=None):