Skip to content

Commit

Permalink
Support batch 1-d convolution in ht.signal.convolve (#1515)
Browse files Browse the repository at this point in the history
* define conditions for batch processing

* set cond for distributed batch processing

* test batch convolve

* implement batch convolve

* add link to conv1d documentation

* reinstate exception check

* expand docs, add examples

* fix tests

* cast local tensors to torch dtype

* expand tests

* test n-D batch convolutions and empty tensors

* bypass empty tensors

* edit docs

* add comment re: batch_processing variable

* test with 4-D batch signal
  • Loading branch information
ClaudiaComito authored Jul 4, 2024
1 parent 064f495 commit 225dc96
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 20 deletions.
141 changes: 124 additions & 17 deletions heat/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
"""
Returns the discrete, linear convolution of two one-dimensional `DNDarray`s or scalars.
Unlike `numpy.signal.convolve`, if ``a`` and/or ``v`` have more than one dimension, batch-convolution along the last dimension will be attempted. See `Examples` below.
Parameters
----------
a : DNDarray or scalar
One-dimensional signal `DNDarray` of shape (N,) or scalar.
One- or N-dimensional signal ``DNDarray`` of shape (..., N), or scalar. If ``a`` has more than one dimension, it will be treated as a batch of 1D signals.
Distribution along the batch dimension is required for distributed batch processing. See the examples for details.
v : DNDarray or scalar
One-dimensional filter weight `DNDarray` of shape (M,) or scalar.
One- or N-dimensional filter weight `DNDarray` of shape (..., M), or scalar. If ``v`` has more than one dimension, it will be treated as a batch of 1D filter weights.
The batch dimension(s) of ``v`` must match the batch dimension(s) of ``a``.
mode : str
Can be 'full', 'valid', or 'same'. Default is 'full'.
'full':
Expand Down Expand Up @@ -69,6 +72,34 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
[0/3] DNDarray([0., 1., 3., 3.])
[1/3] DNDarray([3., 3., 3., 3.])
[2/3] DNDarray([3., 3., 3., 2.])
>>> a = ht.arange(50, dtype = ht.float64, split=0)
>>> a = a.reshape(10, 5) # 10 signals of length 5
>>> v = ht.arange(3)
>>> ht.convolve(a, v) # batch processing: 10 signals convolved with filter v
DNDarray([[ 0., 0., 1., 4., 7., 10., 8.],
[ 0., 5., 16., 19., 22., 25., 18.],
[ 0., 10., 31., 34., 37., 40., 28.],
[ 0., 15., 46., 49., 52., 55., 38.],
[ 0., 20., 61., 64., 67., 70., 48.],
[ 0., 25., 76., 79., 82., 85., 58.],
[ 0., 30., 91., 94., 97., 100., 68.],
[ 0., 35., 106., 109., 112., 115., 78.],
[ 0., 40., 121., 124., 127., 130., 88.],
[ 0., 45., 136., 139., 142., 145., 98.]], dtype=ht.float64, device=cpu:0, split=0)
>>> v = ht.random.randint(0, 3, (10, 3), split=0) # 10 filters of length 3
>>> ht.convolve(a, v) # batch processing: 10 signals convolved with 10 filters
DNDarray([[ 0., 0., 2., 4., 6., 8., 0.],
[ 5., 6., 7., 8., 9., 0., 0.],
[ 20., 42., 56., 61., 66., 41., 14.],
[ 0., 15., 16., 17., 18., 19., 0.],
[ 20., 61., 64., 67., 70., 48., 0.],
[ 50., 52., 104., 108., 112., 56., 58.],
[ 0., 30., 61., 63., 65., 67., 34.],
[ 35., 106., 109., 112., 115., 78., 0.],
[ 0., 40., 81., 83., 85., 87., 44.],
[ 0., 0., 45., 46., 47., 48., 49.]], dtype=ht.float64, device=cpu:0, split=0)
"""
if np.isscalar(a):
a = array([a])
Expand All @@ -88,34 +119,110 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
a = a.astype(promoted_type)
v = v.astype(promoted_type)

if len(a.shape) != 1 or len(v.shape) != 1:
raise ValueError("Only 1-dimensional input DNDarrays are allowed")
if mode == "same" and v.shape[0] % 2 == 0:
# check if the filter is longer than the signal and swap them if necessary
if v.shape[-1] > a.shape[-1]:
a, v = v, a

# assess whether to perform batch processing, default is False (no batch processing)
batch_processing = False
if a.ndim > 1:
# batch processing requires 1D filter OR matching batch dimensions for signal and filter
batch_dims = a.shape[:-1]
# verify that the filter shape is consistent with the signal
if v.ndim > 1:
if v.shape[:-1] != batch_dims:
raise ValueError(
f"Batch dimensions of signal and filter must match. Signal: {a.shape}, Filter: {v.shape}"
)
if a.is_distributed():
if a.split == a.ndim - 1:
raise ValueError(
"Please distribute the signal along the batch dimension, not the signal dimension. For in-place redistribution use the `DNDarray.resplit_()` method with `axis=0`"
)
if v.is_distributed():
if v.ndim == 1:
# gather filter to all ranks
v.resplit_(axis=None)
else:
v.resplit_(axis=a.split)
batch_processing = True

if not batch_processing and v.ndim > 1:
raise ValueError(
f"1-D convolution only supported for 1-dimensional signal and kernel. Signal: {a.shape}, Filter: {v.shape}"
)

if mode == "same" and v.shape[-1] % 2 == 0:
raise ValueError("Mode 'same' cannot be used with even-sized kernel")
if not v.is_balanced():
raise ValueError("Only balanced kernel weights are allowed")

if v.shape[0] > a.shape[0]:
a, v = v, a

# compute halo size
halo_size = torch.max(v.lshape_map[:, 0]).item() // 2

# pad DNDarray with zeros according to mode
# calculate pad size according to mode
if mode == "full":
pad_size = v.shape[0] - 1
gshape = v.shape[0] + a.shape[0] - 1
pad_size = v.shape[-1] - 1
gshape = v.shape[-1] + a.shape[-1] - 1
elif mode == "same":
pad_size = v.shape[0] // 2
gshape = a.shape[0]
pad_size = v.shape[-1] // 2
gshape = a.shape[-1]
elif mode == "valid":
pad_size = 0
gshape = a.shape[0] - v.shape[0] + 1
gshape = a.shape[-1] - v.shape[-1] + 1
else:
raise ValueError(f"Supported modes are 'full', 'valid', 'same', got {mode}")

if batch_processing:
# all operations are local torch operations, only the last dimension is convolved
local_a = a.larray
local_v = v.larray
# flip filter for convolution, as Pytorch conv1d computes correlations
local_v = torch.flip(local_v, [-1])
local_batch_dims = tuple(local_a.shape[:-1])

# reshape signal and filter to 3D for Pytorch conv1d function
# see https://pytorch.org/docs/stable/generated/torch.nn.functional.conv1d.html
local_a = local_a.reshape(
torch.prod(torch.tensor(local_batch_dims, device=local_a.device), dim=0).item(),
local_a.shape[-1],
)
channels = local_a.shape[0]
if v.ndim > 1:
local_v = local_v.reshape(
torch.prod(torch.tensor(local_batch_dims, device=local_v.device), dim=0).item(),
local_v.shape[-1],
)
local_v = local_v.unsqueeze(1)
else:
local_v = local_v.unsqueeze(0).unsqueeze(0).expand(local_a.shape[0], 1, -1)
# add batch dimension to signal
local_a = local_a.unsqueeze(0)

# cast to single-precision float if on GPU
if local_a.is_cuda:
float_type = torch.promote_types(local_a.dtype, torch.float32)
local_a = local_a.to(float_type)
local_v = local_v.to(float_type)

# apply torch convolution operator if local signal isn't empty
if torch.prod(torch.tensor(local_a.shape, device=local_a.device)) > 0:
local_convolved = fc.conv1d(local_a, local_v, padding=pad_size, groups=channels)
else:
empty_shape = tuple(local_a.shape[:-1] + (gshape,))
local_convolved = torch.empty(empty_shape, dtype=local_a.dtype, device=local_a.device)

# unpack 3D result into original shape
local_convolved = local_convolved.squeeze(0)
local_convolved = local_convolved.reshape(local_batch_dims + (gshape,))

# wrap result in DNDarray
convolved = array(local_convolved, is_split=a.split, device=a.device, comm=a.comm)
return convolved

# pad signal with zeros
a = pad(a, pad_size, "constant", 0)

# compute halo size
halo_size = torch.max(v.lshape_map[:, -1]).item() // 2

if a.is_distributed():
if (v.lshape_map[:, 0] > a.lshape_map[:, 0]).any():
raise ValueError(
Expand Down
41 changes: 38 additions & 3 deletions heat/core/tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def test_convolve(self):
ht.convolve(dis_signal, filter_wrong_type, mode="full")
with self.assertRaises(ValueError):
ht.convolve(dis_signal, kernel_odd, mode="invalid")
with self.assertRaises(ValueError):
s = dis_signal.reshape((2, -1))
ht.convolve(s, kernel_odd)
if dis_signal.comm.size > 1:
with self.assertRaises(ValueError):
s = dis_signal.reshape((2, -1)).resplit(axis=1)
ht.convolve(s, kernel_odd)
with self.assertRaises(ValueError):
k = ht.eye(3)
ht.convolve(dis_signal, k)
Expand Down Expand Up @@ -119,3 +120,37 @@ def test_convolve(self):

conv = ht.convolve(1, 5)
self.assertTrue(ht.equal(ht.array([5]), conv))

# test batched convolutions, distributed along the first axis
signal = ht.random.randn(1000, dtype=ht.float64)
batch_signal = ht.empty((10, 1000), dtype=ht.float64, split=0)
batch_signal.larray[:] = signal.larray
kernel = ht.random.randn(19, dtype=ht.float64)
batch_convolved = ht.convolve(batch_signal, kernel, mode="same")
self.assertTrue(ht.equal(ht.convolve(signal, kernel, mode="same"), batch_convolved[0]))

# distributed kernel
dis_kernel = ht.array(kernel, split=0)
batch_convolved = ht.convolve(batch_signal, dis_kernel)
self.assertTrue(ht.equal(ht.convolve(signal, kernel), batch_convolved[0]))
batch_kernel = ht.empty((10, 19), dtype=ht.float64, split=1)
batch_kernel.larray[:] = dis_kernel.larray
batch_convolved = ht.convolve(batch_signal, batch_kernel, mode="full")
self.assertTrue(ht.equal(ht.convolve(signal, kernel, mode="full"), batch_convolved[0]))

# n-D batch convolution
batch_signal = ht.empty((4, 3, 3, 1000), dtype=ht.float64, split=1)
batch_signal.larray[:, :, :] = signal.larray
batch_convolved = ht.convolve(batch_signal, kernel, mode="valid")
self.assertTrue(
ht.equal(ht.convolve(signal, kernel, mode="valid"), batch_convolved[1, 2, 0])
)

# test batch-convolve exceptions
batch_kernel_wrong_shape = ht.random.randn(3, 19, dtype=ht.float64)
with self.assertRaises(ValueError):
ht.convolve(batch_signal, batch_kernel_wrong_shape)
if kernel.comm.size > 1:
batch_signal_wrong_split = batch_signal.resplit(-1)
with self.assertRaises(ValueError):
ht.convolve(batch_signal_wrong_split, kernel)

0 comments on commit 225dc96

Please sign in to comment.