Skip to content

Commit

Permalink
add Overlap, Trim and Roll transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
SerodioJ committed Feb 29, 2024
1 parent a167057 commit 95e68ba
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions dasf/transforms/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,97 @@ def _transform_gpu(self, X):
return X


class Overlap(Transform):
"""
Operator to get chunks with their respective overlaps. Useful when it is desired to use the same chunks with overlaps for multiple operations.
"""

def __init__(self, pad=(1, 1, 1)):
self._pad = pad

def _lazy_transform(self, X):
return da.overlap.overlap(X, depth=self._pad, boundary="nearest")

def _lazy_transform_gpu(self, X):
return self._lazy_transform(X)

def _lazy_transform_cpu(self, X):
return self._lazy_transform(X)

def _transform(self, X, xp):
return xp.pad(
X,
[
(self._pad[0],),
(self._pad[1],),
(self._pad[2],),
],
mode="edge",
)

def _transform_gpu(self, X):
return self._transform(X, cp)

def _transform_cpu(self, X):
return self._transform(X, np)


class Trim(Transform):
"""
Operator to trim dask array that was produced by an Overlap transform or subsequent results from that transform.
"""

def __init__(self, trim=(1, 1, 1)):
self._trim = trim

def _lazy_transform(self, X):
return da.overlap.trim_overlap(
X,
depth=self._trim,
boundary="nearest",
)

def _lazy_transform_gpu(self, X):
return self._lazy_transform(X)

def _lazy_transform_cpu(self, X):
return self._lazy_transform(X)

def _transform(self, X):
sl = [slice(t, -t, None) for t in self._trim]
return X[tuple(sl)]

def _transform_gpu(self, X):
return self._transform(X)

def _transform_cpu(self, X):
return self._transform(X)


class Roll(Transform):
"""
Operator to perform a roll along multiple axis
"""

def __init__(self, shift=(1, 1, 1)):
self._shift = shift

def _transform_generic(self, X, xp):
return xp.roll(X, shift=self._shift, axis=list(range(len(self._shift))))

def _lazy_transform_gpu(self, X):
return X.map_blocks(self._transform_generic, xp=cp)

def _lazy_transform_cpu(self, X):
return X.map_blocks(self._transform_generic, xp=np)

def _transform_gpu(self, X):
return self._transform_generic(X, cp)

def _transform_cpu(self, X):
return self._transform_generic(X, np)


class ApplyPatchesBase(Transform):
"""
Base Class for ApplyPatches Functionalities
Expand Down

0 comments on commit 95e68ba

Please sign in to comment.