diff --git a/jobs/fashion_mnist_train/job.py b/jobs/fashion_mnist_train/job.py index 5570959..05d98f3 100644 --- a/jobs/fashion_mnist_train/job.py +++ b/jobs/fashion_mnist_train/job.py @@ -17,11 +17,14 @@ def train(project: Optional[str], entity: Optional[str], **kwargs: Any): - run = wandb.init(project=project, entity=entity) + run = wandb.init(project=project, entity=entity, config={ + "epochs": 10, + "learning_rate": 0.001, + "steps_per_epoch": 10, + }) # get config, could be set from sweep scheduler train_config = run.config - # get training parameters from config epochs = train_config.get("epochs", 10) learning_rate = train_config.get("learning_rate", 0.001) @@ -98,7 +101,6 @@ def train(project: Optional[str], entity: Optional[str], **kwargs: Any): # Log plot wandb.log({"prediction-chart": plt}) - # Code from: https://www.geeksforgeeks.org/fashion-mnist-with-python-keras-and-deep-learning/ def model_arch(): """Define the architecture of the model"""