Skip to content

Commit

Permalink
major bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
phborba committed Nov 25, 2020
1 parent 3e53fa4 commit cfd6cbb
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions segmentation_models_trainer/dataset_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,20 @@ def process_csv_entry(entry):
separator=''
)
)
label = decode_img(label, width, length, channels=1)
label = decode_img(
label,
width,
length,
channels=1,
interpolation_method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
)
# load the raw data from the file as a string
img = tf.io.read_file(
entry['label_path'][0] if self.base_path == '' \
entry['image_path'][0] if self.base_path == '' \
else tf.strings.join(
[
self.base_path,
entry['label_path'][0]
entry['image_path'][0]
],
separator=''
)
Expand All @@ -182,13 +188,13 @@ def process_csv_entry(entry):
return img, label


def decode_img(img, width, length, channels=3):
def decode_img(img, width, length, channels=3, interpolation_method=tf.image.ResizeMethod.BILINEAR):
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_png(img, channels=channels)
# Use `convert_image_dtype` to convert to floats in the [0,1] range.
img = tf.image.convert_image_dtype(img, IMAGE_DTYPE[self.img_dtype])
# resize the image to the desired size.
return tf.image.resize(img, [width, length])
return tf.image.resize(img, [width, length], method=interpolation_method)

def prepare_for_training(ds, batch_size):
if self.cache:
Expand Down

0 comments on commit cfd6cbb

Please sign in to comment.