diff --git a/.github/workflows/black.yaml b/.github/workflows/black.yaml index 3747c79..6eaffe9 100644 --- a/.github/workflows/black.yaml +++ b/.github/workflows/black.yaml @@ -1,6 +1,6 @@ name: Enforce coding style -on: [push, pull_request] +on: [ push, pull_request ] jobs: lint: diff --git a/experiment_gnn.ipynb b/experiment_gnn.ipynb index 0adda3c..81b2f67 100644 --- a/experiment_gnn.ipynb +++ b/experiment_gnn.ipynb @@ -16,7 +16,11 @@ "import torchvision.transforms as transforms\n", "\n", "from mantra.simplicial import SimplicialDataset\n", - "from mantra.transforms import TriangulationToFaceTransform, OrientableToClassTransform, DegreeTransform\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", "from validation.validate_homology import validate_betti_numbers\n", "\n", "import torch" @@ -69,37 +73,46 @@ ], "source": [ "tr = transforms.Compose(\n", - " [TriangulationToFaceTransform(),FaceToEdge(remove_faces=False),DegreeTransform(),OrientableToClassTransform()]\n", - " )\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " ]\n", + ")\n", "\n", - "dataset = SimplicialDataset(root=\"./data\",transform=tr)\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n", "\n", "\n", "print()\n", - "print(f'Dataset: {dataset}:')\n", - "print('====================')\n", - "print(f'Number of graphs: {len(dataset)}')\n", - "print(f'Number of features: {dataset.num_features}')\n", - "print(f'Number of classes: {dataset.num_classes}')\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "print(f\"Number of classes: {dataset.num_classes}\")\n", "\n", "data = dataset[0] # Get the first graph object.\n", "\n", "print()\n", "print(data)\n", - "print('=============================================================')\n", + "print(\"=============================================================\")\n", "\n", "# Gather some statistics about the first graph.\n", - "print(f'Number of nodes: {len(data.x)}')\n", - "print(f'Number of edges: {data.num_edges}')\n", - "print(f'Average node degree: {data.num_edges / len(data.x):.2f}')\n", - "print(f'Has isolated nodes: {data.has_isolated_nodes()}')\n", - "print(f'Has self-loops: {data.has_self_loops()}')\n", - "print(f'Is undirected: {data.is_undirected()}')\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", "\n", - "print('=============================================================')\n", - "print(f'Number of orientable Manifolds: {sum(dataset.orientable)}')\n", - "print(f'Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}')\n", - "print(f'Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}')" + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" ] }, { @@ -117,15 +130,13 @@ } ], "source": [ - "\n", - "\n", "dataset = dataset.shuffle()\n", "\n", "train_dataset = dataset[:150]\n", "test_dataset = dataset[150:]\n", "\n", - "print(f'Number of training graphs: {len(train_dataset)}')\n", - "print(f'Number of test graphs: {len(test_dataset)}')" + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")" ] }, { @@ -135,8 +146,7 @@ "outputs": [], "source": [ "train_loader = DataLoader(train_dataset)\n", - "test_loader = DataLoader(test_dataset)\n", - "\n" + "test_loader = DataLoader(test_dataset)" ] }, { @@ -174,7 +184,7 @@ " self.lin = Linear(hidden_channels, dataset.num_classes)\n", "\n", " def forward(self, x, edge_index, batch):\n", - " # 1. Obtain node embeddings \n", + " # 1. Obtain node embeddings\n", " x = self.conv1(x, edge_index)\n", " x = x.relu()\n", " x = self.conv2(x, edge_index)\n", @@ -187,9 +197,10 @@ " # 3. Apply a final classifier\n", " x = F.dropout(x, p=0.5, training=self.training)\n", " x = self.lin(x)\n", - " \n", + "\n", " return x\n", "\n", + "\n", "model = GCN(hidden_channels=64)\n", "print(model)" ] @@ -306,19 +317,19 @@ "evalue": "", "output_type": "error", "traceback": [ - "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[1;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)", - "Cell \u001B[1;32mIn[77], line 29\u001B[0m\n\u001B[0;32m 27\u001B[0m train()\n\u001B[0;32m 28\u001B[0m train_acc \u001B[38;5;241m=\u001B[39m test(train_loader)\n\u001B[1;32m---> 29\u001B[0m test_acc \u001B[38;5;241m=\u001B[39m \u001B[43mtest\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtest_loader\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 30\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mEpoch: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mepoch\u001B[38;5;132;01m:\u001B[39;00m\u001B[38;5;124m03d\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m, Train Acc: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtrain_acc\u001B[38;5;132;01m:\u001B[39;00m\u001B[38;5;124m.4f\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m, Test Acc: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtest_acc\u001B[38;5;132;01m:\u001B[39;00m\u001B[38;5;124m.4f\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m'\u001B[39m)\n", - "Cell \u001B[1;32mIn[77], line 20\u001B[0m, in \u001B[0;36mtest\u001B[1;34m(loader)\u001B[0m\n\u001B[0;32m 18\u001B[0m correct \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[0;32m 19\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m data \u001B[38;5;129;01min\u001B[39;00m loader: \u001B[38;5;66;03m# Iterate in batches over the training/test dataset.\u001B[39;00m\n\u001B[1;32m---> 20\u001B[0m out \u001B[38;5;241m=\u001B[39m \u001B[43mmodel\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdata\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdata\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43medge_index\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdata\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbatch\u001B[49m\u001B[43m)\u001B[49m \n\u001B[0;32m 21\u001B[0m pred \u001B[38;5;241m=\u001B[39m out\u001B[38;5;241m.\u001B[39margmax(dim\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1\u001B[39m) \u001B[38;5;66;03m# Use the class with highest probability.\u001B[39;00m\n\u001B[0;32m 22\u001B[0m correct \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;28mint\u001B[39m((pred \u001B[38;5;241m==\u001B[39m data\u001B[38;5;241m.\u001B[39my)\u001B[38;5;241m.\u001B[39msum()) \u001B[38;5;66;03m# Check against ground-truth labels.\u001B[39;00m\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1516\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1517\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1518\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1522\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1523\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1524\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1525\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1526\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1527\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m forward_call(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 1529\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m 1530\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n", - "Cell \u001B[1;32mIn[76], line 18\u001B[0m, in \u001B[0;36mGCN.forward\u001B[1;34m(self, x, edge_index, batch)\u001B[0m\n\u001B[0;32m 16\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\u001B[38;5;28mself\u001B[39m, x, edge_index, batch):\n\u001B[0;32m 17\u001B[0m \u001B[38;5;66;03m# 1. Obtain node embeddings \u001B[39;00m\n\u001B[1;32m---> 18\u001B[0m x \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconv1\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43medge_index\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 19\u001B[0m x \u001B[38;5;241m=\u001B[39m x\u001B[38;5;241m.\u001B[39mrelu()\n\u001B[0;32m 20\u001B[0m x \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mconv2(x, edge_index)\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1516\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1517\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1518\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1522\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1523\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1524\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1525\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1526\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1527\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m forward_call(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 1529\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m 1530\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\conv\\gcn_conv.py:222\u001B[0m, in \u001B[0;36mGCNConv.forward\u001B[1;34m(self, x, edge_index, edge_weight)\u001B[0m\n\u001B[0;32m 220\u001B[0m cache \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_cached_edge_index\n\u001B[0;32m 221\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m cache \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m--> 222\u001B[0m edge_index, edge_weight \u001B[38;5;241m=\u001B[39m \u001B[43mgcn_norm\u001B[49m\u001B[43m(\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;66;43;03m# yapf: disable\u001B[39;49;00m\n\u001B[0;32m 223\u001B[0m \u001B[43m \u001B[49m\u001B[43medge_index\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43medge_weight\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mx\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msize\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mnode_dim\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 224\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mimproved\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43madd_self_loops\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mflow\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mx\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdtype\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 225\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcached:\n\u001B[0;32m 226\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_cached_edge_index \u001B[38;5;241m=\u001B[39m (edge_index, edge_weight)\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\conv\\gcn_conv.py:91\u001B[0m, in \u001B[0;36mgcn_norm\u001B[1;34m(edge_index, edge_weight, num_nodes, improved, add_self_loops, flow, dtype)\u001B[0m\n\u001B[0;32m 88\u001B[0m num_nodes \u001B[38;5;241m=\u001B[39m maybe_num_nodes(edge_index, num_nodes)\n\u001B[0;32m 90\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m add_self_loops:\n\u001B[1;32m---> 91\u001B[0m edge_index, edge_weight \u001B[38;5;241m=\u001B[39m \u001B[43madd_remaining_self_loops\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 92\u001B[0m \u001B[43m \u001B[49m\u001B[43medge_index\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43medge_weight\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfill_value\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnum_nodes\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 94\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m edge_weight \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m 95\u001B[0m edge_weight \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mones((edge_index\u001B[38;5;241m.\u001B[39msize(\u001B[38;5;241m1\u001B[39m), ), dtype\u001B[38;5;241m=\u001B[39mdtype,\n\u001B[0;32m 96\u001B[0m device\u001B[38;5;241m=\u001B[39medge_index\u001B[38;5;241m.\u001B[39mdevice)\n", - "File \u001B[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\utils\\loop.py:342\u001B[0m, in \u001B[0;36madd_remaining_self_loops\u001B[1;34m(edge_index, edge_attr, fill_value, num_nodes)\u001B[0m\n\u001B[0;32m 339\u001B[0m N \u001B[38;5;241m=\u001B[39m maybe_num_nodes(edge_index, num_nodes)\n\u001B[0;32m 340\u001B[0m mask \u001B[38;5;241m=\u001B[39m edge_index[\u001B[38;5;241m0\u001B[39m] \u001B[38;5;241m!=\u001B[39m edge_index[\u001B[38;5;241m1\u001B[39m]\n\u001B[1;32m--> 342\u001B[0m loop_index \u001B[38;5;241m=\u001B[39m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43marange\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mN\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdtype\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mlong\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdevice\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43medge_index\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdevice\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 343\u001B[0m loop_index \u001B[38;5;241m=\u001B[39m loop_index\u001B[38;5;241m.\u001B[39munsqueeze(\u001B[38;5;241m0\u001B[39m)\u001B[38;5;241m.\u001B[39mrepeat(\u001B[38;5;241m2\u001B[39m, \u001B[38;5;241m1\u001B[39m)\n\u001B[0;32m 345\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m edge_attr \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n", - "\u001B[1;31mKeyboardInterrupt\u001B[0m: " + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[77], line 29\u001b[0m\n\u001b[0;32m 27\u001b[0m train()\n\u001b[0;32m 28\u001b[0m train_acc \u001b[38;5;241m=\u001b[39m test(train_loader)\n\u001b[1;32m---> 29\u001b[0m test_acc \u001b[38;5;241m=\u001b[39m \u001b[43mtest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_loader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 30\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mEpoch: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m03d\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Train Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Test Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n", + "Cell \u001b[1;32mIn[77], line 20\u001b[0m, in \u001b[0;36mtest\u001b[1;34m(loader)\u001b[0m\n\u001b[0;32m 18\u001b[0m correct \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m data \u001b[38;5;129;01min\u001b[39;00m loader: \u001b[38;5;66;03m# Iterate in batches over the training/test dataset.\u001b[39;00m\n\u001b[1;32m---> 20\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m \n\u001b[0;32m 21\u001b[0m pred \u001b[38;5;241m=\u001b[39m out\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# Use the class with highest probability.\u001b[39;00m\n\u001b[0;32m 22\u001b[0m correct \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m((pred \u001b[38;5;241m==\u001b[39m data\u001b[38;5;241m.\u001b[39my)\u001b[38;5;241m.\u001b[39msum()) \u001b[38;5;66;03m# Check against ground-truth labels.\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "Cell \u001b[1;32mIn[76], line 18\u001b[0m, in \u001b[0;36mGCN.forward\u001b[1;34m(self, x, edge_index, batch)\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, edge_index, batch):\n\u001b[0;32m 17\u001b[0m \u001b[38;5;66;03m# 1. Obtain node embeddings \u001b[39;00m\n\u001b[1;32m---> 18\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 19\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mrelu()\n\u001b[0;32m 20\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv2(x, edge_index)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\conv\\gcn_conv.py:222\u001b[0m, in \u001b[0;36mGCNConv.forward\u001b[1;34m(self, x, edge_index, edge_weight)\u001b[0m\n\u001b[0;32m 220\u001b[0m cache \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cached_edge_index\n\u001b[0;32m 221\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cache \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m--> 222\u001b[0m edge_index, edge_weight \u001b[38;5;241m=\u001b[39m \u001b[43mgcn_norm\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# yapf: disable\u001b[39;49;00m\n\u001b[0;32m 223\u001b[0m \u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode_dim\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimproved\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_self_loops\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mflow\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 225\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcached:\n\u001b[0;32m 226\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cached_edge_index \u001b[38;5;241m=\u001b[39m (edge_index, edge_weight)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\conv\\gcn_conv.py:91\u001b[0m, in \u001b[0;36mgcn_norm\u001b[1;34m(edge_index, edge_weight, num_nodes, improved, add_self_loops, flow, dtype)\u001b[0m\n\u001b[0;32m 88\u001b[0m num_nodes \u001b[38;5;241m=\u001b[39m maybe_num_nodes(edge_index, num_nodes)\n\u001b[0;32m 90\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m add_self_loops:\n\u001b[1;32m---> 91\u001b[0m edge_index, edge_weight \u001b[38;5;241m=\u001b[39m \u001b[43madd_remaining_self_loops\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 92\u001b[0m \u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_nodes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 94\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m edge_weight \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 95\u001b[0m edge_weight \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mones((edge_index\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m1\u001b[39m), ), dtype\u001b[38;5;241m=\u001b[39mdtype,\n\u001b[0;32m 96\u001b[0m device\u001b[38;5;241m=\u001b[39medge_index\u001b[38;5;241m.\u001b[39mdevice)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\utils\\loop.py:342\u001b[0m, in \u001b[0;36madd_remaining_self_loops\u001b[1;34m(edge_index, edge_attr, fill_value, num_nodes)\u001b[0m\n\u001b[0;32m 339\u001b[0m N \u001b[38;5;241m=\u001b[39m maybe_num_nodes(edge_index, num_nodes)\n\u001b[0;32m 340\u001b[0m mask \u001b[38;5;241m=\u001b[39m edge_index[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m!=\u001b[39m edge_index[\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m--> 342\u001b[0m loop_index \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marange\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mN\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlong\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 343\u001b[0m loop_index \u001b[38;5;241m=\u001b[39m loop_index\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mrepeat(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m 345\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m edge_attr \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], @@ -327,32 +338,42 @@ "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", "criterion = torch.nn.CrossEntropyLoss()\n", "\n", + "\n", "def train():\n", " model.train()\n", "\n", " for data in train_loader: # Iterate in batches over the training dataset.\n", - " out = model(data.x, data.edge_index, data.batch) # Perform a single forward pass.\n", - " loss = criterion(out, data.y) # Compute the loss.\n", - " loss.backward() # Derive gradients.\n", - " optimizer.step() # Update parameters based on gradients.\n", - " optimizer.zero_grad() # Clear gradients.\n", + " out = model(\n", + " data.x, data.edge_index, data.batch\n", + " ) # Perform a single forward pass.\n", + " loss = criterion(out, data.y) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", "\n", "def test(loader):\n", - " model.eval()\n", + " model.eval()\n", "\n", - " correct = 0\n", - " for data in loader: # Iterate in batches over the training/test dataset.\n", - " out = model(data.x, data.edge_index, data.batch) \n", - " pred = out.argmax(dim=1) # Use the class with highest probability.\n", - " correct += int((pred == data.y).sum()) # Check against ground-truth labels.\n", - " return correct / len(loader.dataset) # Derive ratio of correct predictions.\n", + " correct = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data.x, data.edge_index, data.batch)\n", + " pred = out.argmax(dim=1) # Use the class with highest probability.\n", + " correct += int(\n", + " (pred == data.y).sum()\n", + " ) # Check against ground-truth labels.\n", + " return correct / len(\n", + " loader.dataset\n", + " ) # Derive ratio of correct predictions.\n", "\n", "\n", "for epoch in range(1, 171):\n", " train()\n", " train_acc = test(train_loader)\n", " test_acc = test(test_loader)\n", - " print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')" + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" ] } ], diff --git a/experiment_mlp.ipynb b/experiment_mlp.ipynb new file mode 100644 index 0000000..f3dfd64 --- /dev/null +++ b/experiment_mlp.ipynb @@ -0,0 +1,461 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader, ImbalancedSampler\n", + "from torch_geometric.transforms import FaceToEdge, OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 9\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\Ernst\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'name', 'face', 'orientable', 'torsion_coefficients', 'dimension', 'betti_numbers', 'genus', 'n_vertices'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of classes: 2\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 9], y=[1])\n", + "=============================================================\n", + "Number of nodes: 4\n", + "Number of edges: 12\n", + "Average node degree: 3.00\n", + "Has isolated nodes: False\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + } + ], + "source": [ + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " OneHotDegree(max_degree=8,cat=False)\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n", + "\n", + "\n", + "print()\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "print(f\"Number of classes: {dataset.num_classes}\")\n", + "\n", + "data = dataset[0] # Get the first graph object.\n", + "\n", + "print()\n", + "print(data)\n", + "print(\"=============================================================\")\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", + "\n", + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\Ernst\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'name', 'face', 'orientable', 'torsion_coefficients', 'dimension', 'betti_numbers', 'genus', 'n_vertices'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([ True, True, True, False, True, True, True, False, True, False,\n", + " False, True, True, True, True, True, True, False, True, False,\n", + " False, False, True, False, False, False, False, False, False, False,\n", + " False, False, False, False, True, True, True, True, True, False,\n", + " False, False, True, True, True, False, True, True, True, True,\n", + " False, False, True, False, True, True, True, True, True, True,\n", + " True, False, True, True, True, True, False, False, False, False,\n", + " True, True, True, False, False, False, False, False, False, False,\n", + " True, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, True, True,\n", + " True, False, False, False, False, False, False, True, False, False,\n", + " True, True, True, False, False, True, False, False, False, False,\n", + " False, False, False, False, False, True, True, False, False, False,\n", + " True, False, False, False, False, False, False, True, False, False,\n", + " False, False, False, True, True, False, True, True, True, True,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, True, False, False, False, False, False, False,\n", + " False, False, True, True, True, False, False, True, True, True,\n", + " True, True, False, False, True, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " True, True, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " True, True, True, True, True, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, True, True,\n", + " True, False, False, False, True, False, False, False, True, False,\n", + " False, False, True, False, False, False, False, False, False, False,\n", + " False, True, True, True, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, True, False, False, True, False, False, True, True,\n", + " True, False, False, False, False, False, False, True, True, False,\n", + " False, True, False, False, False, False, False, False, False, False,\n", + " False, True, True, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, True, True, True, True, True,\n", + " True, False, False, False, False, False, True, True, True, True,\n", + " True, False, False, False, True, True, True, True, True, True,\n", + " True, True, False, False, True, False, False, True, True, True,\n", + " False, True, True, True, True, False, False, False, False, True,\n", + " False, False, False, False, False, False, False, False, True, True,\n", + " True, True, True, True, True, True, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " True, False, False, False, False, False, False, True, True, False,\n", + " False, False, False, False, False, False, True, True, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, True, True, True, False, False, False, False, False,\n", + " True, True, True, False, False, True, True, True, True, True,\n", + " False, False, False, True, True, True, True, True, True, True,\n", + " True, False, False, False, False, False, True, False, True, False,\n", + " True, False, False, False, True, True, True, False, False, False,\n", + " False, False, False, False, False, True, False, False, False, True,\n", + " True, False, False, False, False, True, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, True, True,\n", + " True, True, True, True, True, True, True, False, True, False,\n", + " False, False, False, False, False, True, True, True, True, False,\n", + " True, True])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.tensor([data.y for data in dataset])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\Ernst\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'betti_numbers', 'face', 'n_vertices', 'name', 'dimension', 'orientable', 'torsion_coefficients', 'genus'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([1, 1, 0, 0, 1, 1, 0, 0, 0, 1])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader(train_dataset,batch_size=10)#,sampler=ImbalancedSampler(train_dataset))\n", + "test_loader = DataLoader(test_dataset,batch_size=10)\n", + "\n", + "\n", + "for batch in train_loader:\n", + " break\n", + "\n", + "batch.y" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "import torch.nn as nn\n", + "from torch_geometric.nn import GCNConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "from torch_scatter import segment_coo\n", + "\n", + "class PermInvariant(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super().__init__()\n", + " # torch.manual_seed(12345)\n", + " self.classification = nn.Sequential( \n", + " nn.Linear(9,hidden_channels),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_channels,hidden_channels),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_channels,2),\n", + " nn.ReLU()\n", + " )\n", + "\n", + " def forward(self, batch):\n", + " x = self.classification(batch.x)\n", + " # print(batch.x)\n", + " # print(x)\n", + " return segment_coo(x,batch.batch,reduce=\"sum\")\n", + "\n", + "\n", + "model = PermInvariant(hidden_channels=64)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 0.7313, Test Acc: 0.7200\n", + "Epoch: 002, Train Acc: 0.7384, Test Acc: 0.7333\n", + "Epoch: 003, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 004, Train Acc: 0.8025, Test Acc: 0.7800\n", + "Epoch: 005, Train Acc: 0.7972, Test Acc: 0.7733\n", + "Epoch: 006, Train Acc: 0.7972, Test Acc: 0.7733\n", + "Epoch: 007, Train Acc: 0.7900, Test Acc: 0.7667\n", + "Epoch: 008, Train Acc: 0.7918, Test Acc: 0.7667\n", + "Epoch: 009, Train Acc: 0.7865, Test Acc: 0.7667\n", + "Epoch: 010, Train Acc: 0.7794, Test Acc: 0.7600\n", + "Epoch: 011, Train Acc: 0.7794, Test Acc: 0.7600\n", + "Epoch: 012, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 013, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 014, Train Acc: 0.7794, Test Acc: 0.7600\n", + "Epoch: 015, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 016, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 017, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 018, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 019, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 020, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 021, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 022, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 023, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 024, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 025, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 026, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 027, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 028, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 029, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 030, Train Acc: 0.7829, Test Acc: 0.7667\n", + "Epoch: 031, Train Acc: 0.7829, Test Acc: 0.7667\n", + "Epoch: 032, Train Acc: 0.7811, Test Acc: 0.7600\n", + "Epoch: 033, Train Acc: 0.7829, Test Acc: 0.7667\n", + "Epoch: 034, Train Acc: 0.7829, Test Acc: 0.7667\n", + "Epoch: 035, Train Acc: 0.7829, Test Acc: 0.7667\n", + "Epoch: 036, Train Acc: 0.7829, Test Acc: 0.7667\n", + "Epoch: 037, Train Acc: 0.7829, Test Acc: 0.7667\n", + "Epoch: 038, Train Acc: 0.7829, Test Acc: 0.7667\n", + "Epoch: 039, Train Acc: 0.7829, Test Acc: 0.7667\n" + ] + } + ], + "source": [ + "model = PermInvariant(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(data\n", + " ) # Perform a single forward pass.\n", + " loss = criterion(out, data.y) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " pred = out.argmax(dim=1) # Use the class with highest probability.\n", + " correct += int(\n", + " (pred == data.y).sum()\n", + " ) # Check against ground-truth labels.\n", + " return correct / len(\n", + " loader.dataset\n", + " ) # Derive ratio of correct predictions.\n", + "\n", + "\n", + "for epoch in range(1, 40):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.9635, 0.8106, 0.9215, 0.9906, 1.5490, 1.0627, 1.0325, 1.2307, 1.2846,\n", + " 2.0083], grad_fn=)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(batch)[:,0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_scconv.ipynb b/experiment_scconv.ipynb deleted file mode 100644 index 418e6b9..0000000 --- a/experiment_scconv.ipynb +++ /dev/null @@ -1,490 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "outputs": [], - "source": [ - "import math\n", - "%load_ext autoreload\n", - "%autoreload 2" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-04-26T17:24:32.434796Z", - "start_time": "2024-04-26T17:24:32.392352Z" - } - }, - "id": "e1447e3b250fa124", - "execution_count": 1 - }, - { - "cell_type": "code", - "outputs": [], - "source": [ - "from mantra.simplicial import SimplicialDataset\n", - "import torch\n", - "from torch import nn\n", - "import torchmetrics\n", - "import torchvision.transforms as transforms\n", - "from mantra.transforms import SimplicialComplexTransform\n", - "from mantra.dataloaders import SimplicialDataLoader\n", - "from topomodelx.nn.simplicial.scnn import SCNN\n", - "from torch_geometric.nn import pool\n", - "import lightning as L\n", - "from typing import Literal\n", - "from torch.utils.data import random_split\n", - "import math\n", - "from mantra.utils import transfer_simplicial_complex_batch_to_device" - ], - "metadata": { - "collapsed": true, - "ExecuteTime": { - "end_time": "2024-04-26T17:39:52.886233Z", - "start_time": "2024-04-26T17:39:52.773509Z" - } - }, - "id": "initial_id", - "execution_count": 40 - }, - { - "cell_type": "code", - "outputs": [], - "source": [ - "from mantra.transforms import SimplicialComplexDegreeTransform, SimplicialComplexEdgeCoadjacencyDegreeTransform, \\\n", - " SimplicialComplexEdgeAdjacencyDegreeTransform, SimplicialComplexTriangleCoadjacencyDegreeTransform, \\\n", - " OrientableToClassSimplicialComplexTransform, DimTwoHodgeLaplacianSimplicialComplexTransform, \\\n", - " DimOneHodgeLaplacianDownSimplicialComplexTransform, DimOneHodgeLaplacianUpSimplicialComplexTransform, \\\n", - " DimZeroHodgeLaplacianSimplicialComplexTransform\n", - "\n", - "tr = transforms.Compose(\n", - " [SimplicialComplexTransform(), \n", - " SimplicialComplexDegreeTransform(),\n", - " SimplicialComplexEdgeCoadjacencyDegreeTransform(),\n", - " SimplicialComplexEdgeAdjacencyDegreeTransform(),\n", - " SimplicialComplexTriangleCoadjacencyDegreeTransform(),\n", - " DimZeroHodgeLaplacianSimplicialComplexTransform(),\n", - " DimOneHodgeLaplacianUpSimplicialComplexTransform(),\n", - " DimOneHodgeLaplacianDownSimplicialComplexTransform(),\n", - " DimTwoHodgeLaplacianSimplicialComplexTransform(),\n", - " OrientableToClassSimplicialComplexTransform()]\n", - " )" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-04-26T17:24:42.521090Z", - "start_time": "2024-04-26T17:24:42.398648Z" - } - }, - "id": "258c13ed23fc24a9", - "execution_count": 3 - }, - { - "cell_type": "code", - "outputs": [], - "source": [ - "dataset = SimplicialDataset(root=\"./data\", transform=tr)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-04-26T17:25:08.895803Z", - "start_time": "2024-04-26T17:25:08.759940Z" - } - }, - "id": "b7a030f18bfe7035", - "execution_count": 4 - }, - { - "cell_type": "markdown", - "source": [ - "## Train the Neural Network\n", - "We specify the model with our pre-made neighborhood structures and specify an optimizer." - ], - "metadata": { - "collapsed": false - }, - "id": "d9f15969d862d021" - }, - { - "cell_type": "code", - "outputs": [], - "source": [ - "rank = 1 # simplex level. We'll use the features of the rank-simplices.\n", - "conv_order_down = 2 # TODO: No idea of what this parameter does\n", - "conv_order_up = 2 # TODO: No idea of what this parameter does\n", - "hidden_channels = 4\n", - "out_channels = 1 # num classes\n", - "num_layers = 3\n", - "# Check the rank has an appropriate value.\n", - "assert 0 <= rank <= 2, \"rank must be 0, 1 or 2.\"\n", - "# select the simplex level\n", - "if rank == 0:\n", - " conv_order_down = 0\n", - "# configure parameters\n", - "in_channels = dataset[0].x[rank].shape[1]" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-04-26T17:25:11.543336Z", - "start_time": "2024-04-26T17:25:11.414486Z" - } - }, - "id": "4253885c6d26b76e", - "execution_count": 5 - }, - { - "cell_type": "code", - "outputs": [], - "source": [ - "class SCNNNetwork(L.LightningModule):\n", - " def __init__(self, rank, in_channels, hidden_channels, out_channels,\n", - " conv_order_down, conv_order_up, n_layers=3):\n", - " super().__init__()\n", - " self.rank = rank\n", - " self.base_model = SCNN(in_channels=in_channels,\n", - " hidden_channels=hidden_channels,\n", - " conv_order_down=conv_order_down, \n", - " conv_order_up=conv_order_up, \n", - " n_layers=n_layers)\n", - " self.liner_readout = torch.nn.Linear(hidden_channels, out_channels)\n", - " # Accuracy metrics\n", - " self.training_accuracy = torchmetrics.classification.BinaryAccuracy()\n", - " self.validation_accuracy = torchmetrics.classification.BinaryAccuracy()\n", - " self.test_accuracy = torchmetrics.classification.BinaryAccuracy()\n", - " \n", - " def forward(self, x, laplacian_down, laplacian_up, signal_belongings):\n", - " x = self.base_model(x, laplacian_down, laplacian_up)\n", - " x = self.liner_readout(x)\n", - " x_mean = pool.global_mean_pool(x, signal_belongings)\n", - " x_mean[torch.isnan(x_mean)] = 0\n", - " return x_mean\n", - " \n", - " def transfer_batch_to_device(self, batch, device, dataloader_idx):\n", - " return transfer_simplicial_complex_batch_to_device(batch, device, dataloader_idx)\n", - " \n", - " def general_step(self, batch, batch_idx, step: Literal['train', 'test', 'validation']):\n", - " s_complexes, signal_belongings, batch_len = batch\n", - " x = s_complexes.signals[self.rank]\n", - " if rank == 0:\n", - " laplacian_down = None\n", - " laplacian_up = s_complexes.neighborhood_matrices[f'0_laplacian']\n", - " elif rank == 1:\n", - " laplacian_down = s_complexes.neighborhood_matrices[f'1_laplacian_down']\n", - " laplacian_up = s_complexes.neighborhood_matrices[f'1_laplacian_up']\n", - " elif rank == 2:\n", - " laplacian_down = s_complexes.neighborhood_matrices[f'2_laplacian']\n", - " laplacian_up = None\n", - " else:\n", - " raise ValueError(\"rank must be 0, 1 or 2.\")\n", - " y = s_complexes.other_features['y'].float()\n", - " signal_belongings = signal_belongings[self.rank]\n", - " x_hat = self(x, laplacian_down, laplacian_up, signal_belongings)\n", - " # Squeeze x_hat to match the shape of y\n", - " x_hat = x_hat.squeeze()\n", - " loss = nn.functional.binary_cross_entropy_with_logits(x_hat, y)\n", - " self.log('train_loss', loss, prog_bar=True, batch_size=batch_len, on_step=False, on_epoch=True)\n", - " self.log_accuracies(x_hat, y, batch_len, step)\n", - " return loss\n", - " \n", - " def log_accuracies(self, x_hat, y, batch_len, step: Literal['train', 'test', 'validation']):\n", - " # Apply the sigmoid function to x_hat to get the probabilities\n", - " x_hat = torch.sigmoid(x_hat)\n", - " if step == 'train':\n", - " self.training_accuracy(x_hat, y)\n", - " self.log('train_accuracy', self.training_accuracy, prog_bar=True, on_step=False, on_epoch=True, batch_size=batch_len)\n", - " elif step == 'test':\n", - " self.test_accuracy(x_hat, y)\n", - " self.log('test_accuracy', self.test_accuracy, prog_bar=True, on_step=False, on_epoch=True, batch_size=batch_len)\n", - " elif step == 'validation':\n", - " self.validation_accuracy(x_hat, y)\n", - " self.log('validation_accuracy', self.validation_accuracy, prog_bar=True, on_step=False, on_epoch=True, batch_size=batch_len)\n", - " \n", - " def test_step(self, batch, batch_idx):\n", - " return self.general_step(batch, batch_idx, 'test')\n", - " \n", - " def validation_step(self, batch, batch_idx):\n", - " return self.general_step(batch, batch_idx, 'validation')\n", - " \n", - " def training_step(self, batch, batch_idx):\n", - " return self.general_step(batch, batch_idx, 'train')\n", - " \n", - " def configure_optimizers(self):\n", - " optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", - " return optimizer\n", - " " - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-04-26T17:52:24.975079Z", - "start_time": "2024-04-26T17:52:24.844516Z" - } - }, - "id": "d8a128eef9a65999", - "execution_count": 57 - }, - { - "cell_type": "code", - "outputs": [], - "source": [ - "model = SCNNNetwork(rank=rank,\n", - " in_channels=in_channels,\n", - " hidden_channels=hidden_channels,\n", - " out_channels=out_channels,\n", - " conv_order_down=conv_order_down,\n", - " conv_order_up=conv_order_up,\n", - " n_layers=num_layers)\n", - "\n", - "loss_fn = torch.nn.MSELoss()" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-04-26T17:52:25.881301Z", - "start_time": "2024-04-26T17:52:25.759380Z" - } - }, - "id": "490ce8ad6e1d4ccf", - "execution_count": 58 - }, - { - "cell_type": "code", - "outputs": [], - "source": [ - "# Split the dataset\n", - "test_percentage = 0.2\n", - "batch_size = 16\n", - "test_len = math.floor(len(dataset) * test_percentage)\n", - "train_ds, test_ds = random_split(dataset, [len(dataset) - test_len, test_len])\n", - "train_dl = SimplicialDataLoader(train_ds, batch_size=batch_size, shuffle=True)\n", - "test_dl = SimplicialDataLoader(test_ds, batch_size=batch_size, shuffle=False)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-04-26T17:52:27.137037Z", - "start_time": "2024-04-26T17:52:27.021248Z" - } - }, - "id": "bc2adabedffd109a", - "execution_count": 59 - }, - { - "cell_type": "code", - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: True (cuda), used: False\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "\n", - " | Name | Type | Params\n", - "-------------------------------------------------------\n", - "0 | base_model | SCNN | 200 \n", - "1 | liner_readout | Linear | 5 \n", - "2 | training_accuracy | BinaryAccuracy | 0 \n", - "3 | validation_accuracy | BinaryAccuracy | 0 \n", - "4 | test_accuracy | BinaryAccuracy | 0 \n", - "-------------------------------------------------------\n", - "205 Trainable params\n", - "0 Non-trainable params\n", - "205 Total params\n", - "0.001 Total estimated model params size (MB)\n" - ] - }, - { - "data": { - "text/plain": "Sanity Checking: | | 0/? [00:00 29\u001b[0m test_acc \u001b[38;5;241m=\u001b[39m \u001b[43mtest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_loader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 30\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mEpoch: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m03d\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Train Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Test Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n", - "Cell \u001b[1;32mIn[21], line 20\u001b[0m, in \u001b[0;36mtest\u001b[1;34m(loader)\u001b[0m\n\u001b[0;32m 18\u001b[0m correct \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m data \u001b[38;5;129;01min\u001b[39;00m loader: \u001b[38;5;66;03m# Iterate in batches over the training/test dataset.\u001b[39;00m\n\u001b[1;32m---> 20\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_attr\u001b[49m\u001b[43m,\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m \n\u001b[0;32m 21\u001b[0m pred \u001b[38;5;241m=\u001b[39m out\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# Use the class with highest probability.\u001b[39;00m\n\u001b[0;32m 22\u001b[0m correct \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m((pred \u001b[38;5;241m==\u001b[39m data\u001b[38;5;241m.\u001b[39my)\u001b[38;5;241m.\u001b[39msum()) \u001b[38;5;66;03m# Check against ground-truth labels.\u001b[39;00m\n", + "Cell \u001b[1;32mIn[10], line 29\u001b[0m\n\u001b[0;32m 27\u001b[0m train()\n\u001b[0;32m 28\u001b[0m train_acc \u001b[38;5;241m=\u001b[39m test(train_loader)\n\u001b[1;32m---> 29\u001b[0m test_acc \u001b[38;5;241m=\u001b[39m \u001b[43mtest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_loader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 30\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mEpoch: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m03d\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Train Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Test Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n", + "Cell \u001b[1;32mIn[10], line 20\u001b[0m, in \u001b[0;36mtest\u001b[1;34m(loader)\u001b[0m\n\u001b[0;32m 18\u001b[0m correct \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m data \u001b[38;5;129;01min\u001b[39;00m loader: \u001b[38;5;66;03m# Iterate in batches over the training/test dataset.\u001b[39;00m\n\u001b[1;32m---> 20\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_attr\u001b[49m\u001b[43m,\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m \n\u001b[0;32m 21\u001b[0m pred \u001b[38;5;241m=\u001b[39m out\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# Use the class with highest probability.\u001b[39;00m\n\u001b[0;32m 22\u001b[0m correct \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m((pred \u001b[38;5;241m==\u001b[39m data\u001b[38;5;241m.\u001b[39my)\u001b[38;5;241m.\u001b[39msum()) \u001b[38;5;66;03m# Check against ground-truth labels.\u001b[39;00m\n", "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "Cell \u001b[1;32mIn[20], line 14\u001b[0m, in \u001b[0;36mGCN.forward\u001b[1;34m(self, x, edge_index, edge_attr, batch)\u001b[0m\n\u001b[0;32m 12\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv1(x, edge_index,edge_attr)\n\u001b[0;32m 13\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mrelu()\n\u001b[1;32m---> 14\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43medge_attr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 15\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mrelu()\n\u001b[0;32m 16\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv3(x, edge_index,edge_attr)\n", + "Cell \u001b[1;32mIn[7], line 14\u001b[0m, in \u001b[0;36mGCN.forward\u001b[1;34m(self, x, edge_index, edge_attr, batch)\u001b[0m\n\u001b[0;32m 12\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv1(x, edge_index,edge_attr)\n\u001b[0;32m 13\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mrelu()\n\u001b[1;32m---> 14\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43medge_attr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 15\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mrelu()\n\u001b[0;32m 16\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv3(x, edge_index,edge_attr)\n", "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\conv\\gat_conv.py:324\u001b[0m, in \u001b[0;36mGATConv.forward\u001b[1;34m(self, x, edge_index, edge_attr, size, return_attention_weights)\u001b[0m\n\u001b[0;32m 321\u001b[0m num_nodes \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmin\u001b[39m(size) \u001b[38;5;28;01mif\u001b[39;00m size \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m num_nodes\n\u001b[0;32m 322\u001b[0m edge_index, edge_attr \u001b[38;5;241m=\u001b[39m remove_self_loops(\n\u001b[0;32m 323\u001b[0m edge_index, edge_attr)\n\u001b[1;32m--> 324\u001b[0m edge_index, edge_attr \u001b[38;5;241m=\u001b[39m \u001b[43madd_self_loops\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 325\u001b[0m \u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_attr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfill_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 326\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_nodes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_nodes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 327\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(edge_index, SparseTensor):\n\u001b[0;32m 328\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39medge_dim \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\utils\\loop.py:488\u001b[0m, in \u001b[0;36madd_self_loops\u001b[1;34m(edge_index, edge_attr, fill_value, num_nodes)\u001b[0m\n\u001b[0;32m 485\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected sparse tensor layout (got \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mlayout\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m)\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 487\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m edge_attr \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m--> 488\u001b[0m loop_attr \u001b[38;5;241m=\u001b[39m \u001b[43mcompute_loop_attr\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m#\u001b[39;49;00m\n\u001b[0;32m 489\u001b[0m \u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_attr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mN\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_sparse\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 490\u001b[0m edge_attr \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([edge_attr, loop_attr], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m 492\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m full_edge_index, edge_attr\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\utils\\loop.py:767\u001b[0m, in \u001b[0;36mcompute_loop_attr\u001b[1;34m(edge_index, edge_attr, num_nodes, is_sparse, fill_value)\u001b[0m\n\u001b[0;32m 765\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(fill_value, \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m 766\u001b[0m col \u001b[38;5;241m=\u001b[39m edge_index[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m is_sparse \u001b[38;5;28;01melse\u001b[39;00m edge_index[\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m--> 767\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mscatter\u001b[49m\u001b[43m(\u001b[49m\u001b[43medge_attr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_nodes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 769\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo valid \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfill_value\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m provided\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\utils\\_scatter.py:83\u001b[0m, in \u001b[0;36mscatter\u001b[1;34m(src, index, dim, dim_size, reduce)\u001b[0m\n\u001b[0;32m 80\u001b[0m count \u001b[38;5;241m=\u001b[39m count\u001b[38;5;241m.\u001b[39mclamp(\u001b[38;5;28mmin\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m 82\u001b[0m index \u001b[38;5;241m=\u001b[39m broadcast(index, src, dim)\n\u001b[1;32m---> 83\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43msrc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnew_zeros\u001b[49m\u001b[43m(\u001b[49m\u001b[43msize\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscatter_add_\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msrc\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 85\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out \u001b[38;5;241m/\u001b[39m broadcast(count, out, dim)\n\u001b[0;32m 87\u001b[0m \u001b[38;5;66;03m# For \"min\" and \"max\" reduction, we prefer `scatter_reduce_` on CPU or\u001b[39;00m\n\u001b[0;32m 88\u001b[0m \u001b[38;5;66;03m# in case the input does not require gradients:\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\conv\\gat_conv.py:239\u001b[0m, in \u001b[0;36mGATConv.forward\u001b[1;34m(self, x, edge_index, edge_attr, size, return_attention_weights)\u001b[0m\n\u001b[0;32m 236\u001b[0m num_nodes \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmin\u001b[39m(size) \u001b[38;5;28;01mif\u001b[39;00m size \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m num_nodes\n\u001b[0;32m 237\u001b[0m edge_index, edge_attr \u001b[38;5;241m=\u001b[39m remove_self_loops(\n\u001b[0;32m 238\u001b[0m edge_index, edge_attr)\n\u001b[1;32m--> 239\u001b[0m edge_index, edge_attr \u001b[38;5;241m=\u001b[39m \u001b[43madd_self_loops\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 240\u001b[0m \u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_attr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfill_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 241\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_nodes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_nodes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 242\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(edge_index, SparseTensor):\n\u001b[0;32m 243\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39medge_dim \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\utils\\loop.py:263\u001b[0m, in \u001b[0;36madd_self_loops\u001b[1;34m(edge_index, edge_attr, fill_value, num_nodes)\u001b[0m\n\u001b[0;32m 261\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(fill_value, \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m 262\u001b[0m col \u001b[38;5;241m=\u001b[39m edge_index[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m is_sparse \u001b[38;5;28;01melse\u001b[39;00m edge_index[\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m--> 263\u001b[0m loop_attr \u001b[38;5;241m=\u001b[39m \u001b[43mscatter\u001b[49m\u001b[43m(\u001b[49m\u001b[43medge_attr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mN\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 264\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 265\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo valid \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfill_value\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m provided\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } @@ -439,32 +446,42 @@ "optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n", "criterion = torch.nn.CrossEntropyLoss()\n", "\n", + "\n", "def train():\n", " model.train()\n", "\n", " for data in train_loader: # Iterate in batches over the training dataset.\n", - " out = model(data.x, data.edge_index, data.edge_attr, data.batch) # Perform a single forward pass.\n", - " loss = criterion(out, data.y) # Compute the loss.\n", - " loss.backward() # Derive gradients.\n", - " optimizer.step() # Update parameters based on gradients.\n", - " optimizer.zero_grad() # Clear gradients.\n", + " out = model(\n", + " data.x, data.edge_index, data.edge_attr, data.batch\n", + " ) # Perform a single forward pass.\n", + " loss = criterion(out, data.y) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", "\n", "def test(loader):\n", - " model.eval()\n", + " model.eval()\n", "\n", - " correct = 0\n", - " for data in loader: # Iterate in batches over the training/test dataset.\n", - " out = model(data.x, data.edge_index, data.edge_attr,data.batch) \n", - " pred = out.argmax(dim=1) # Use the class with highest probability.\n", - " correct += int((pred == data.y).sum()) # Check against ground-truth labels.\n", - " return correct / len(loader.dataset) # Derive ratio of correct predictions.\n", + " correct = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data.x, data.edge_index, data.edge_attr, data.batch)\n", + " pred = out.argmax(dim=1) # Use the class with highest probability.\n", + " correct += int(\n", + " (pred == data.y).sum()\n", + " ) # Check against ground-truth labels.\n", + " return correct / len(\n", + " loader.dataset\n", + " ) # Derive ratio of correct predictions.\n", "\n", "\n", "for epoch in range(1, 171):\n", " train()\n", " train_acc = test(train_loader)\n", " test_acc = test(test_loader)\n", - " print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')" + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" ] }, { diff --git a/experiments/__init__.py b/experiments/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/lightning_modules/BaseModuleOrientability.py b/experiments/lightning_modules/BaseModuleOrientability.py new file mode 100644 index 0000000..a49a9fa --- /dev/null +++ b/experiments/lightning_modules/BaseModuleOrientability.py @@ -0,0 +1,59 @@ +from typing import Literal + +import lightning as L +import torch +import torchmetrics + + +class BaseOrientabilityModule(L.LightningModule): + def __init__(self): + super().__init__() + # Accuracy metrics + self.training_accuracy = torchmetrics.classification.BinaryAccuracy() + self.validation_accuracy = torchmetrics.classification.BinaryAccuracy() + self.test_accuracy = torchmetrics.classification.BinaryAccuracy() + + def log_accuracies( + self, x_hat, y, batch_len, step: Literal["train", "test", "validation"] + ): + # Apply the sigmoid function to x_hat to get the probabilities + x_hat = torch.sigmoid(x_hat) + if step == "train": + self.training_accuracy(x_hat, y) + self.log( + "train_accuracy", + self.training_accuracy, + prog_bar=True, + on_step=False, + on_epoch=True, + batch_size=batch_len, + ) + elif step == "test": + self.test_accuracy(x_hat, y) + self.log( + "test_accuracy", + self.test_accuracy, + prog_bar=True, + on_step=False, + on_epoch=True, + batch_size=batch_len, + ) + elif step == "validation": + self.validation_accuracy(x_hat, y) + self.log( + "validation_accuracy", + self.validation_accuracy, + prog_bar=True, + on_step=False, + on_epoch=True, + batch_size=batch_len, + ) + + def test_step(self, batch, batch_idx): + return self.general_step(batch, batch_idx, "test") + + def validation_step(self, batch, batch_idx): + return self.general_step(batch, batch_idx, "validation") + + def training_step(self, batch, batch_idx): + return self.general_step(batch, batch_idx, "train") diff --git a/experiments/lightning_modules/GraphCommonModuleOrientability.py b/experiments/lightning_modules/GraphCommonModuleOrientability.py new file mode 100644 index 0000000..03ceecb --- /dev/null +++ b/experiments/lightning_modules/GraphCommonModuleOrientability.py @@ -0,0 +1,33 @@ +from torch import nn + +from experiments.lightning_modules.BaseModuleOrientability import ( + BaseOrientabilityModule, +) + + +class GraphCommonModuleOrientability(BaseOrientabilityModule): + def __init__(self, base_model): + super().__init__() + self.base_model = base_model + + def forward(self, x, edge_index, batch): + x = self.base_model(x, edge_index, batch) + return x + + def general_step(self, batch, batch_idx, step: str): + batch_len = len(batch.y) + x_hat = self(batch.x, batch.edge_index, batch.batch) + # Squeeze x_hat to match the shape of y + x_hat = x_hat.squeeze() + y = batch.y.float() + loss = nn.functional.binary_cross_entropy_with_logits(x_hat, y) + self.log( + f"{step}_loss", + loss, + prog_bar=True, + batch_size=batch_len, + on_step=False, + on_epoch=True, + ) + self.log_accuracies(x_hat, y, batch_len, step) + return loss diff --git a/experiments/lightning_modules/__init__.py b/experiments/lightning_modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/orientability/__init__.py b/experiments/orientability/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/orientability/graphs/GATSimplex2Vec.py b/experiments/orientability/graphs/GATSimplex2Vec.py new file mode 100644 index 0000000..b4c18bc --- /dev/null +++ b/experiments/orientability/graphs/GATSimplex2Vec.py @@ -0,0 +1,87 @@ +import lightning as L +import torch +import torchvision.transforms as transforms +from torch.utils.data import Subset +from torch_geometric.data import DataLoader +from torch_geometric.transforms import FaceToEdge + +from experiments.lightning_modules.GraphCommonModuleOrientability import ( + GraphCommonModuleOrientability, +) +from mantra.simplicial import SimplicialDataset +from mantra.transforms import ( + TriangulationToFaceTransform, + SetNumNodesTransform, + DegreeTransform, + OrientableToClassTransform, + Simplex2VecTransform, +) +from models.graphs.GAT import GATNetwork + + +class GATSimplexToVecModule(GraphCommonModuleOrientability): + def __init__( + self, + hidden_channels, + num_node_features, + out_channels, + num_heads, + num_hidden_layers, + learning_rate=0.0001, + ): + base_model = GATNetwork( + hidden_channels=hidden_channels, + num_node_features=num_node_features, + out_channels=out_channels, + num_heads=num_heads, + num_hidden_layers=num_hidden_layers, + ) + super().__init__(base_model=base_model) + self.learning_rate = learning_rate + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + +def load_dataset_with_transformations(): + tr = transforms.Compose( + [ + TriangulationToFaceTransform(), + SetNumNodesTransform(), + FaceToEdge(remove_faces=False), + DegreeTransform(), + OrientableToClassTransform(), + Simplex2VecTransform(), + ] + ) + dataset = SimplicialDataset(root="./data", transform=tr) + return dataset + + +def single_experiment_orientability_gat_simplex2vec(): + # =============================== + # Training parameters + # =============================== + hidden_channels = 64 + num_hidden_layers = 2 + num_heads = 4 + batch_size = 32 + max_epochs = 100 + learning_rate = 0.0001 + # =============================== + dataset = load_dataset_with_transformations() + model = GATSimplexToVecModule( + hidden_channels=hidden_channels, + num_node_features=dataset.num_node_features, + out_channels=1, + num_heads=num_heads, + num_hidden_layers=num_hidden_layers, + learning_rate=learning_rate, + ) + train_ds = Subset(dataset, dataset.train_orientability_indices) + test_ds = Subset(dataset, dataset.test_orientability_indices) + train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True) + test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False) + trainer = L.Trainer(max_epochs=max_epochs, log_every_n_steps=1) + trainer.fit(model, train_dl, test_dl) diff --git a/experiments/orientability/graphs/GCN.py b/experiments/orientability/graphs/GCN.py new file mode 100644 index 0000000..6087234 --- /dev/null +++ b/experiments/orientability/graphs/GCN.py @@ -0,0 +1,81 @@ +import lightning as L +import torch +import torchvision.transforms as transforms +from torch.utils.data import Subset +from torch_geometric.loader import DataLoader +from torch_geometric.transforms import FaceToEdge + +from experiments.lightning_modules.GraphCommonModuleOrientability import ( + GraphCommonModuleOrientability, +) +from mantra.simplicial import SimplicialDataset +from mantra.transforms import ( + TriangulationToFaceTransform, + DegreeTransform, + OrientableToClassTransform, +) +from models.graphs.GCN import GCNetwork + + +class GCNModule(GraphCommonModuleOrientability): + def __init__( + self, + hidden_channels, + num_node_features, + out_channels, + num_hidden_layers, + learning_rate=0.01, + ): + base_model = GCNetwork( + hidden_channels=hidden_channels, + num_node_features=num_node_features, + out_channels=out_channels, + num_hidden_layers=num_hidden_layers, + ) + super().__init__(base_model=base_model) + self.learning_rate = learning_rate + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.base_model.parameters(), lr=self.learning_rate + ) + return optimizer + + +def load_dataset_with_transformations(): + tr = transforms.Compose( + [ + TriangulationToFaceTransform(), + FaceToEdge(remove_faces=False), + DegreeTransform(), + OrientableToClassTransform(), + ] + ) + dataset = SimplicialDataset(root="./data", transform=tr) + return dataset + + +def single_experiment_orientability_gnn(): + # =============================== + # Training parameters + # =============================== + hidden_channels = 64 + num_hidden_layers = 2 + batch_size = 32 + max_epochs = 100 + learning_rate = 0.1 + # =============================== + dataset = load_dataset_with_transformations() + model = GCNModule( + hidden_channels=hidden_channels, + num_node_features=dataset.num_node_features, + out_channels=1, # Binary classification + num_hidden_layers=num_hidden_layers, + learning_rate=learning_rate, + ) + train_ds = Subset(dataset, dataset.train_orientability_indices) + test_ds = Subset(dataset, dataset.test_orientability_indices) + train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True) + test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False) + trainer = L.Trainer(max_epochs=max_epochs, log_every_n_steps=1) + trainer.fit(model, train_dl, test_dl) diff --git a/experiments/orientability/graphs/__init__.py b/experiments/orientability/graphs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/orientability/simplicial_complexes/SCNN.py b/experiments/orientability/simplicial_complexes/SCNN.py new file mode 100644 index 0000000..12c60bc --- /dev/null +++ b/experiments/orientability/simplicial_complexes/SCNN.py @@ -0,0 +1,169 @@ +from typing import Literal + +import lightning as L +import torch +import torchvision.transforms as transforms +from torch import nn +from torch.utils.data import Subset + +from experiments.lightning_modules.BaseModuleOrientability import ( + BaseOrientabilityModule, +) +from mantra.dataloaders import SimplicialDataLoader +from mantra.simplicial import SimplicialDataset +from mantra.transforms import ( + OrientableToClassSimplicialComplexTransform, + DimTwoHodgeLaplacianSimplicialComplexTransform, + DimOneHodgeLaplacianDownSimplicialComplexTransform, + DimOneHodgeLaplacianUpSimplicialComplexTransform, + DimZeroHodgeLaplacianSimplicialComplexTransform, + SimplicialComplexOnesTransform, +) +from mantra.transforms import SimplicialComplexTransform +from mantra.utils import transfer_simplicial_complex_batch_to_device +from models.simplicial_complexes.SCNN import SCNNNetwork + + +class SCNNNModule(BaseOrientabilityModule): + def __init__( + self, + rank, + in_channels, + hidden_channels, + out_channels, + conv_order_down, + conv_order_up, + n_layers=3, + learning_rate=0.01, + ): + super().__init__() + self.rank = rank + self.learning_rate = learning_rate + self.base_model = SCNNNetwork( + rank=rank, + in_channels=in_channels, + hidden_channels=hidden_channels, + out_channels=out_channels, + conv_order_down=conv_order_down, + conv_order_up=conv_order_up, + n_layers=n_layers, + ) + + def forward(self, x, laplacian_down, laplacian_up, signal_belongings): + x = self.base_model(x, laplacian_down, laplacian_up, signal_belongings) + return x + + def general_step( + self, batch, batch_idx, step: Literal["train", "test", "validation"] + ): + s_complexes, signal_belongings, batch_len = batch + x = s_complexes.signals[self.rank] + if self.rank == 0: + laplacian_down = None + laplacian_up = s_complexes.neighborhood_matrices[f"0_laplacian"] + elif self.rank == 1: + laplacian_down = s_complexes.neighborhood_matrices[ + f"1_laplacian_down" + ] + laplacian_up = s_complexes.neighborhood_matrices[f"1_laplacian_up"] + elif self.rank == 2: + laplacian_down = s_complexes.neighborhood_matrices[f"2_laplacian"] + laplacian_up = None + else: + raise ValueError("rank must be 0, 1 or 2.") + y = s_complexes.other_features["y"].float() + signal_belongings = signal_belongings[self.rank] + x_hat = self(x, laplacian_down, laplacian_up, signal_belongings) + # Squeeze x_hat to match the shape of y + x_hat = x_hat.squeeze() + loss = nn.functional.binary_cross_entropy_with_logits(x_hat, y) + self.log( + f"{step}_loss", + loss, + prog_bar=True, + batch_size=batch_len, + on_step=False, + on_epoch=True, + ) + self.log_accuracies(x_hat, y, batch_len, step) + return loss + + def transfer_batch_to_device(self, batch, device, dataloader_idx): + return transfer_simplicial_complex_batch_to_device( + batch, device, dataloader_idx + ) + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.base_model.parameters(), lr=self.learning_rate + ) + return optimizer + + +def load_dataset_with_transformations(): + tr = transforms.Compose( + [ + SimplicialComplexTransform(), + SimplicialComplexOnesTransform(ones_length=10), + DimZeroHodgeLaplacianSimplicialComplexTransform(), + DimOneHodgeLaplacianUpSimplicialComplexTransform(), + DimOneHodgeLaplacianDownSimplicialComplexTransform(), + DimTwoHodgeLaplacianSimplicialComplexTransform(), + OrientableToClassSimplicialComplexTransform(), + ] + ) + dataset = SimplicialDataset(root="./data", transform=tr) + return dataset + + +def single_experiment_orientability_scnn(): + dataset = load_dataset_with_transformations() + # =============================== + # Training parameters + # =============================== + rank = 1 # We work with edge features + conv_order_down = 2 # TODO: No idea of what this parameter does + conv_order_up = 2 # TODO: No idea of what this parameter does + hidden_channels = 20 + out_channels = 1 # num classes + num_layers = 5 + batch_size = 128 + max_epochs = 100 + learning_rate = 0.01 + # =============================== + # Checks and dependent parameters + # =============================== + # Check the rank has an appropriate value. + assert 0 <= rank <= 2, "rank must be 0, 1 or 2." + # select the simplex level + if rank == 0: + conv_order_down = 0 + # configure parameters + in_channels = dataset[0].x[rank].shape[1] + # =============================== + # Model and dataloader creation + # =============================== + model = SCNNNModule( + rank=rank, + in_channels=in_channels, + hidden_channels=hidden_channels, + out_channels=out_channels, + conv_order_down=conv_order_down, + conv_order_up=conv_order_up, + n_layers=num_layers, + learning_rate=learning_rate, + ) + train_ds = Subset(dataset, dataset.train_orientability_indices) + test_ds = Subset(dataset, dataset.test_orientability_indices) + train_dl = SimplicialDataLoader( + train_ds, batch_size=batch_size, shuffle=True + ) + test_dl = SimplicialDataLoader( + test_ds, batch_size=batch_size, shuffle=False + ) + # Use CPU acceleration: SCCNN does not support GPU acceleration because it creates matrices not placed in the + # device of the network. + trainer = L.Trainer( + max_epochs=max_epochs, accelerator="cpu", log_every_n_steps=1 + ) + trainer.fit(model, train_dl, test_dl) diff --git a/experiments/orientability/simplicial_complexes/__init__.py b/experiments/orientability/simplicial_complexes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/k_simplex2vec.py b/k_simplex2vec.py index 6397675..8e15498 100644 --- a/k_simplex2vec.py +++ b/k_simplex2vec.py @@ -1,10 +1,7 @@ -import gudhi import numpy as np -from numpy import matlib -import random -import gensim -from gensim.models import Word2Vec import scipy.sparse +from gensim.models import Word2Vec + # Some function that will be useful for the rest of the code diff --git a/main.py b/main.py new file mode 100644 index 0000000..7660773 --- /dev/null +++ b/main.py @@ -0,0 +1,14 @@ +from experiments.orientability.graphs.GATSimplex2Vec import ( + single_experiment_orientability_gat_simplex2vec, +) +from experiments.orientability.graphs.GCN import ( + single_experiment_orientability_gnn, +) +from experiments.orientability.simplicial_complexes.SCNN import ( + single_experiment_orientability_scnn, +) + +if __name__ == "__main__": + single_experiment_orientability_gnn() + single_experiment_orientability_gat_simplex2vec() + single_experiment_orientability_scnn() diff --git a/mantra/convert.py b/mantra/convert.py index 0542273..19e3320 100644 --- a/mantra/convert.py +++ b/mantra/convert.py @@ -14,10 +14,11 @@ import argparse import json import re -import pydantic from typing import List, Optional, Dict -import pandas as pd + import numpy as np +import pandas as pd +import pydantic class Triangulation(pydantic.BaseModel): @@ -186,7 +187,6 @@ def process_manifolds( filename_homology: str | None = None, filename_type: str | None = None, ) -> List[Dict]: - homology_groups, types = {}, {} # Parse triangulations @@ -247,3 +247,19 @@ def process_manifolds( # f.write(result) # else: # print(result) + + +def process_train_test_split_orientability( + filename_train_test_split_orientability: str, +): + with open(filename_train_test_split_orientability, "r") as f: + indices_raw = f.readlines() + for indices_line in indices_raw: + type_indices, indices = indices_line.split(":") + if type_indices == "Train indices": + train_indices = [int(idx) for idx in indices.strip().split(" ")] + elif type_indices == "Test indices": + test_indices = [int(idx) for idx in indices.strip().split(" ")] + else: + raise NotImplementedError("Unknown type of indices") + return np.array(train_indices), np.array(test_indices) diff --git a/mantra/dataloaders.py b/mantra/dataloaders.py index bfc4eed..85d818a 100644 --- a/mantra/dataloaders.py +++ b/mantra/dataloaders.py @@ -3,7 +3,6 @@ import numpy as np import scipy import torch.utils.data - from torch.utils.data.dataloader import DataLoader diff --git a/mantra/download_data.py b/mantra/download_data.py index 19567ca..a8c1352 100644 --- a/mantra/download_data.py +++ b/mantra/download_data.py @@ -6,7 +6,6 @@ def download_data(): - response = requests.get( "https://www3.math.tu-berlin.de/IfM/Nachrufe/Frank_Lutz/stellar/2_manifolds_all.txt" ) diff --git a/mantra/generation.py b/mantra/generation.py new file mode 100644 index 0000000..4fe0461 --- /dev/null +++ b/mantra/generation.py @@ -0,0 +1,42 @@ +import numpy as np +from sklearn.model_selection import train_test_split + +from mantra.convert import process_manifolds + + +def generate_random_split( + all_dataset_triangulations_path, + all_dataset_homology_path, + all_dataset_type_path, + test_size=0.2, + output_filename="./data/train_test_split_orientability.txt", +): + processed_manifolds = process_manifolds( + all_dataset_triangulations_path, + all_dataset_homology_path, + all_dataset_type_path, + ) + indices_dataset = np.arange(len(processed_manifolds)) + orientability_labels = np.array( + [int(manifold["orientable"]) for manifold in processed_manifolds] + ) + X_train, X_test = train_test_split( + indices_dataset, + test_size=test_size, + shuffle=True, + stratify=orientability_labels, + ) + # Create a txt file with the indices of the train and test set + with open(output_filename, "w") as f: + f.write("Train indices: " + " ".join(map(str, X_train)) + "\n") + f.write("Test indices: " + " ".join(map(str, X_test)) + "\n") + + +if __name__ == "__main__": + generate_random_split( + all_dataset_triangulations_path="../data/simplicial_v1.0.0/raw/2_manifolds_all.txt", + all_dataset_homology_path="../data/simplicial_v1.0.0/raw/2_manifolds_all_hom.txt", + all_dataset_type_path="../data/simplicial_v1.0.0/raw/2_manifolds_all_type.txt", + test_size=0.2, + output_filename="../data/train_test_split_orientability.txt", + ) diff --git a/mantra/simplicial.py b/mantra/simplicial.py index a2c415e..38d7c2d 100644 --- a/mantra/simplicial.py +++ b/mantra/simplicial.py @@ -4,21 +4,18 @@ conjunction to dataloaders. """ +import torch from torch_geometric.data import InMemoryDataset, download_url, Data -from mantra.convert import process_manifolds + +from mantra.convert import ( + process_manifolds, + process_train_test_split_orientability, +) class SimplicialDataset(InMemoryDataset): available_versions = ["1.0.0"] - @staticmethod - def _get_raw_dataset_root_link(version: str): - match version: - case "1.0.0": - return "https://www3.math.tu-berlin.de/IfM/Nachrufe/Frank_Lutz/stellar" - case _: - raise ValueError(f"Version {version} not available") - def __init__( self, root, @@ -33,6 +30,8 @@ def __init__( root += f"/simplicial_v{version}" super().__init__(root, transform, pre_transform, pre_filter) self.load(self.processed_paths[0]) + self.train_orientability_indices = torch.load(self.processed_paths[1]) + self.test_orientability_indices = torch.load(self.processed_paths[2]) @property def raw_file_names(self): @@ -40,19 +39,38 @@ def raw_file_names(self): f"{self.manifold}_manifolds_all.txt", f"{self.manifold}_manifolds_all_type.txt", f"{self.manifold}_manifolds_all_hom.txt", + f"{self.manifold}_manifolds_all_train_test_split_orientability.txt", ] @property def processed_file_names(self): - return ["data.pt"] + return [ + "data.pt", + "train_orientability_indices.pt", + "test_orientability_indices.pt", + ] + + def _get_download_links(self, version: str): + match version: + case "1.0.0": + root_manifolds = "https://www3.math.tu-berlin.de/IfM/Nachrufe/Frank_Lutz/stellar" + manifolds_files = [ + f"{root_manifolds}/{name}" + for name in self.raw_file_names + if name + != f"{self.manifold}_manifolds_all_train_test_split_orientability.txt" + ] + orientability_indices_file = [ + "https://rubenbb.com/assets/MANTRA/v1.0.0/2_manifolds_all_train_test_split_orientability.txt" + ] + return manifolds_files + orientability_indices_file + case _: + raise ValueError(f"Version {version} not available") def download(self): - root_link = self._get_raw_dataset_root_link(self.version) - for name in self.raw_file_names: - download_url( - f"{root_link}/{name}", - self.raw_dir, - ) + download_links = self._get_download_links(self.version) + for download_link in download_links: + download_url(download_link, self.raw_dir) def process(self): triangulations = process_manifolds( @@ -69,4 +87,11 @@ def process(self): if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] + orien_train_indices, orien_test_indices = ( + process_train_test_split_orientability( + f"{self.raw_dir}/{self.manifold}_manifolds_all_train_test_split_orientability.txt" + ) + ) self.save(data_list, self.processed_paths[0]) + torch.save(orien_train_indices, self.processed_paths[1]) + torch.save(orien_test_indices, self.processed_paths[2]) diff --git a/mantra/transforms.py b/mantra/transforms.py index c8d2dc3..9be1435 100644 --- a/mantra/transforms.py +++ b/mantra/transforms.py @@ -1,8 +1,11 @@ +import gudhi import numpy as np -from torch_geometric.utils import degree -from toponetx.classes import SimplicialComplex import torch +from toponetx.classes import SimplicialComplex +from torch_geometric.transforms import ToUndirected +from torch_geometric.utils import degree +import k_simplex2vec as ks2v from mantra.utils import ( create_signals_on_data_if_needed, append_signals, @@ -11,9 +14,60 @@ ) +class SetNumNodesTransform(object): + def __call__(self, data): + data.num_nodes = data.n_vertices + return data + + +class Simplex2VecTransform(object): + def __call__(self, data): + st = gudhi.SimplexTree() + + ei = [ + [edge[0], edge[1]] + for edge in data.edge_index.T.tolist() + if edge[0] < edge[1] + ] + data.edge_index = torch.tensor(ei).T + # Say hi to bad programming + for edge in ei: + st.insert(edge) + st.expansion(3) + + p1 = ks2v.assemble(cplx=st, k=1, scheme="uniform", laziness=None) + P1 = p1.toarray() + + Simplices = list() + for simplex in st.get_filtration(): + if simplex[1] != np.inf: + Simplices.append(simplex[0]) + else: + break + + ## Perform random walks on the edges + L = 20 + N = 40 + Walks = ks2v.RandomWalks(walk_length=L, number_walks=N, P=P1, seed=3) + # to save the walks in a text file + ks2v.save_random_walks(Walks, "RandomWalks_Edges.txt") + + ## Embed the edges + Emb = ks2v.Embedding( + Walks=Walks, + emb_dim=20, + epochs=5, + filename="k-simplex2vec_Edge_embedding.model", + ) + data.edge_attr = torch.tensor(Emb.wv.vectors) + toundirected = ToUndirected() + data = toundirected(data) + return data + + class OrientableToClassTransform(object): def __call__(self, data): - data.y = data.orientable.long() + data.y = data.orientable return data @@ -50,6 +104,18 @@ def __call__(self, data): return data +class SimplicialComplexOnesTransform(object): + def __init__(self, ones_length=10): + self.ones_length = ones_length + + def __call__(self, data): + data = create_signals_on_data_if_needed(data) + for dim in range(len(data.sc.shape)): + ones_signals = torch.ones(data.sc.shape[dim], self.ones_length) + data = append_signals(data, dim, ones_signals) + return data + + class SimplicialComplexEdgeCoadjacencyDegreeTransform(object): def __call__(self, data): data = create_signals_on_data_if_needed(data) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/graphs/GAT.py b/models/graphs/GAT.py new file mode 100644 index 0000000..ea59a01 --- /dev/null +++ b/models/graphs/GAT.py @@ -0,0 +1,41 @@ +import torch.nn.functional as F +from torch import nn +from torch_geometric.nn import GATConv, global_mean_pool + + +class GATNetwork(nn.Module): + def __init__( + self, + hidden_channels, + num_node_features, + out_channels, + num_heads, + num_hidden_layers, + ): + super().__init__() + hidden_channels_per_head = hidden_channels // num_heads + self.gat_input = GATConv( + num_node_features, hidden_channels_per_head, heads=num_heads + ) + self.hidden_layers = nn.ModuleList( + [ + GATConv( + hidden_channels, hidden_channels_per_head, heads=num_heads + ) + for _ in range(num_hidden_layers) + ] + ) + self.final_linear = nn.Linear(hidden_channels, out_channels) + + def forward(self, x, edge_index, batch): + # 1. Obtain node embeddings + x = self.gat_input(x, edge_index) + for layer in self.hidden_layers: + x = layer(x, edge_index) + # 2. Readout layer + x = global_mean_pool(x, batch) # [batch_size, hidden_channels] + + # 3. Apply a final classifier + x = F.dropout(x, p=0.5, training=self.training) + x = self.final_linear(x) + return x diff --git a/models/graphs/GCN.py b/models/graphs/GCN.py new file mode 100644 index 0000000..95ac44e --- /dev/null +++ b/models/graphs/GCN.py @@ -0,0 +1,35 @@ +import torch.nn.functional as F +from torch import nn +from torch_geometric.nn import GCNConv +from torch_geometric.nn import global_mean_pool + + +class GCNetwork(nn.Module): + def __init__( + self, + hidden_channels, + num_node_features, + out_channels, + num_hidden_layers, + ): + super().__init__() + self.conv_input = GCNConv(num_node_features, hidden_channels) + self.hidden_layers = nn.ModuleList( + [ + GCNConv(hidden_channels, hidden_channels) + for _ in range(num_hidden_layers) + ] + ) + self.final_linear = nn.Linear(hidden_channels, out_channels) + + def forward(self, x, edge_index, batch): + # 1. Obtain node embeddings + x = self.conv_input(x, edge_index) + for layer in self.hidden_layers: + x = layer(x, edge_index) + # 2. Readout layer + x = global_mean_pool(x, batch) # [batch_size, hidden_channels] + # 3. Apply a final classifier + x = F.dropout(x, p=0.5, training=self.training) + x = self.final_linear(x) + return x diff --git a/models/graphs/__init__.py b/models/graphs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/simplicial_complexes/SCNN.py b/models/simplicial_complexes/SCNN.py new file mode 100644 index 0000000..3700cca --- /dev/null +++ b/models/simplicial_complexes/SCNN.py @@ -0,0 +1,34 @@ +import torch +from topomodelx.nn.simplicial.scnn import SCNN +from torch import nn +from torch_geometric.nn import pool + + +class SCNNNetwork(nn.Module): + def __init__( + self, + rank, + in_channels, + hidden_channels, + out_channels, + conv_order_down, + conv_order_up, + n_layers=3, + ): + super().__init__() + self.rank = rank + self.base_model = SCNN( + in_channels=in_channels, + hidden_channels=hidden_channels, + conv_order_down=conv_order_down, + conv_order_up=conv_order_up, + n_layers=n_layers, + ) + self.liner_readout = torch.nn.Linear(hidden_channels, out_channels) + + def forward(self, x, laplacian_down, laplacian_up, signal_belongings): + x = self.base_model(x, laplacian_down, laplacian_up) + x = self.liner_readout(x) + x_mean = pool.global_mean_pool(x, signal_belongings) + x_mean[torch.isnan(x_mean)] = 0 + return x_mean diff --git a/models/simplicial_complexes/__init__.py b/models/simplicial_complexes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 0112e56..1bea603 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,20 +6,20 @@ build-backend = "hatchling.build" name = "mantra" version = "0.0.1" dependencies = [ - "gudhi", - "lightning" + "gudhi", + "lightning" ] requires-python = ">=3.8" authors = [ - {name = "Bastian Rieck", email = "bastian.rieck@helmholtz-munich.de"}, + { name = "Bastian Rieck", email = "bastian.rieck@helmholtz-munich.de" }, ] description = "Manifold Triangulations" readme = "README.md" -license = {file = "LICENSE.md"} +license = { file = "LICENSE.md" } classifiers = [ - "Development Status :: 4 - Beta", - "Programming Language :: Python", - "License :: OSI Approved :: BSD-3-Clause", + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "License :: OSI Approved :: BSD-3-Clause", ] [project.urls] diff --git a/tests/test_homology.py b/tests/test_homology.py index db263c1..654aec8 100644 --- a/tests/test_homology.py +++ b/tests/test_homology.py @@ -5,12 +5,12 @@ from the triangulation and its homology is calculated. """ -import gudhi as gd - import argparse import itertools import json +import gudhi as gd + def build_simplex_tree(top_level_simplices): simplices = set([tuple(s) for s in top_level_simplices]) diff --git a/validation/validate_homology.py b/validation/validate_homology.py index adb09b0..6ad23b8 100644 --- a/validation/validate_homology.py +++ b/validation/validate_homology.py @@ -5,12 +5,12 @@ homology is calculated. """ -import gudhi as gd - import argparse import itertools import json +import gudhi as gd + def build_simplex_tree(top_level_simplices): simplices = set([tuple(s) for s in top_level_simplices])