11from deepsensor .data .loader import TaskLoader
2- from deepsensor .data .processor import DataProcessor
2+ from deepsensor .data .processor import (
3+ DataProcessor ,
4+ process_X_mask_for_X ,
5+ xarray_to_coord_array_normalised ,
6+ mask_coord_array_normalised ,
7+ )
38from deepsensor .data .task import Task , flatten_X
49
510from typing import List , Union
@@ -118,6 +123,7 @@ def predict(
118123 X_t : Union [
119124 xr .Dataset , xr .DataArray , pd .DataFrame , pd .Series , pd .Index , np .ndarray
120125 ],
126+ X_t_mask : Union [xr .Dataset , xr .DataArray ] = None ,
121127 X_t_is_normalised : bool = False ,
122128 resolution_factor = 1 ,
123129 n_samples = 0 ,
@@ -131,13 +137,12 @@ def predict(
131137 ):
132138 """Predict on a regular grid or at off-grid locations.
133139
134- TODO:
135- - Test with multiple targets model
136-
137140 Args:
138141 tasks: List of tasks containing context data.
139142 X_t: Target locations to predict at. Can be an xarray object containing
140143 on-grid locations or a pandas object containing off-grid locations.
144+ X_t_mask: Optional 2D mask to apply to X_t (zero/False will be NaNs). Will be interpolated
145+ to the same grid as X_t. Default None (no mask).
141146 X_t_is_normalised: Whether the `X_t` coords are normalised.
142147 If False, will normalise the coords before passing to model. Default False.
143148 resolution_factor: Optional factor to increase the resolution of the
@@ -186,6 +191,10 @@ def predict(
186191 f"X_t must be and xarray, pandas or numpy object. Got { type (X_t )} ."
187192 )
188193
194+ if mode == "off-grid" and X_t_mask is not None :
195+ # TODO: Unit test this
196+ raise ValueError ("X_t_mask can only be used with on-grid predictions." )
197+
189198 if type (tasks ) is Task :
190199 tasks = [tasks ]
191200
@@ -228,7 +237,7 @@ def predict(
228237
229238 # Unnormalise coords to use for xarray/pandas objects for storing predictions
230239 X_t = self .data_processor .map_coords (X_t , unnorm = True )
231- else :
240+ elif not X_t_is_normalised :
232241 # Normalise coords to use for model
233242 X_t_normalised = self .data_processor .map_coords (X_t )
234243
@@ -237,8 +246,15 @@ def predict(
237246 X_t_normalised = increase_spatial_resolution (
238247 X_t_normalised , resolution_factor
239248 )
240- # TODO rename from _arr because not an array here
241- X_t_arr = (X_t_normalised ["x1" ].values , X_t_normalised ["x2" ].values )
249+
250+ if X_t_mask is not None :
251+ X_t_mask = process_X_mask_for_X (X_t_mask , X_t )
252+ X_t_mask_normalised = self .data_processor .map_coords (X_t_mask )
253+ X_t_arr = xarray_to_coord_array_normalised (X_t_normalised )
254+ # Remove points that lie outside the mask
255+ X_t_arr = mask_coord_array_normalised (X_t_arr , X_t_mask_normalised )
256+ else :
257+ X_t_arr = (X_t_normalised ["x1" ].values , X_t_normalised ["x2" ].values )
242258 elif mode == "off-grid" :
243259 X_t_arr = X_t_normalised .reset_index ()[["x1" , "x2" ]].values .T
244260
@@ -379,13 +395,22 @@ def unnormalise_pred_array(arr, **kwargs):
379395 )
380396
381397 if mode == "on-grid" :
382- mean .loc [:, task ["time" ], :, :] = mean_arr
383- std .loc [:, task ["time" ], :, :] = std_arr
384- if n_samples >= 1 :
385- for sample_i in range (n_samples ):
386- samples .loc [:, sample_i , task ["time" ], :, :] = samples_arr [
387- sample_i
388- ]
398+ if X_t_mask is None :
399+ mean .loc [:, task ["time" ], :, :] = mean_arr
400+ std .loc [:, task ["time" ], :, :] = std_arr
401+ if n_samples >= 1 :
402+ for sample_i in range (n_samples ):
403+ samples .loc [:, sample_i , task ["time" ], :, :] = samples_arr [
404+ sample_i
405+ ]
406+ else :
407+ mean .loc [:, task ["time" ], :, :].data [:, X_t_mask .data ] = mean_arr
408+ std .loc [:, task ["time" ], :, :].data [:, X_t_mask .data ] = std_arr
409+ if n_samples >= 1 :
410+ for sample_i in range (n_samples ):
411+ samples .loc [:, sample_i , task ["time" ], :, :].data [
412+ :, X_t_mask .data
413+ ] = samples_arr [sample_i ]
389414 elif mode == "off-grid" :
390415 # TODO multi-target case
391416 mean .loc [task ["time" ]] = mean_arr .T
0 commit comments