From ffb7279f2df5f79b59e624fc2f8fc6e0dbd0373f Mon Sep 17 00:00:00 2001
From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com>
Date: Wed, 7 Feb 2024 14:08:43 -0800
Subject: [PATCH] Use the same dtype for input indices and targets in
 preprocess_sparse_image

---
 gpax/utils/utils.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/gpax/utils/utils.py b/gpax/utils/utils.py
index 2f5e72e..9581820 100644
--- a/gpax/utils/utils.py
+++ b/gpax/utils/utils.py
@@ -156,6 +156,7 @@ def preprocess_sparse_image(sparse_image):
     and an array of full indices of the shape (N_full, D) for reconstructing the full image. D is
     the image dimensionality (D=2 for a 2D image)
     """
+    dtype = sparse_image.dtype
     # Find non-zero element indices
     non_zero_indices = onp.nonzero(sparse_image)
     # Create the GP input using the indices
@@ -164,7 +165,7 @@ def preprocess_sparse_image(sparse_image):
     targets = sparse_image[non_zero_indices]
     # Generate indices for the entire image
     full_indices = onp.array(onp.meshgrid(*[onp.arange(dim) for dim in sparse_image.shape])).T.reshape(-1, sparse_image.ndim)
-    return gp_input, targets, full_indices
+    return gp_input.astype(dtype), targets.astype(dtype), full_indices.astype(dtype)
 
 
 def initialize_inducing_points(X, ratio=0.1, method='uniform', key=None):