Skip to content

Commit ca4948c

Browse files
committed
Add X_t_mask kwarg to DeepSensorModel.predict
1 parent 814ec86 commit ca4948c

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

deepsensor/model/model.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from deepsensor.data.loader import TaskLoader
2-
from deepsensor.data.processor import DataProcessor
2+
from deepsensor.data.processor import (
3+
DataProcessor,
4+
process_X_mask_for_X,
5+
xarray_to_coord_array_normalised,
6+
mask_coord_array_normalised,
7+
)
38
from deepsensor.data.task import Task, flatten_X
49

510
from typing import List, Union
@@ -118,6 +123,7 @@ def predict(
118123
X_t: Union[
119124
xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index, np.ndarray
120125
],
126+
X_t_mask: Union[xr.Dataset, xr.DataArray] = None,
121127
X_t_is_normalised: bool = False,
122128
resolution_factor=1,
123129
n_samples=0,
@@ -131,13 +137,12 @@ def predict(
131137
):
132138
"""Predict on a regular grid or at off-grid locations.
133139
134-
TODO:
135-
- Test with multiple targets model
136-
137140
Args:
138141
tasks: List of tasks containing context data.
139142
X_t: Target locations to predict at. Can be an xarray object containing
140143
on-grid locations or a pandas object containing off-grid locations.
144+
X_t_mask: Optional 2D mask to apply to X_t (zero/False will be NaNs). Will be interpolated
145+
to the same grid as X_t. Default None (no mask).
141146
X_t_is_normalised: Whether the `X_t` coords are normalised.
142147
If False, will normalise the coords before passing to model. Default False.
143148
resolution_factor: Optional factor to increase the resolution of the
@@ -186,6 +191,10 @@ def predict(
186191
f"X_t must be and xarray, pandas or numpy object. Got {type(X_t)}."
187192
)
188193

194+
if mode == "off-grid" and X_t_mask is not None:
195+
# TODO: Unit test this
196+
raise ValueError("X_t_mask can only be used with on-grid predictions.")
197+
189198
if type(tasks) is Task:
190199
tasks = [tasks]
191200

@@ -228,7 +237,7 @@ def predict(
228237

229238
# Unnormalise coords to use for xarray/pandas objects for storing predictions
230239
X_t = self.data_processor.map_coords(X_t, unnorm=True)
231-
else:
240+
elif not X_t_is_normalised:
232241
# Normalise coords to use for model
233242
X_t_normalised = self.data_processor.map_coords(X_t)
234243

@@ -237,8 +246,15 @@ def predict(
237246
X_t_normalised = increase_spatial_resolution(
238247
X_t_normalised, resolution_factor
239248
)
240-
# TODO rename from _arr because not an array here
241-
X_t_arr = (X_t_normalised["x1"].values, X_t_normalised["x2"].values)
249+
250+
if X_t_mask is not None:
251+
X_t_mask = process_X_mask_for_X(X_t_mask, X_t)
252+
X_t_mask_normalised = self.data_processor.map_coords(X_t_mask)
253+
X_t_arr = xarray_to_coord_array_normalised(X_t_normalised)
254+
# Remove points that lie outside the mask
255+
X_t_arr = mask_coord_array_normalised(X_t_arr, X_t_mask_normalised)
256+
else:
257+
X_t_arr = (X_t_normalised["x1"].values, X_t_normalised["x2"].values)
242258
elif mode == "off-grid":
243259
X_t_arr = X_t_normalised.reset_index()[["x1", "x2"]].values.T
244260

@@ -379,13 +395,22 @@ def unnormalise_pred_array(arr, **kwargs):
379395
)
380396

381397
if mode == "on-grid":
382-
mean.loc[:, task["time"], :, :] = mean_arr
383-
std.loc[:, task["time"], :, :] = std_arr
384-
if n_samples >= 1:
385-
for sample_i in range(n_samples):
386-
samples.loc[:, sample_i, task["time"], :, :] = samples_arr[
387-
sample_i
388-
]
398+
if X_t_mask is None:
399+
mean.loc[:, task["time"], :, :] = mean_arr
400+
std.loc[:, task["time"], :, :] = std_arr
401+
if n_samples >= 1:
402+
for sample_i in range(n_samples):
403+
samples.loc[:, sample_i, task["time"], :, :] = samples_arr[
404+
sample_i
405+
]
406+
else:
407+
mean.loc[:, task["time"], :, :].data[:, X_t_mask.data] = mean_arr
408+
std.loc[:, task["time"], :, :].data[:, X_t_mask.data] = std_arr
409+
if n_samples >= 1:
410+
for sample_i in range(n_samples):
411+
samples.loc[:, sample_i, task["time"], :, :].data[
412+
:, X_t_mask.data
413+
] = samples_arr[sample_i]
389414
elif mode == "off-grid":
390415
# TODO multi-target case
391416
mean.loc[task["time"]] = mean_arr.T

0 commit comments

Comments
 (0)