diff --git a/segmentation_models_trainer/dataset_loader/dataset.py b/segmentation_models_trainer/dataset_loader/dataset.py index 0f23735..8439464 100644 --- a/segmentation_models_trainer/dataset_loader/dataset.py +++ b/segmentation_models_trainer/dataset_loader/dataset.py @@ -165,14 +165,20 @@ def process_csv_entry(entry): separator='' ) ) - label = decode_img(label, width, length, channels=1) + label = decode_img( + label, + width, + length, + channels=1, + interpolation_method=tf.image.ResizeMethod.NEAREST_NEIGHBOR + ) # load the raw data from the file as a string img = tf.io.read_file( - entry['label_path'][0] if self.base_path == '' \ + entry['image_path'][0] if self.base_path == '' \ else tf.strings.join( [ self.base_path, - entry['label_path'][0] + entry['image_path'][0] ], separator='' ) @@ -182,13 +188,13 @@ def process_csv_entry(entry): return img, label - def decode_img(img, width, length, channels=3): + def decode_img(img, width, length, channels=3, interpolation_method=tf.image.ResizeMethod.BILINEAR): # convert the compressed string to a 3D uint8 tensor img = tf.image.decode_png(img, channels=channels) # Use `convert_image_dtype` to convert to floats in the [0,1] range. img = tf.image.convert_image_dtype(img, IMAGE_DTYPE[self.img_dtype]) # resize the image to the desired size. - return tf.image.resize(img, [width, length]) + return tf.image.resize(img, [width, length], method=interpolation_method) def prepare_for_training(ds, batch_size): if self.cache: