diff --git a/test/test_images.py b/test/test_images.py index 89e96bdc..6d7eb690 100644 --- a/test/test_images.py +++ b/test/test_images.py @@ -96,6 +96,13 @@ def test_subsample(eng): assert allclose(vals, truth) +def test_decimate(eng): + data = fromlist(list(arange(24).reshape((4, 3, 2))), engine=eng) + vals = data.decimate(2).toarray() + truth = [arange(3, 9).reshape((3, 2)), arange(15, 21).reshape((3, 2))] + assert allclose(vals, truth) + + def test_median_filter_2d(eng): data = fromlist([arange(24).reshape((4, 6))], engine=eng) assert data.median_filter(2).toarray().shape == (4, 6) diff --git a/thunder/images/images.py b/thunder/images/images.py index 9c1872ef..2b9fe04a 100644 --- a/thunder/images/images.py +++ b/thunder/images/images.py @@ -106,7 +106,7 @@ def toseries(self, size='150'): return Series(self.values.swap((0,), tuple(range(n)), size=size), index=index) if self.mode == 'local': - return Series(self.values.transpose(tuple(range(1, n+1)) + (0,)), index=index) + return Series(self.values.transpose(tuple(range(1, n + 1)) + (0,)), index=index) def tolocal(self): """ @@ -255,7 +255,7 @@ def max_projection(self, axis=2): """ if axis >= size(self.dims): raise Exception('Axis for projection (%s) exceeds ' - 'image dimensions (%s-%s)' % (axis, 0, size(self.dims)-1)) + 'image dimensions (%s-%s)' % (axis, 0, size(self.dims) - 1)) newdims = list(self.dims) del newdims[axis] @@ -274,7 +274,7 @@ def max_min_projection(self, axis=2): """ if axis >= size(self.dims): raise Exception('Axis for projection (%s) exceeds ' - 'image dimensions (%s-%s)' % (axis, 0, size(self.dims)-1)) + 'image dimensions (%s-%s)' % (axis, 0, size(self.dims) - 1)) newdims = list(self.dims) del newdims[axis] @@ -308,6 +308,34 @@ def roundup(a, b): return self.map(lambda v: v[slices], dims=newdims) + def decimate(self, factor): + """ + Decimate images by an integer factor. + + Parameters + ---------- + factor : positive int + Number of images to average together. Corresponds to running mean filtering + with window length 'factor' followed by subsampling by 'factor' + """ + if self.mode == 'spark': + from thunder.images.readers import fromrdd + decimated = self.tordd().map(lambda k, v: (int(k[0]) / factor, (v, 1))) + decimated = decimated.reduceByKey(lambda x, y: (x[0] + y[0], x[1] + y[1])) + decimated = decimated.map(lambda k, v: (k, (v[0] / float(v[1])))) + return fromrdd(decimated) + else: + from thunder.images.readers import fromarray + from numpy import vstack + T = self.shape[0] + decimated = (self.values[:T - T % factor] + .reshape((T // factor, factor) + self.shape[1:]).mean(1)) + if T % factor: + remainder = self.values[-T % factor:].mean(0).reshape((1,) + self.shape[1:]) + return fromarray(vstack([decimated, remainder])) + else: + return fromarray(decimated) + def gaussian_filter(self, sigma=2, order=0): """ Spatially smooth images with a gaussian filter.