Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add postprocess option to ts2img #24

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/repurpose/img2ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,8 @@ def calc(self):

for img_stack_dict, timestamps in self.img_bulk():
# =================================================================
logging.info(f"Finished reading bulk with {len(timestamps)} images")

start_time = datetime.now()

# temporally drop grids, due to issue when pickling them...
Expand All @@ -451,6 +453,7 @@ def calc(self):

keys = list(img_stack_dict.keys())
for key in keys:
#print(key)
# rename variable in output dataset
if self.variable_rename is None:
var_new_name = str(key)
Expand All @@ -473,6 +476,7 @@ def calc(self):
ITER_KWARGS = {'cell': [], 'celldata': []}

for cell in np.unique(target_grid.activearrcell):
#print(cell)
cell_idx = np.where(cells == cell)[0]

if len(cell_idx) == 0:
Expand All @@ -486,7 +490,6 @@ def calc(self):
np.atleast_2d(img_stack_dict[k])[:, cell_idx], 0, 1)
img_stack_dict[k] = np.delete(img_stack_dict[k], cell_idx,
axis=1)

cells = np.delete(cells, cell_idx)

ITER_KWARGS['celldata'].append(celldata)
Expand Down
46 changes: 31 additions & 15 deletions src/repurpose/ts2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,15 @@ def _convert(converter: 'Ts2Img',
if ts is None:
continue
if preprocess_func is not None:
preprocess_kwargs = preprocess_kwargs or {}
ts = preprocess_func(ts, **preprocess_kwargs)
preprocess_func = np.atleast_1d(preprocess_func)
if preprocess_kwargs is None:
preprocess_kwargs = [{}] * len(preprocess_func)
preprocess_kwargs = np.atleast_1d(preprocess_kwargs)
if len(preprocess_func) != len(preprocess_kwargs):
raise ValueError("Length of preprocess_func and "
"preprocess_kwargs is different")
for func, kwargs in zip(preprocess_func, preprocess_kwargs):
ts = func(ts, **kwargs)
if np.any(np.isin(ts.columns, Ts2Img._protected_vars)):
raise ValueError(
f"Time series contains protected variables. "
Expand Down Expand Up @@ -263,7 +270,7 @@ def calc(self, path_out, format_out='slice', preprocess=None,
- stack: write all time steps into one file. In this case if there
is a {datetime} placeholder in the fn_template, then the time
range is inserted.
preprocess: callable, optional (default: None)
preprocess: callable or list[Callable], optional (default: None)
Function that is applied to each time series before converting it.
The first argument is the data frame that the reader returns.
Additional keyword arguments can be passed via `preprocess_kwargs`.
Expand All @@ -279,25 +286,25 @@ def preprocess_add(df: pd.DataFrame, **preprocess_kwargs) \
df['var3'] = df['var1'] + df['var2']
return df
```
preprocess_kwargs: dict, optional (default: None)
preprocess_kwargs: dict or list[dict], optional (default: None)
Keyword arguments for the preprocess function. If None are given,
then the preprocessing function is is called with only the input
data frame and no additional arguments (see example above).
postprocess: Callable, optional (default: None)
Function that is applied to the image stack after loading the data
and before writing it to disk. The function must take xarray
postprocess: Callable or list[Callable], optional (default: None)
Function(s) applied to the image stack after loading the data
and before writing it to disk. The function must take an xarray
Dataset as the first argument and return an xarray Dataset of the
same form as the input data.
same form.
A simple example for a preprocessing function to add a new variable
from the sum of two existing variables:
```
def preprocess_add(stack: xr.Dataset, **postprocess_kwargs) \
def postprocess_add(stack: xr.Dataset, **postprocess_kwargs) \
-> xr.Dataset
stack = stack.assign(var3=lambda x: x['var0'] + x['var2'])
return stack
```
postprocess_kwargs: dict, optional (default: None)
Keyword arguments for the postprocess function. If None are given,
postprocess_kwargs: dict or list[dict], optional (default: None)
Keyword arguments for the postprocess function(s). If None are given,
then the postprocess function is called with only the input
image stack and no additional arguments (see example above).
fn_template: str, optional (default: "{datetime}.nc")
Expand Down Expand Up @@ -381,10 +388,6 @@ def preprocess_add(stack: xr.Dataset, **postprocess_kwargs) \

self.stack = self.stack.drop_isel(time=idx_empty)

if postprocess is not None:
postprocess_kwargs = postprocess_kwargs or {}
self.stack = postprocess(self.stack, **postprocess_kwargs)

if var_fillvalues is not None:
for var, fillvalue in var_fillvalues.items():
self.stack[var].values = np.nan_to_num(
Expand All @@ -410,6 +413,19 @@ def preprocess_add(stack: xr.Dataset, **postprocess_kwargs) \
for var in var_attrs:
self.stack[var].attrs.update(var_attrs[var])

if postprocess is not None:
# this is done as late as possible
postprocess = np.atleast_1d(postprocess)
if postprocess_kwargs is None:
postprocess_kwargs = [{}] * len(postprocess)
postprocess_kwargs = np.atleast_1d(postprocess_kwargs)
if len(postprocess) != len(postprocess_kwargs):
raise ValueError("postprocess and postprocess_kwargs "
"have different lengths")

for func, kwargs in zip(postprocess, postprocess_kwargs):
self.stack = func(self.stack, **kwargs)

if self.stack['time'].size == 0:
warnings.warn("No images in stack to write to disk.")
self.stack = None
Expand Down
Loading