Skip to content

Commit

Permalink
Extract n patches from each sample
Browse files Browse the repository at this point in the history
  • Loading branch information
tibuch committed Nov 7, 2019
1 parent 49cf8b1 commit 0bf5078
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions n2v/internals/N2V_DataGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ def generate_patches_from_list(self, data, num_patches_per_img=None, shape=(256,
"""
patches = []
for img in data:
p = self.generate_patches(img, num_patches=num_patches_per_img, shape=shape, augment=augment)
patches.append(p)
for s in range(img.shape[0]):
p = self.generate_patches(img[s][np.newaxis], num_patches=num_patches_per_img, shape=shape, augment=augment)
patches.append(p)

patches = np.concatenate(patches, axis=0)

Expand Down Expand Up @@ -210,26 +211,24 @@ def __extract_patches__(self, data, num_patches=None, shape=(256, 256), n_dims=2
patches = []
if n_dims == 2:
for i in range(num_patches):
s = np.random.randint(0, data.shape[0])
y, x = np.random.randint(0, data.shape[1] - shape[0] + 1), np.random.randint(0,
data.shape[
2] - shape[
1] + 1)
patches.append(data[s, y:y + shape[0], x:x + shape[1]])
patches.append(data[0, y:y + shape[0], x:x + shape[1]])

if len(patches) > 1:
return np.stack(patches)
else:
return np.array(patches)[np.newaxis]
elif n_dims == 3:
for i in range(num_patches):
s = np.random.randint(0, data.shape[0])
z, y, x = np.random.randint(0, data.shape[1] - shape[0] + 1), np.random.randint(0,
data.shape[
2] - shape[
1] + 1), np.random.randint(
0, data.shape[3] - shape[2] + 1)
patches.append(data[s, z:z + shape[0], y:y + shape[1], x:x + shape[2]])
patches.append(data[0, z:z + shape[0], y:y + shape[1], x:x + shape[2]])

if len(patches) > 1:
return np.stack(patches)
Expand Down

0 comments on commit 0bf5078

Please sign in to comment.