Skip to content

Commit

Permalink
fix batch, shuffle data once on load
Browse files Browse the repository at this point in the history
  • Loading branch information
jdilger committed Apr 9, 2024
1 parent 513d036 commit 3184ec8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
27 changes: 17 additions & 10 deletions fao_models/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,38 @@
import tensorflow as tf
import os


def _parse_function(proto):
# Define the parsing schema
feature_description = {
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.string),
"image": tf.io.FixedLenFeature([], tf.string),
"label": tf.io.FixedLenFeature([], tf.string),
}
# Parse the input `tf.train.Example` proto using the schema
example = tf.io.parse_single_example(proto, feature_description)
image = tf.io.parse_tensor(example['image'], out_type=tf.float32)
label = tf.io.parse_tensor(example['label'], out_type=tf.int64)
image = tf.io.parse_tensor(example["image"], out_type=tf.float32)
label = tf.io.parse_tensor(example["label"], out_type=tf.int64)
image.set_shape([32, 32, 4]) # Set the shape explicitly if not already defined
label.set_shape([]) # For scalar labels
return image, label


def load_dataset_from_tfrecords(tfrecord_dir, batch_size=32):

pattern = tfrecord_dir + "/*.tfrecord.gz"
files = tf.data.Dataset.list_files(pattern)
dataset = files.interleave(
lambda x: tf.data.TFRecordDataset(x, compression_type="GZIP"),
cycle_length=tf.data.AUTOTUNE,
block_length=1
block_length=1,
)
dataset = dataset.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE).batch(
batch_size, drop_remainder=True
)
dataset = dataset.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.shuffle(buffer_size=100_000, seed=42)
return dataset


def split_dataset(dataset, total_examples, test_split=0.2, batch_size=32):
test_size = int(total_examples * test_split)
train_size = total_examples - test_size
Expand All @@ -36,7 +41,9 @@ def split_dataset(dataset, total_examples, test_split=0.2, batch_size=32):
train_batches = train_size // batch_size
test_batches = test_size // batch_size

train_dataset = dataset.take(train_batches).batch(batch_size).prefetch(tf.data.AUTOTUNE)
test_dataset = dataset.skip(train_batches).take(test_batches).batch(batch_size).prefetch(tf.data.AUTOTUNE)
train_dataset = dataset.take(train_batches).prefetch(tf.data.AUTOTUNE)
test_dataset = (
dataset.skip(train_batches).take(test_batches).prefetch(tf.data.AUTOTUNE)
)

return train_dataset, test_dataset
return train_dataset, test_dataset
14 changes: 8 additions & 6 deletions fao_models/model_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def main():

# initalize new cli parser
Expand All @@ -42,7 +43,7 @@ def main():
args = parser.parse_args()

config_file = args.config

with open(config_file, "r") as file:
config_data = yaml.safe_load(file)

Expand Down Expand Up @@ -87,15 +88,14 @@ def main():
logger.info(pformat(config_data))

# Load the dataset without batching
dataset = dl.load_dataset_from_tfrecords(data_dir)
dataset = dl.load_dataset_from_tfrecords(data_dir, batch_size=batch_size)

# Split the dataset into training and testing
train_dataset, test_dataset = dl.split_dataset(
dataset, total_examples, test_split=data_split, batch_size=batch_size
)
train_dataset = train_dataset.shuffle(buffer_size, reshuffle_each_iteration=True)


logger.info("Starting model training...")
LOGS_DIR = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "logs", experiment_name
Expand All @@ -117,13 +117,14 @@ def main():
if early_stopping_patience is not None:
logger.info(f"Using early stopping. Patience: {early_stopping_patience}")
early_stop = tf.keras.callbacks.EarlyStopping(
monitor="val_loss", patience=early_stopping_patience, restore_best_weights=True
monitor="val_loss",
patience=early_stopping_patience,
restore_best_weights=True,
)
callbacks.append(early_stop)
callbacks.append(cm_callback)
callbacks.append(tf.keras.callbacks.TensorBoard(LOGS_DIR))


history = model.fit(
train_dataset,
epochs=epochs,
Expand All @@ -135,5 +136,6 @@ def main():
logger.info("Training history:")
logger.info(pformat(history.history))


if __name__ == "__main__":
main()
main()

0 comments on commit 3184ec8

Please sign in to comment.