Skip to content

Commit

Permalink
In OGBN-MAG examples, add activation="relu" to the paper feat encoder.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613989561
  • Loading branch information
arnoegw authored and tensorflower-gardener committed Mar 11, 2024
1 parent 9e318b9 commit 2038ce7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 27 deletions.
54 changes: 28 additions & 26 deletions examples/notebooks/ogbn_mag_e2e.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Running on TPU ['10.113.155.218:8470']\n",
"Running on TPU ['10.116.185.2:8470']\n",
"Using TPUStrategy\n",
"Found 8 replicas in sync\n"
]
Expand Down Expand Up @@ -701,7 +701,9 @@
"source": [
"### Initialization of Hidden States\n",
"\n",
"The hidden states on nodes are created by mapping a dict of (preprocessed) features to fixed-size hidden states for nodes. Similarly to feature preprocessing, the `tfgnn.keras.layers.MapFeatures` layer lets you specify such a transformation as a callback function that transforms feature dicts, with GraphTensor mechanics taken off your shoulders."
"The hidden states on nodes are created by mapping a dict of (preprocessed) features to fixed-size hidden states for nodes. It often makes sense to send input features through a small encoder network, like the `Dense` layer applied below to the `\"feat\"` of paper nodes.\n",
"\n",
"Similarly to feature preprocessing, the `tfgnn.keras.layers.MapFeatures` layer lets you specify such a transformation as a callback function that transforms feature dicts, with GraphTensor mechanics taken off your shoulders."
]
},
{
Expand All @@ -721,7 +723,7 @@
" if node_set_name == \"institution\":\n",
" return tf.keras.layers.Embedding(6_500, 16)(node_set[\"hashed_id\"])\n",
" if node_set_name == \"paper\":\n",
" return tf.keras.layers.Dense(node_state_dim)(node_set[\"feat\"])\n",
" return tf.keras.layers.Dense(node_state_dim, \"relu\")(node_set[\"feat\"])\n",
" if node_set_name == \"author\":\n",
" return node_set[\"empty_state\"]\n",
" raise KeyError(f\"Unexpected node_set_name='{node_set_name}'\")"
Expand Down Expand Up @@ -923,33 +925,33 @@
"base_uri": "https://localhost:8080/"
},
"id": "Ay2hhL3d0dZz",
"outputId": "6f0d261d-7bea-4b73-ba6d-f5fccbe2b249"
"outputId": "09663b5e-8a98-4753-f900-c24e56f054c1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"4918/4918 [==============================] - 138s 28ms/step - loss: 2.6087 - sparse_categorical_accuracy: 0.3227 - sparse_categorical_crossentropy: 2.7150 - val_loss: 2.0895 - val_sparse_categorical_accuracy: 0.4181 - val_sparse_categorical_crossentropy: 2.1522\n",
"4918/4918 [==============================] - 142s 29ms/step - loss: 2.6329 - sparse_categorical_accuracy: 0.3213 - sparse_categorical_crossentropy: 2.7456 - val_loss: 2.1140 - val_sparse_categorical_accuracy: 0.4131 - val_sparse_categorical_crossentropy: 2.1837\n",
"Epoch 2/10\n",
"4918/4918 [==============================] - 90s 18ms/step - loss: 2.1015 - sparse_categorical_accuracy: 0.4214 - sparse_categorical_crossentropy: 2.1580 - val_loss: 1.9503 - val_sparse_categorical_accuracy: 0.4458 - val_sparse_categorical_crossentropy: 1.9923\n",
"4918/4918 [==============================] - 91s 18ms/step - loss: 2.1113 - sparse_categorical_accuracy: 0.4201 - sparse_categorical_crossentropy: 2.1739 - val_loss: 1.9439 - val_sparse_categorical_accuracy: 0.4548 - val_sparse_categorical_crossentropy: 1.9911\n",
"Epoch 3/10\n",
"4918/4918 [==============================] - 90s 18ms/step - loss: 1.9663 - sparse_categorical_accuracy: 0.4530 - sparse_categorical_crossentropy: 2.0059 - val_loss: 1.8771 - val_sparse_categorical_accuracy: 0.4692 - val_sparse_categorical_crossentropy: 1.9089\n",
"4918/4918 [==============================] - 90s 18ms/step - loss: 1.9727 - sparse_categorical_accuracy: 0.4516 - sparse_categorical_crossentropy: 2.0184 - val_loss: 1.8663 - val_sparse_categorical_accuracy: 0.4672 - val_sparse_categorical_crossentropy: 1.9032\n",
"Epoch 4/10\n",
"4918/4918 [==============================] - 92s 19ms/step - loss: 1.8823 - sparse_categorical_accuracy: 0.4733 - sparse_categorical_crossentropy: 1.9131 - val_loss: 1.8625 - val_sparse_categorical_accuracy: 0.4673 - val_sparse_categorical_crossentropy: 1.8919\n",
"4918/4918 [==============================] - 89s 18ms/step - loss: 1.8827 - sparse_categorical_accuracy: 0.4718 - sparse_categorical_crossentropy: 1.9195 - val_loss: 1.8593 - val_sparse_categorical_accuracy: 0.4698 - val_sparse_categorical_crossentropy: 1.8943\n",
"Epoch 5/10\n",
"4918/4918 [==============================] - 91s 18ms/step - loss: 1.8097 - sparse_categorical_accuracy: 0.4893 - sparse_categorical_crossentropy: 1.8360 - val_loss: 1.8237 - val_sparse_categorical_accuracy: 0.4772 - val_sparse_categorical_crossentropy: 1.8517\n",
"4918/4918 [==============================] - 90s 18ms/step - loss: 1.8079 - sparse_categorical_accuracy: 0.4894 - sparse_categorical_crossentropy: 1.8400 - val_loss: 1.7997 - val_sparse_categorical_accuracy: 0.4880 - val_sparse_categorical_crossentropy: 1.8320\n",
"Epoch 6/10\n",
"4918/4918 [==============================] - 89s 18ms/step - loss: 1.7476 - sparse_categorical_accuracy: 0.5030 - sparse_categorical_crossentropy: 1.7720 - val_loss: 1.8024 - val_sparse_categorical_accuracy: 0.4844 - val_sparse_categorical_crossentropy: 1.8317\n",
"4918/4918 [==============================] - 90s 18ms/step - loss: 1.7434 - sparse_categorical_accuracy: 0.5032 - sparse_categorical_crossentropy: 1.7732 - val_loss: 1.7836 - val_sparse_categorical_accuracy: 0.4879 - val_sparse_categorical_crossentropy: 1.8171\n",
"Epoch 7/10\n",
"4918/4918 [==============================] - 91s 19ms/step - loss: 1.6904 - sparse_categorical_accuracy: 0.5155 - sparse_categorical_crossentropy: 1.7141 - val_loss: 1.7533 - val_sparse_categorical_accuracy: 0.4953 - val_sparse_categorical_crossentropy: 1.7822\n",
"4918/4918 [==============================] - 89s 18ms/step - loss: 1.6894 - sparse_categorical_accuracy: 0.5161 - sparse_categorical_crossentropy: 1.7182 - val_loss: 1.7512 - val_sparse_categorical_accuracy: 0.4984 - val_sparse_categorical_crossentropy: 1.7851\n",
"Epoch 8/10\n",
"4918/4918 [==============================] - 91s 18ms/step - loss: 1.6432 - sparse_categorical_accuracy: 0.5268 - sparse_categorical_crossentropy: 1.6663 - val_loss: 1.7408 - val_sparse_categorical_accuracy: 0.4973 - val_sparse_categorical_crossentropy: 1.7709\n",
"4918/4918 [==============================] - 91s 18ms/step - loss: 1.6422 - sparse_categorical_accuracy: 0.5261 - sparse_categorical_crossentropy: 1.6702 - val_loss: 1.7340 - val_sparse_categorical_accuracy: 0.5009 - val_sparse_categorical_crossentropy: 1.7686\n",
"Epoch 9/10\n",
"4918/4918 [==============================] - 89s 18ms/step - loss: 1.6137 - sparse_categorical_accuracy: 0.5331 - sparse_categorical_crossentropy: 1.6364 - val_loss: 1.7367 - val_sparse_categorical_accuracy: 0.4992 - val_sparse_categorical_crossentropy: 1.7675\n",
"4918/4918 [==============================] - 90s 18ms/step - loss: 1.6122 - sparse_categorical_accuracy: 0.5329 - sparse_categorical_crossentropy: 1.6396 - val_loss: 1.7371 - val_sparse_categorical_accuracy: 0.5003 - val_sparse_categorical_crossentropy: 1.7728\n",
"Epoch 10/10\n",
"4918/4918 [==============================] - 91s 19ms/step - loss: 1.6014 - sparse_categorical_accuracy: 0.5355 - sparse_categorical_crossentropy: 1.6238 - val_loss: 1.7309 - val_sparse_categorical_accuracy: 0.4997 - val_sparse_categorical_crossentropy: 1.7614\n"
"4918/4918 [==============================] - 89s 18ms/step - loss: 1.5958 - sparse_categorical_accuracy: 0.5365 - sparse_categorical_crossentropy: 1.6227 - val_loss: 1.7306 - val_sparse_categorical_accuracy: 0.5013 - val_sparse_categorical_crossentropy: 1.7659\n"
]
},
{
Expand All @@ -962,7 +964,7 @@
{
"data": {
"text/plain": [
"RunResult(preprocess_model=\u003ckeras.engine.functional.Functional object at 0x7b8b4c62beb0\u003e, base_model=\u003ckeras.engine.sequential.Sequential object at 0x7b8b500d9240\u003e, trained_model=\u003ckeras.engine.functional.Functional object at 0x7b8b4fc4d840\u003e)"
"RunResult(preprocess_model=\u003ckeras.engine.functional.Functional object at 0x7c1309331660\u003e, base_model=\u003ckeras.engine.sequential.Sequential object at 0x7c124546f670\u003e, trained_model=\u003ckeras.engine.functional.Functional object at 0x7c12454f41f0\u003e)"
]
},
"execution_count": 17,
Expand Down Expand Up @@ -1011,23 +1013,23 @@
"base_uri": "https://localhost:8080/"
},
"id": "ki33s9EpsQnF",
"outputId": "2f45a98b-70d0-4ebc-d29f-741f30ae80f4"
"outputId": "b87d9ded-70f8-4abb-f9d1-fb5a547d59ff"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The predicted class for input 0 is 9 with predicted probability 0.346\n",
"The predicted class for input 1 is 281 with predicted probability 0.5623\n",
"The predicted class for input 2 is 189 with predicted probability 0.2929\n",
"The predicted class for input 3 is 158 with predicted probability 0.9645\n",
"The predicted class for input 4 is 200 with predicted probability 0.1749\n",
"The predicted class for input 5 is 247 with predicted probability 0.9088\n",
"The predicted class for input 6 is 209 with predicted probability 0.5486\n",
"The predicted class for input 7 is 189 with predicted probability 0.5403\n",
"The predicted class for input 8 is 192 with predicted probability 0.5332\n",
"The predicted class for input 9 is 311 with predicted probability 0.7223\n"
"The predicted class for input 0 is 9 with predicted probability 0.4965\n",
"The predicted class for input 1 is 189 with predicted probability 0.2402\n",
"The predicted class for input 2 is 189 with predicted probability 0.4836\n",
"The predicted class for input 3 is 158 with predicted probability 0.958\n",
"The predicted class for input 4 is 341 with predicted probability 0.2332\n",
"The predicted class for input 5 is 189 with predicted probability 0.5669\n",
"The predicted class for input 6 is 209 with predicted probability 0.3472\n",
"The predicted class for input 7 is 247 with predicted probability 0.7285\n",
"The predicted class for input 8 is 89 with predicted probability 0.4504\n",
"The predicted class for input 9 is 311 with predicted probability 0.8283\n"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_gnn/runner/examples/ogbn/mag/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def set_paper_node_state(node_set: tfgnn.NodeSet):
else:
logging.info("Applying dense layer %d to paper.", _PAPER_DIM.value)
embedding_list.append(
tf.keras.layers.Dense(_PAPER_DIM.value)(node_set["feat"])
tf.keras.layers.Dense(_PAPER_DIM.value, "relu")(node_set["feat"])
)
# Masked label
if _MASKED_LABELS.value:
Expand Down

0 comments on commit 2038ce7

Please sign in to comment.