Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleGoyette committed Feb 26, 2024
1 parent c1497c7 commit a9c4bcd
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions jobs/fashion_mnist_train/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@


def train(project: Optional[str], entity: Optional[str], **kwargs: Any):
run = wandb.init(project=project, entity=entity)
run = wandb.init(project=project, entity=entity, config={
"epochs": 10,
"learning_rate": 0.001,
"steps_per_epoch": 10,
})

# get config, could be set from sweep scheduler
train_config = run.config

# get training parameters from config
epochs = train_config.get("epochs", 10)
learning_rate = train_config.get("learning_rate", 0.001)
steps_per_epoch = train_config.get("steps_per_epoch", 10)


# load data
(train_X, train_y), (test_X, test_y) = fashion_mnist.load_data()
Expand Down Expand Up @@ -98,7 +101,6 @@ def train(project: Optional[str], entity: Optional[str], **kwargs: Any):
# Log plot
wandb.log({"prediction-chart": plt})


# Code from: https://www.geeksforgeeks.org/fashion-mnist-with-python-keras-and-deep-learning/
def model_arch():
"""Define the architecture of the model"""
Expand Down

0 comments on commit a9c4bcd

Please sign in to comment.