Skip to content

Commit

Permalink
Update to_tuple to account for labels > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
mjevans26 committed Jan 4, 2021
1 parent db6af9c commit 5da2b44
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
8 changes: 5 additions & 3 deletions azure/train_acd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@
FEATURES_DICT = dict(zip(FEATURES, COLUMNS))

# create training dataset
train_files = glob.glob(os.path.join(args.data_folder, 'train'))
eval_files = glob.glob(os.path.join(args.data_folder, 'eval'))
train_files = glob.glob(os.path.join(args.data_folder, 'training', 'UNET_256_[A-Z]*.gz'))
eval_files = glob.glob(os.path.join(args.data_folder, 'eval', 'UNET_256_[A-Z]*.gz'))

print(len(train_files))

training = processing.get_training_dataset(train_files, ftDict = FEATURES_DICT, buffer = args.buffer, batch = args.batch)
evaluation = processing.get_eval_dataset(eval_files, ftDict = FEATURES_DICT)
Expand Down Expand Up @@ -93,7 +95,7 @@
x = training,
epochs = args.epochs,
#TODO: make command line argument for size
steps_per_epoch = int(63*16),
steps_per_epoch = 63,
validation_data = evaluation,
callbacks = [checkpoint, tensorboard]
)
Expand Down
8 changes: 5 additions & 3 deletions utils/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def to_tuple(inputs, features, response):
#split input bands and labels
bands = stacked[:,:,:len(features)]
labels = stacked[:,:,len(features):]
# in case labels are >1
labels = tf.where(tf.greater(labels, 1.0), 1.0, labels)
# perform color augmentation on input features
bands = aug_color(bands)
# standardize each patch of bands
Expand Down Expand Up @@ -161,7 +163,7 @@ def tupelize(inputs):
dataset = dataset.map(tupelize, num_parallel_calls=5)
return dataset

def get_training_dataset(files, ftDict, buffer, batch):
def get_training_dataset(files, ftDict, buff, batch):
"""
Get the preprocessed training dataset
Args:
Expand All @@ -172,7 +174,7 @@ def get_training_dataset(files, ftDict, buffer, batch):
A tf.data.Dataset of training data.
"""
dataset = get_dataset(files, ftDict)
dataset = dataset.shuffle(buffer).batch(batch).repeat()
dataset = dataset.shuffle(buff).batch(batch).repeat()
return dataset

def get_eval_dataset(files, ftDict):
Expand All @@ -184,5 +186,5 @@ def get_eval_dataset(files, ftDict):
A tf.data.Dataset of evaluation data.
"""
dataset = get_dataset(files, ftDict)
dataset = dataset.batch(1).repeat()
dataset = dataset.batch(1)
return dataset

0 comments on commit 5da2b44

Please sign in to comment.