From c34e294a67fb3c0e92e95bdf0f461ed4affe4076 Mon Sep 17 00:00:00 2001 From: Kye Date: Thu, 5 Oct 2023 23:42:59 -0400 Subject: [PATCH] train --- swarms_torch/cellular_transformer.py | 1 + swarms_torch/graph_cellular_automa.py | 62 +++++++++++++++++++++++++-- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/swarms_torch/cellular_transformer.py b/swarms_torch/cellular_transformer.py index 7f384f1..1155a42 100644 --- a/swarms_torch/cellular_transformer.py +++ b/swarms_torch/cellular_transformer.py @@ -89,3 +89,4 @@ def forward(self, x): x = cell(x, neighbors) return x + diff --git a/swarms_torch/graph_cellular_automa.py b/swarms_torch/graph_cellular_automa.py index 41a13c0..3515ef2 100644 --- a/swarms_torch/graph_cellular_automa.py +++ b/swarms_torch/graph_cellular_automa.py @@ -70,15 +70,71 @@ def forward(self, node_embeddings, adjacency_matrix): return updated_embeddings, replication_decisions, edge_weights -# Usage examples +# # Usage examples +# embedding_dim = 16 +# hidden_dim = 32 +# node_embeddings = torch.rand((10, embedding_dim)) # For 10 nodes +# adjacency_matrix = torch.rand((10, 10)) # Dummy adjacency matrix for 10 nodes + +# model = NDP(embedding_dim, hidden_dim) +# updated_embeddings, replication_decisions, edge_weights = model(node_embeddings, adjacency_matrix) + +# print(updated_embeddings.shape) +# print(replication_decisions.shape) +# print(edge_weights.shape) + + + +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import datasets, transforms + + + +# Define the training function +def train(model, train_loader, optimizer, criterion): + model.train() + + for batch_idx, (data, target) in enumerate(train_loader): + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + +# Set hyperparameters embedding_dim = 16 hidden_dim = 32 +learning_rate = 0.001 +batch_size = 64 +epochs = 10 + +# Load MNIST dataset +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) +]) + +train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) +train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + +# Initialize the model, optimizer, and loss function +model = NDP(embedding_dim, hidden_dim) +optimizer = optim.Adam(model.parameters(), lr=learning_rate) +criterion = nn.CrossEntropyLoss() + +# Training loop +for epoch in range(epochs): + train(model, train_loader, optimizer, criterion) + print(f"Epoch {epoch+1}/{epochs} completed") + +# Usage examples node_embeddings = torch.rand((10, embedding_dim)) # For 10 nodes adjacency_matrix = torch.rand((10, 10)) # Dummy adjacency matrix for 10 nodes -model = NDP(embedding_dim, hidden_dim) updated_embeddings, replication_decisions, edge_weights = model(node_embeddings, adjacency_matrix) print(updated_embeddings.shape) print(replication_decisions.shape) -print(edge_weights.shape) +print(edge_weights.shape) \ No newline at end of file