diff --git a/n2v/internals/N2V_DataGenerator.py b/n2v/internals/N2V_DataGenerator.py index 33ca496..692531f 100644 --- a/n2v/internals/N2V_DataGenerator.py +++ b/n2v/internals/N2V_DataGenerator.py @@ -183,8 +183,8 @@ def __extract_patches__(self, data, num_patches=None, shape=(256, 256), n_dims=2 patches = [] if n_dims == 2: if data.shape[1] > shape[0] and data.shape[2] > shape[1]: - for y in range(0, data.shape[1] - shape[0], shape[0]): - for x in range(0, data.shape[2] - shape[1], shape[1]): + for y in range(0, data.shape[1] - shape[0] + 1, shape[0]): + for x in range(0, data.shape[2] - shape[1] + 1, shape[1]): patches.append(data[:, y:y + shape[0], x:x + shape[1]]) return np.concatenate(patches) @@ -194,14 +194,13 @@ def __extract_patches__(self, data, num_patches=None, shape=(256, 256), n_dims=2 print("'shape' is too big.") elif n_dims == 3: if data.shape[1] > shape[0] and data.shape[2] > shape[1] and data.shape[3] > shape[2]: - for z in range(0, data.shape[1] - shape[0], shape[0]): - for y in range(0, data.shape[2] - shape[1], shape[1]): - for x in range(0, data.shape[3] - shape[2], shape[2]): + for z in range(0, data.shape[1] - shape[0] + 1, shape[0]): + for y in range(0, data.shape[2] - shape[1] + 1, shape[1]): + for x in range(0, data.shape[3] - shape[2] + 1, shape[2]): patches.append(data[:, z:z + shape[0], y:y + shape[1], x:x + shape[2]]) return np.concatenate(patches) - elif data.shape[1] == shape[0] and data.shape[2] == shape[1] and data.shape[3] == shape[ - 2]: + elif data.shape[1] == shape[0] and data.shape[2] == shape[1] and data.shape[3] == shape[2]: return data else: print("'shape' is too big.") diff --git a/n2v/version.py b/n2v/version.py index 850505a..13b7089 100644 --- a/n2v/version.py +++ b/n2v/version.py @@ -1 +1 @@ -__version__ = '0.1.10' +__version__ = '0.1.11' diff --git a/tests/functional/test_training2D_RGB.py b/tests/functional/test_training2D_RGB.py index bbaabe0..86247ac 100755 --- a/tests/functional/test_training2D_RGB.py +++ b/tests/functional/test_training2D_RGB.py @@ -7,20 +7,20 @@ from n2v.utils.n2v_utils import manipulate_val_data from n2v.internals.N2V_DataGenerator import N2V_DataGenerator from matplotlib import pyplot as plt -import urllib +import urllib.request import os import zipfile # create a folder for our data if not os.path.isdir('./data'): os.mkdir('data') - # check if data has been downloaded already - zipPath="data/RGB.zip" - if not os.path.exists(zipPath): - # download and unzip data - data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/Frru2hsjjAljpfW/download', zipPath) - with zipfile.ZipFile(zipPath, 'r') as zip_ref: - zip_ref.extractall("data") +# check if data has been downloaded already +zipPath = "data/RGB.zip" +if not os.path.exists(zipPath): + # download and unzip data + data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/Frru2hsjjAljpfW/download', zipPath) + with zipfile.ZipFile(zipPath, 'r') as zip_ref: + zip_ref.extractall("data") # For training, we will load __one__ low-SNR RGB image and use the N2V_DataGenerator to extract non-overlapping patches datagen = N2V_DataGenerator() @@ -29,7 +29,7 @@ # The function will return a list of images (numpy arrays). # In the 'dims' parameter we specify the order of dimensions in the image files we are reading: # 'C' stands for channels (color) -imgs = datagen.load_imgs_from_directory(directory="data/", filter='*.png', dims='YXC') +imgs = datagen.load_imgs_from_directory(directory="./data", filter='*.png', dims='YXC') print('shape of loaded images: ',imgs[0].shape) # Remove alpha channel diff --git a/tests/functional/test_training2D_SEM.py b/tests/functional/test_training2D_SEM.py index d6df035..b674640 100755 --- a/tests/functional/test_training2D_SEM.py +++ b/tests/functional/test_training2D_SEM.py @@ -13,7 +13,7 @@ # create a folder for our data if not os.path.isdir('./data'): os.mkdir('./data') -zipPath="data/SEM.zip" +zipPath = "data/SEM.zip" if not os.path.exists(zipPath): # download and unzip data data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/pXgfbobntrw06lC/download', zipPath) diff --git a/tests/test_Noise2VoidDataGenerator.py b/tests/test_Noise2VoidDataGenerator.py new file mode 100644 index 0000000..a6cfd8a --- /dev/null +++ b/tests/test_Noise2VoidDataGenerator.py @@ -0,0 +1,47 @@ +from n2v.internals.N2V_DataGenerator import N2V_DataGenerator +import urllib.request +import os +import zipfile + + +def test_generate_patches_2D(): + + if not os.path.isdir('data'): + os.mkdir('data') + zip_path = "data/RGB.zip" + if not os.path.exists(zip_path): + data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/Frru2hsjjAljpfW/download', zip_path) + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall('data') + + datagen = N2V_DataGenerator() + + imgs = datagen.load_imgs_from_directory(directory="data", filter='*.png', dims='YXC') + imgs[0] = imgs[0][..., :3] + patches = datagen.generate_patches_from_list(imgs, shape=(1100, 2800)) + assert len(patches) == 1 + patches = datagen.generate_patches_from_list(imgs, shape=(550, 1400)) + assert len(patches) == 4 + patches = datagen.generate_patches_from_list(imgs, shape=(110, 280)) + assert len(patches) == 100 + +def test_generate_patches_3D(): + + if not os.path.isdir('data'): + os.mkdir('data') + zip_path = 'data/flywing-data.zip' + if not os.path.exists(zip_path): + # download and unzip data + data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/RKStdwKo4FlFrxE/download', zip_path) + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall('data') + + datagen = N2V_DataGenerator() + + imgs = datagen.load_imgs_from_directory(directory="data", filter='*.tif', dims='ZYX') + print(imgs[0].shape) + patches = datagen.generate_patches_from_list(imgs[:1], shape=(35, 520, 692)) + assert len(patches) == 1 + patches = datagen.generate_patches_from_list(imgs[:1], shape=(5, 52, 174)) + assert len(patches) == 210 + diff --git a/tests/test_Noise2VoidDataWrapper.py b/tests/test_Noise2VoidDataWrapper.py index c95dfb0..03abfc8 100644 --- a/tests/test_Noise2VoidDataWrapper.py +++ b/tests/test_Noise2VoidDataWrapper.py @@ -1,4 +1,4 @@ -from n2v.internals.N2V_DataWrapper import N2V_DataWrapper +from n2v.internals.N2V_DataWrapper import N2V_DataWrapper import numpy as np