Skip to content

Commit

Permalink
TPU support in Colab code: switch to the "TPU v2" runtime on a TPU VM
Browse files Browse the repository at this point in the history
and retire workarounds for the older "TPU (deprecated)" runtime that
uses a separate TPU Node.

Along the way, changing runtimes upgrades from TF2.12 to TF2.15.

PiperOrigin-RevId: 621107610
  • Loading branch information
arnoegw authored and tensorflower-gardener committed Apr 2, 2024
1 parent 3c65d57 commit 4749d94
Show file tree
Hide file tree
Showing 2 changed files with 279 additions and 237 deletions.
136 changes: 78 additions & 58 deletions examples/notebooks/ogbn_mag_e2e.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,25 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 10039,
"status": "ok",
"timestamp": 1711472628551,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -60
},
"id": "oA4_zh0EyNHv",
"outputId": "4e3e16b7-64dd-4516-99da-8cea252750d8"
"outputId": "8b415cad-86b7-4169-cc9f-2dd9f6b02f2c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running TF-GNN 1.0.2 under TensorFlow 2.12.0.\n"
"Running TF-GNN 1.0.2 under TensorFlow 2.15.0.\n"
]
}
],
Expand Down Expand Up @@ -435,10 +445,9 @@
"source": [
"## Distributed Training\n",
"\n",
"\n",
"\n",
"We use TensorFlow's [Distribution Strategy](https://www.tensorflow.org/guide/distributed_training) API to write a model that can run on multiple TPUs, multiple GPUs, or maybe just locally on CPU.\n",
"\n"
"\n",
"For CloudTPU, the following code assumes the Colab runtime type \"TPU v2\", that is, a TPU VM. Do not use the runtime type \"TPU (deprecated)\", which uses a TPU Node on a separate VM."
]
},
{
Expand All @@ -448,31 +457,34 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 26820,
"status": "ok",
"timestamp": 1711472717800,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -60
},
"id": "2oBuJEZ3izQm",
"outputId": "680d981b-9ee6-4ffe-d696-c70110edadca"
"outputId": "db98cb52-837a-4552-cf63-606cb88ffa25"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on TPU ['10.116.185.2:8470']\n",
"Using TPUStrategy\n",
"Found 8 replicas in sync\n"
]
}
],
"source": [
"try:\n",
" tpu_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()\n",
" print(\"Running on TPU \", tpu_resolver.cluster_spec().as_dict()[\"worker\"])\n",
"except:\n",
" tpu_resolver = None\n",
"\n",
"if tpu_resolver:\n",
" print(\"Using TPUStrategy\")\n",
"if tf.config.list_physical_devices(\"TPU\"):\n",
" print(f\"Using TPUStrategy\")\n",
" min_nodes_per_component = {\"paper\": 1}\n",
" strategy = runner.TPUStrategy()\n",
" strategy = runner.TPUStrategy(\"local\")\n",
" train_padding = runner.FitOrSkipPadding(example_input_graph_spec, train_ds_provider, min_nodes_per_component)\n",
" valid_padding = runner.TightPadding(example_input_graph_spec, valid_ds_provider, min_nodes_per_component)\n",
"elif tf.config.list_physical_devices(\"GPU\"):\n",
Expand Down Expand Up @@ -846,8 +858,8 @@
"global_batch_size = 128\n",
"epochs = 10\n",
"initial_learning_rate = 0.001\n",
"if tpu_resolver:\n",
" # Training on TPU takes ~90 secs / epoch, so we train for the entire epoch.\n",
"if tf.config.list_physical_devices(\"TPU\"):\n",
" # Training on TPU takes ~130 secs / epoch, so we train for the entire epoch.\n",
" epoch_divisor = 1\n",
"else:\n",
" # Training on GPU / CPU is slower, so we train for 1/100th of a true epoch.\n",
Expand Down Expand Up @@ -882,9 +894,7 @@
"source": [
"## Export options for inference\n",
"\n",
"For inference, a SavedModel must be exported by the runner at the end of training. C++ inference environments like TF Serving do not support input of extension types like GraphTensor, so the `KerasModelExporter` exports the model with a SavedModel Signature that accepts a batch of serialized tf.Examples and preprocesses them like training did.\n",
"\n",
"Note: After connecting this Colab to a TPU worker, explicit device placements are necessary to do the test on the colab host (which has the `/tmp/gnn_model` directory)."
"For inference, a SavedModel must be exported by the runner at the end of training. C++ inference environments like TF Serving do not support input of extension types like GraphTensor, so the `KerasModelExporter` exports the model with a SavedModel Signature that accepts a batch of serialized tf.Examples and preprocesses them like training did."
]
},
{
Expand All @@ -895,9 +905,7 @@
},
"outputs": [],
"source": [
"save_options = tf.saved_model.SaveOptions(experimental_io_device=\"/job:localhost\")\n",
"model_exporter = runner.KerasModelExporter(output_names=\"paper_venue_logits\",\n",
" options=save_options)"
"model_exporter = runner.KerasModelExporter(output_names=\"paper_venue_logits\")"
]
},
{
Expand All @@ -924,47 +932,50 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 427499,
"status": "ok",
"timestamp": 1711474246342,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -60
},
"id": "Ay2hhL3d0dZz",
"outputId": "09663b5e-8a98-4753-f900-c24e56f054c1"
"outputId": "70fa9f6c-a2c5-4bfa-a0ef-c8ad295c50cb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"4918/4918 [==============================] - 142s 29ms/step - loss: 2.6329 - sparse_categorical_accuracy: 0.3213 - sparse_categorical_crossentropy: 2.7456 - val_loss: 2.1140 - val_sparse_categorical_accuracy: 0.4131 - val_sparse_categorical_crossentropy: 2.1837\n",
"4918/4918 [==============================] - 170s 34ms/step - loss: 2.6135 - sparse_categorical_accuracy: 0.3245 - sparse_categorical_crossentropy: 2.7248 - val_loss: 2.0817 - val_sparse_categorical_accuracy: 0.4239 - val_sparse_categorical_crossentropy: 2.1490\n",
"Epoch 2/10\n",
"4918/4918 [==============================] - 91s 18ms/step - loss: 2.1113 - sparse_categorical_accuracy: 0.4201 - sparse_categorical_crossentropy: 2.1739 - val_loss: 1.9439 - val_sparse_categorical_accuracy: 0.4548 - val_sparse_categorical_crossentropy: 1.9911\n",
"4918/4918 [==============================] - 129s 26ms/step - loss: 2.1057 - sparse_categorical_accuracy: 0.4225 - sparse_categorical_crossentropy: 2.1676 - val_loss: 2.0053 - val_sparse_categorical_accuracy: 0.4330 - val_sparse_categorical_crossentropy: 2.0561\n",
"Epoch 3/10\n",
"4918/4918 [==============================] - 90s 18ms/step - loss: 1.9727 - sparse_categorical_accuracy: 0.4516 - sparse_categorical_crossentropy: 2.0184 - val_loss: 1.8663 - val_sparse_categorical_accuracy: 0.4672 - val_sparse_categorical_crossentropy: 1.9032\n",
"4918/4918 [==============================] - 128s 26ms/step - loss: 1.9673 - sparse_categorical_accuracy: 0.4541 - sparse_categorical_crossentropy: 2.0124 - val_loss: 1.8902 - val_sparse_categorical_accuracy: 0.4703 - val_sparse_categorical_crossentropy: 1.9283\n",
"Epoch 4/10\n",
"4918/4918 [==============================] - 89s 18ms/step - loss: 1.8827 - sparse_categorical_accuracy: 0.4718 - sparse_categorical_crossentropy: 1.9195 - val_loss: 1.8593 - val_sparse_categorical_accuracy: 0.4698 - val_sparse_categorical_crossentropy: 1.8943\n",
"4918/4918 [==============================] - 130s 26ms/step - loss: 1.8787 - sparse_categorical_accuracy: 0.4740 - sparse_categorical_crossentropy: 1.9149 - val_loss: 1.8447 - val_sparse_categorical_accuracy: 0.4803 - val_sparse_categorical_crossentropy: 1.8784\n",
"Epoch 5/10\n",
"4918/4918 [==============================] - 90s 18ms/step - loss: 1.8079 - sparse_categorical_accuracy: 0.4894 - sparse_categorical_crossentropy: 1.8400 - val_loss: 1.7997 - val_sparse_categorical_accuracy: 0.4880 - val_sparse_categorical_crossentropy: 1.8320\n",
"4918/4918 [==============================] - 129s 26ms/step - loss: 1.8062 - sparse_categorical_accuracy: 0.4904 - sparse_categorical_crossentropy: 1.8378 - val_loss: 1.8227 - val_sparse_categorical_accuracy: 0.4787 - val_sparse_categorical_crossentropy: 1.8559\n",
"Epoch 6/10\n",
"4918/4918 [==============================] - 90s 18ms/step - loss: 1.7434 - sparse_categorical_accuracy: 0.5032 - sparse_categorical_crossentropy: 1.7732 - val_loss: 1.7836 - val_sparse_categorical_accuracy: 0.4879 - val_sparse_categorical_crossentropy: 1.8171\n",
"4918/4918 [==============================] - 130s 26ms/step - loss: 1.7416 - sparse_categorical_accuracy: 0.5043 - sparse_categorical_crossentropy: 1.7708 - val_loss: 1.7801 - val_sparse_categorical_accuracy: 0.4919 - val_sparse_categorical_crossentropy: 1.8128\n",
"Epoch 7/10\n",
"4918/4918 [==============================] - 89s 18ms/step - loss: 1.6894 - sparse_categorical_accuracy: 0.5161 - sparse_categorical_crossentropy: 1.7182 - val_loss: 1.7512 - val_sparse_categorical_accuracy: 0.4984 - val_sparse_categorical_crossentropy: 1.7851\n",
"4918/4918 [==============================] - 133s 27ms/step - loss: 1.6856 - sparse_categorical_accuracy: 0.5167 - sparse_categorical_crossentropy: 1.7136 - val_loss: 1.7456 - val_sparse_categorical_accuracy: 0.4999 - val_sparse_categorical_crossentropy: 1.7787\n",
"Epoch 8/10\n",
"4918/4918 [==============================] - 91s 18ms/step - loss: 1.6422 - sparse_categorical_accuracy: 0.5261 - sparse_categorical_crossentropy: 1.6702 - val_loss: 1.7340 - val_sparse_categorical_accuracy: 0.5009 - val_sparse_categorical_crossentropy: 1.7686\n",
"4918/4918 [==============================] - 130s 26ms/step - loss: 1.6424 - sparse_categorical_accuracy: 0.5263 - sparse_categorical_crossentropy: 1.6700 - val_loss: 1.7497 - val_sparse_categorical_accuracy: 0.4955 - val_sparse_categorical_crossentropy: 1.7849\n",
"Epoch 9/10\n",
"4918/4918 [==============================] - 90s 18ms/step - loss: 1.6122 - sparse_categorical_accuracy: 0.5329 - sparse_categorical_crossentropy: 1.6396 - val_loss: 1.7371 - val_sparse_categorical_accuracy: 0.5003 - val_sparse_categorical_crossentropy: 1.7728\n",
"4918/4918 [==============================] - 131s 27ms/step - loss: 1.6112 - sparse_categorical_accuracy: 0.5332 - sparse_categorical_crossentropy: 1.6382 - val_loss: 1.7343 - val_sparse_categorical_accuracy: 0.5013 - val_sparse_categorical_crossentropy: 1.7693\n",
"Epoch 10/10\n",
"4918/4918 [==============================] - 89s 18ms/step - loss: 1.5958 - sparse_categorical_accuracy: 0.5365 - sparse_categorical_crossentropy: 1.6227 - val_loss: 1.7306 - val_sparse_categorical_accuracy: 0.5013 - val_sparse_categorical_crossentropy: 1.7659\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
"4918/4918 [==============================] - 132s 27ms/step - loss: 1.5959 - sparse_categorical_accuracy: 0.5372 - sparse_categorical_crossentropy: 1.6224 - val_loss: 1.7417 - val_sparse_categorical_accuracy: 0.4991 - val_sparse_categorical_crossentropy: 1.7773\n"
]
},
{
"data": {
"text/plain": [
"RunResult(preprocess_model=\u003ckeras.engine.functional.Functional object at 0x7c1309331660\u003e, base_model=\u003ckeras.engine.sequential.Sequential object at 0x7c124546f670\u003e, trained_model=\u003ckeras.engine.functional.Functional object at 0x7c12454f41f0\u003e)"
"RunResult(preprocess_model=\u003ckeras.src.engine.functional.Functional object at 0x7fed485eee90\u003e, base_model=\u003ckeras.src.engine.sequential.Sequential object at 0x7febfa0b3280\u003e, trained_model=\u003ckeras.src.engine.functional.Functional object at 0x7fec8811a9e0\u003e)"
]
},
"execution_count": 17,
Expand Down Expand Up @@ -997,7 +1008,7 @@
},
"source": [
"## Inference using Exported Model\n",
"At the end of training, a SavedModel is exported by the Runner for inference. For demonstration, let's call the exported model on the validation dataset from above, but without labels. We load it as a SavedModel, like TF Serving would. Analogous to the SaveOptions above, LoadOptions with a device placement are necessary when connecting this Colab to a TPU worker.\n",
"At the end of training, a SavedModel is exported by the Runner for inference. For demonstration, let's call the exported model on the validation dataset from above, but without labels. We load it as a SavedModel, like TF Serving would.\n",
"\n",
"NOTE: TF Serving usually expects examples in form of serialized strings, therefore we explicitly convert the graph tensors to serialized string format and pass it to the loaded model.\n",
"\n",
Expand All @@ -1012,32 +1023,40 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 51166,
"status": "ok",
"timestamp": 1711474297507,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": -60
},
"id": "ki33s9EpsQnF",
"outputId": "b87d9ded-70f8-4abb-f9d1-fb5a547d59ff"
"outputId": "8e6a7ba6-514e-4dda-f96f-deea16d185b1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The predicted class for input 0 is 9 with predicted probability 0.4965\n",
"The predicted class for input 1 is 189 with predicted probability 0.2402\n",
"The predicted class for input 2 is 189 with predicted probability 0.4836\n",
"The predicted class for input 3 is 158 with predicted probability 0.958\n",
"The predicted class for input 4 is 341 with predicted probability 0.2332\n",
"The predicted class for input 5 is 189 with predicted probability 0.5669\n",
"The predicted class for input 6 is 209 with predicted probability 0.3472\n",
"The predicted class for input 7 is 247 with predicted probability 0.7285\n",
"The predicted class for input 8 is 89 with predicted probability 0.4504\n",
"The predicted class for input 9 is 311 with predicted probability 0.8283\n"
"The predicted class for input 0 is 9 with predicted probability 0.3137\n",
"The predicted class for input 1 is 281 with predicted probability 0.2777\n",
"The predicted class for input 2 is 189 with predicted probability 0.4749\n",
"The predicted class for input 3 is 158 with predicted probability 0.9535\n",
"The predicted class for input 4 is 82 with predicted probability 0.3277\n",
"The predicted class for input 5 is 247 with predicted probability 0.299\n",
"The predicted class for input 6 is 209 with predicted probability 0.4056\n",
"The predicted class for input 7 is 247 with predicted probability 0.593\n",
"The predicted class for input 8 is 192 with predicted probability 0.5478\n",
"The predicted class for input 9 is 311 with predicted probability 0.7335\n"
]
}
],
"source": [
"# Load model.\n",
"load_options = tf.saved_model.LoadOptions(experimental_io_device=\"/job:localhost\")\n",
"saved_model = tf.saved_model.load(os.path.join(trainer.model_dir, \"export\"),\n",
" options=load_options)\n",
"saved_model = tf.saved_model.load(os.path.join(trainer.model_dir, \"export\"))\n",
"signature_fn = saved_model.signatures[\n",
" tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n",
"\n",
Expand Down Expand Up @@ -1104,6 +1123,7 @@
"ScitaPqhKtuW"
],
"name": "Solving OGBN-MAG end-to-end with TF-GNN",
"gpuType": "V28",
"provenance": []
},
"kernelspec": {
Expand Down
Loading

0 comments on commit 4749d94

Please sign in to comment.