From a9c4bcdf760d407a1493b8b852c7a60cb3290558 Mon Sep 17 00:00:00 2001 From: Kyle Goyette Date: Mon, 26 Feb 2024 09:41:59 -0800 Subject: [PATCH 1/2] fix --- jobs/fashion_mnist_train/job.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/jobs/fashion_mnist_train/job.py b/jobs/fashion_mnist_train/job.py index 5570959..daee60d 100644 --- a/jobs/fashion_mnist_train/job.py +++ b/jobs/fashion_mnist_train/job.py @@ -17,15 +17,18 @@ def train(project: Optional[str], entity: Optional[str], **kwargs: Any): - run = wandb.init(project=project, entity=entity) + run = wandb.init(project=project, entity=entity, config={ + "epochs": 10, + "learning_rate": 0.001, + "steps_per_epoch": 10, + }) # get config, could be set from sweep scheduler train_config = run.config - # get training parameters from config epochs = train_config.get("epochs", 10) learning_rate = train_config.get("learning_rate", 0.001) - steps_per_epoch = train_config.get("steps_per_epoch", 10) + # load data (train_X, train_y), (test_X, test_y) = fashion_mnist.load_data() @@ -98,7 +101,6 @@ def train(project: Optional[str], entity: Optional[str], **kwargs: Any): # Log plot wandb.log({"prediction-chart": plt}) - # Code from: https://www.geeksforgeeks.org/fashion-mnist-with-python-keras-and-deep-learning/ def model_arch(): """Define the architecture of the model""" From 8027d64fc91e2b927609452a0e2db6ce6936bfa9 Mon Sep 17 00:00:00 2001 From: Kyle Goyette Date: Mon, 26 Feb 2024 09:46:47 -0800 Subject: [PATCH 2/2] wip --- jobs/fashion_mnist_train/job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jobs/fashion_mnist_train/job.py b/jobs/fashion_mnist_train/job.py index daee60d..05d98f3 100644 --- a/jobs/fashion_mnist_train/job.py +++ b/jobs/fashion_mnist_train/job.py @@ -28,7 +28,7 @@ def train(project: Optional[str], entity: Optional[str], **kwargs: Any): # get training parameters from config epochs = train_config.get("epochs", 10) learning_rate = train_config.get("learning_rate", 0.001) - + steps_per_epoch = train_config.get("steps_per_epoch", 10) # load data (train_X, train_y), (test_X, test_y) = fashion_mnist.load_data()