Skip to content

Commit

Permalink
allsettransformer once more
Browse files Browse the repository at this point in the history
  • Loading branch information
levtelyatnikov committed Nov 7, 2023
1 parent b745e70 commit 751aef8
Showing 1 changed file with 20 additions and 32 deletions.
52 changes: 20 additions & 32 deletions tutorials/hypergraph/allset_transformer_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2023-06-01T16:14:51.222779223Z",
Expand Down Expand Up @@ -110,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2023-06-01T16:14:51.959770754Z",
Expand Down Expand Up @@ -145,7 +145,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2023-06-01T16:14:53.022151550Z",
Expand Down Expand Up @@ -189,7 +189,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2023-06-01T16:14:53.022151550Z",
Expand Down Expand Up @@ -230,7 +230,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -278,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -311,12 +311,12 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class Network(torch.nn.Module):\n",
" \"\"\"Network class that initializes the base model and readout layer.\n",
" \"\"\"Network class that initializes the AllSet model and readout layer.\n",
"\n",
" Base model parameters:\n",
" ----------\n",
Expand Down Expand Up @@ -374,7 +374,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -395,7 +395,6 @@
" in_channels=in_channels,\n",
" hidden_channels=hidden_channels,\n",
" out_channels=out_channels,\n",
" heads=heads,\n",
" n_layers=n_layers,\n",
" mlp_num_layers=mlp_num_layers,\n",
" task_level=task_level,\n",
Expand All @@ -414,7 +413,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -430,7 +429,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2023-06-01T16:14:59.046068930Z",
Expand All @@ -442,9 +441,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_802086/276484184.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
"/tmp/ipykernel_850617/276484184.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" x_0s = torch.tensor(x_0s)\n",
"/tmp/ipykernel_802086/276484184.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
"/tmp/ipykernel_850617/276484184.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" torch.tensor(y, dtype=torch.long).to(device),\n"
]
}
Expand All @@ -468,25 +467,21 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 5 \n",
"Train_loss: 1.5679, acc: 0.9929\n",
"Val_loss: 0.9291, Val_acc: 0.7360\n",
"Test_loss: 0.8751, Test_acc: 0.7670\n",
"Train_loss: 1.7016, acc: 0.9714\n",
"Val_loss: 1.1253, Val_acc: 0.6820\n",
"Test_loss: 1.0647, Test_acc: 0.6940\n",
"Epoch: 10 \n",
"Train_loss: 0.8660, acc: 1.0000\n",
"Val_loss: 1.0993, Val_acc: 0.7380\n",
"Test_loss: 1.0621, Test_acc: 0.7370\n",
"Epoch: 15 \n",
"Train_loss: 0.5851, acc: 1.0000\n",
"Val_loss: 1.7357, Val_acc: 0.6780\n",
"Test_loss: 1.5269, Test_acc: 0.7230\n"
"Train_loss: 0.9825, acc: 1.0000\n",
"Val_loss: 0.9897, Val_acc: 0.7340\n",
"Test_loss: 0.8711, Test_acc: 0.7730\n"
]
}
],
Expand Down Expand Up @@ -532,13 +527,6 @@
" flush=True,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit 751aef8

Please sign in to comment.