diff --git a/examples/notebooks/ogbn_mag_e2e.ipynb b/examples/notebooks/ogbn_mag_e2e.ipynb index c59e1803..9e8c5335 100644 --- a/examples/notebooks/ogbn_mag_e2e.ipynb +++ b/examples/notebooks/ogbn_mag_e2e.ipynb @@ -456,7 +456,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Running on TPU ['10.113.155.218:8470']\n", + "Running on TPU ['10.116.185.2:8470']\n", "Using TPUStrategy\n", "Found 8 replicas in sync\n" ] @@ -701,7 +701,9 @@ "source": [ "### Initialization of Hidden States\n", "\n", - "The hidden states on nodes are created by mapping a dict of (preprocessed) features to fixed-size hidden states for nodes. Similarly to feature preprocessing, the `tfgnn.keras.layers.MapFeatures` layer lets you specify such a transformation as a callback function that transforms feature dicts, with GraphTensor mechanics taken off your shoulders." + "The hidden states on nodes are created by mapping a dict of (preprocessed) features to fixed-size hidden states for nodes. It often makes sense to send input features through a small encoder network, like the `Dense` layer applied below to the `\"feat\"` of paper nodes.\n", + "\n", + "Similarly to feature preprocessing, the `tfgnn.keras.layers.MapFeatures` layer lets you specify such a transformation as a callback function that transforms feature dicts, with GraphTensor mechanics taken off your shoulders." ] }, { @@ -721,7 +723,7 @@ " if node_set_name == \"institution\":\n", " return tf.keras.layers.Embedding(6_500, 16)(node_set[\"hashed_id\"])\n", " if node_set_name == \"paper\":\n", - " return tf.keras.layers.Dense(node_state_dim)(node_set[\"feat\"])\n", + " return tf.keras.layers.Dense(node_state_dim, \"relu\")(node_set[\"feat\"])\n", " if node_set_name == \"author\":\n", " return node_set[\"empty_state\"]\n", " raise KeyError(f\"Unexpected node_set_name='{node_set_name}'\")" @@ -923,7 +925,7 @@ "base_uri": "https://localhost:8080/" }, "id": "Ay2hhL3d0dZz", - "outputId": "6f0d261d-7bea-4b73-ba6d-f5fccbe2b249" + "outputId": "09663b5e-8a98-4753-f900-c24e56f054c1" }, "outputs": [ { @@ -931,25 +933,25 @@ "output_type": "stream", "text": [ "Epoch 1/10\n", - "4918/4918 [==============================] - 138s 28ms/step - loss: 2.6087 - sparse_categorical_accuracy: 0.3227 - sparse_categorical_crossentropy: 2.7150 - val_loss: 2.0895 - val_sparse_categorical_accuracy: 0.4181 - val_sparse_categorical_crossentropy: 2.1522\n", + "4918/4918 [==============================] - 142s 29ms/step - loss: 2.6329 - sparse_categorical_accuracy: 0.3213 - sparse_categorical_crossentropy: 2.7456 - val_loss: 2.1140 - val_sparse_categorical_accuracy: 0.4131 - val_sparse_categorical_crossentropy: 2.1837\n", "Epoch 2/10\n", - "4918/4918 [==============================] - 90s 18ms/step - loss: 2.1015 - sparse_categorical_accuracy: 0.4214 - sparse_categorical_crossentropy: 2.1580 - val_loss: 1.9503 - val_sparse_categorical_accuracy: 0.4458 - val_sparse_categorical_crossentropy: 1.9923\n", + "4918/4918 [==============================] - 91s 18ms/step - loss: 2.1113 - sparse_categorical_accuracy: 0.4201 - sparse_categorical_crossentropy: 2.1739 - val_loss: 1.9439 - val_sparse_categorical_accuracy: 0.4548 - val_sparse_categorical_crossentropy: 1.9911\n", "Epoch 3/10\n", - "4918/4918 [==============================] - 90s 18ms/step - loss: 1.9663 - sparse_categorical_accuracy: 0.4530 - sparse_categorical_crossentropy: 2.0059 - val_loss: 1.8771 - val_sparse_categorical_accuracy: 0.4692 - val_sparse_categorical_crossentropy: 1.9089\n", + "4918/4918 [==============================] - 90s 18ms/step - loss: 1.9727 - sparse_categorical_accuracy: 0.4516 - sparse_categorical_crossentropy: 2.0184 - val_loss: 1.8663 - val_sparse_categorical_accuracy: 0.4672 - val_sparse_categorical_crossentropy: 1.9032\n", "Epoch 4/10\n", - "4918/4918 [==============================] - 92s 19ms/step - loss: 1.8823 - sparse_categorical_accuracy: 0.4733 - sparse_categorical_crossentropy: 1.9131 - val_loss: 1.8625 - val_sparse_categorical_accuracy: 0.4673 - val_sparse_categorical_crossentropy: 1.8919\n", + "4918/4918 [==============================] - 89s 18ms/step - loss: 1.8827 - sparse_categorical_accuracy: 0.4718 - sparse_categorical_crossentropy: 1.9195 - val_loss: 1.8593 - val_sparse_categorical_accuracy: 0.4698 - val_sparse_categorical_crossentropy: 1.8943\n", "Epoch 5/10\n", - "4918/4918 [==============================] - 91s 18ms/step - loss: 1.8097 - sparse_categorical_accuracy: 0.4893 - sparse_categorical_crossentropy: 1.8360 - val_loss: 1.8237 - val_sparse_categorical_accuracy: 0.4772 - val_sparse_categorical_crossentropy: 1.8517\n", + "4918/4918 [==============================] - 90s 18ms/step - loss: 1.8079 - sparse_categorical_accuracy: 0.4894 - sparse_categorical_crossentropy: 1.8400 - val_loss: 1.7997 - val_sparse_categorical_accuracy: 0.4880 - val_sparse_categorical_crossentropy: 1.8320\n", "Epoch 6/10\n", - "4918/4918 [==============================] - 89s 18ms/step - loss: 1.7476 - sparse_categorical_accuracy: 0.5030 - sparse_categorical_crossentropy: 1.7720 - val_loss: 1.8024 - val_sparse_categorical_accuracy: 0.4844 - val_sparse_categorical_crossentropy: 1.8317\n", + "4918/4918 [==============================] - 90s 18ms/step - loss: 1.7434 - sparse_categorical_accuracy: 0.5032 - sparse_categorical_crossentropy: 1.7732 - val_loss: 1.7836 - val_sparse_categorical_accuracy: 0.4879 - val_sparse_categorical_crossentropy: 1.8171\n", "Epoch 7/10\n", - "4918/4918 [==============================] - 91s 19ms/step - loss: 1.6904 - sparse_categorical_accuracy: 0.5155 - sparse_categorical_crossentropy: 1.7141 - val_loss: 1.7533 - val_sparse_categorical_accuracy: 0.4953 - val_sparse_categorical_crossentropy: 1.7822\n", + "4918/4918 [==============================] - 89s 18ms/step - loss: 1.6894 - sparse_categorical_accuracy: 0.5161 - sparse_categorical_crossentropy: 1.7182 - val_loss: 1.7512 - val_sparse_categorical_accuracy: 0.4984 - val_sparse_categorical_crossentropy: 1.7851\n", "Epoch 8/10\n", - "4918/4918 [==============================] - 91s 18ms/step - loss: 1.6432 - sparse_categorical_accuracy: 0.5268 - sparse_categorical_crossentropy: 1.6663 - val_loss: 1.7408 - val_sparse_categorical_accuracy: 0.4973 - val_sparse_categorical_crossentropy: 1.7709\n", + "4918/4918 [==============================] - 91s 18ms/step - loss: 1.6422 - sparse_categorical_accuracy: 0.5261 - sparse_categorical_crossentropy: 1.6702 - val_loss: 1.7340 - val_sparse_categorical_accuracy: 0.5009 - val_sparse_categorical_crossentropy: 1.7686\n", "Epoch 9/10\n", - "4918/4918 [==============================] - 89s 18ms/step - loss: 1.6137 - sparse_categorical_accuracy: 0.5331 - sparse_categorical_crossentropy: 1.6364 - val_loss: 1.7367 - val_sparse_categorical_accuracy: 0.4992 - val_sparse_categorical_crossentropy: 1.7675\n", + "4918/4918 [==============================] - 90s 18ms/step - loss: 1.6122 - sparse_categorical_accuracy: 0.5329 - sparse_categorical_crossentropy: 1.6396 - val_loss: 1.7371 - val_sparse_categorical_accuracy: 0.5003 - val_sparse_categorical_crossentropy: 1.7728\n", "Epoch 10/10\n", - "4918/4918 [==============================] - 91s 19ms/step - loss: 1.6014 - sparse_categorical_accuracy: 0.5355 - sparse_categorical_crossentropy: 1.6238 - val_loss: 1.7309 - val_sparse_categorical_accuracy: 0.4997 - val_sparse_categorical_crossentropy: 1.7614\n" + "4918/4918 [==============================] - 89s 18ms/step - loss: 1.5958 - sparse_categorical_accuracy: 0.5365 - sparse_categorical_crossentropy: 1.6227 - val_loss: 1.7306 - val_sparse_categorical_accuracy: 0.5013 - val_sparse_categorical_crossentropy: 1.7659\n" ] }, { @@ -962,7 +964,7 @@ { "data": { "text/plain": [ - "RunResult(preprocess_model=\u003ckeras.engine.functional.Functional object at 0x7b8b4c62beb0\u003e, base_model=\u003ckeras.engine.sequential.Sequential object at 0x7b8b500d9240\u003e, trained_model=\u003ckeras.engine.functional.Functional object at 0x7b8b4fc4d840\u003e)" + "RunResult(preprocess_model=\u003ckeras.engine.functional.Functional object at 0x7c1309331660\u003e, base_model=\u003ckeras.engine.sequential.Sequential object at 0x7c124546f670\u003e, trained_model=\u003ckeras.engine.functional.Functional object at 0x7c12454f41f0\u003e)" ] }, "execution_count": 17, @@ -1011,23 +1013,23 @@ "base_uri": "https://localhost:8080/" }, "id": "ki33s9EpsQnF", - "outputId": "2f45a98b-70d0-4ebc-d29f-741f30ae80f4" + "outputId": "b87d9ded-70f8-4abb-f9d1-fb5a547d59ff" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The predicted class for input 0 is 9 with predicted probability 0.346\n", - "The predicted class for input 1 is 281 with predicted probability 0.5623\n", - "The predicted class for input 2 is 189 with predicted probability 0.2929\n", - "The predicted class for input 3 is 158 with predicted probability 0.9645\n", - "The predicted class for input 4 is 200 with predicted probability 0.1749\n", - "The predicted class for input 5 is 247 with predicted probability 0.9088\n", - "The predicted class for input 6 is 209 with predicted probability 0.5486\n", - "The predicted class for input 7 is 189 with predicted probability 0.5403\n", - "The predicted class for input 8 is 192 with predicted probability 0.5332\n", - "The predicted class for input 9 is 311 with predicted probability 0.7223\n" + "The predicted class for input 0 is 9 with predicted probability 0.4965\n", + "The predicted class for input 1 is 189 with predicted probability 0.2402\n", + "The predicted class for input 2 is 189 with predicted probability 0.4836\n", + "The predicted class for input 3 is 158 with predicted probability 0.958\n", + "The predicted class for input 4 is 341 with predicted probability 0.2332\n", + "The predicted class for input 5 is 189 with predicted probability 0.5669\n", + "The predicted class for input 6 is 209 with predicted probability 0.3472\n", + "The predicted class for input 7 is 247 with predicted probability 0.7285\n", + "The predicted class for input 8 is 89 with predicted probability 0.4504\n", + "The predicted class for input 9 is 311 with predicted probability 0.8283\n" ] } ], diff --git a/tensorflow_gnn/runner/examples/ogbn/mag/train.py b/tensorflow_gnn/runner/examples/ogbn/mag/train.py index c60f03fd..c262c5f9 100644 --- a/tensorflow_gnn/runner/examples/ogbn/mag/train.py +++ b/tensorflow_gnn/runner/examples/ogbn/mag/train.py @@ -337,7 +337,7 @@ def set_paper_node_state(node_set: tfgnn.NodeSet): else: logging.info("Applying dense layer %d to paper.", _PAPER_DIM.value) embedding_list.append( - tf.keras.layers.Dense(_PAPER_DIM.value)(node_set["feat"]) + tf.keras.layers.Dense(_PAPER_DIM.value, "relu")(node_set["feat"]) ) # Masked label if _MASKED_LABELS.value: