From 7aab21a3f8a653e75369ad0093700cb4c6f11033 Mon Sep 17 00:00:00 2001 From: Matteo Bando <105106139+bandomatteo@users.noreply.github.com> Date: Fri, 3 Oct 2025 13:37:25 +0200 Subject: [PATCH] fix: run model in eval mode with inference_mode for predictions in cell 13 of 02_pytorch_classification.ipynb --- 02_pytorch_classification.ipynb | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/02_pytorch_classification.ipynb b/02_pytorch_classification.ipynb index 03064f5d..eb03eafb 100644 --- a/02_pytorch_classification.ipynb +++ b/02_pytorch_classification.ipynb @@ -82,7 +82,7 @@ "source": [ "## Where can you get help?\n", "\n", - "All of the materials for this course [live on GitHub](https://github.com/mrdbourke/pytorch-deep-learning).\n", + "All of the materials for this course [live on GitHub](https://github.com/mrdbourke/pytorch-deep-learning). \n", "\n", "And if you run into trouble, you can ask a question on the [Discussions page](https://github.com/mrdbourke/pytorch-deep-learning/discussions) there too.\n", "\n", @@ -877,7 +877,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -912,7 +912,9 @@ ], "source": [ "# Make predictions with the model\n", - "untrained_preds = model_0(X_test.to(device))\n", + "model_0.eval()\n", + "with torch.inference_mode():\n", + " untrained_preds = model_0(X_test.to(device))\n", "print(f\"Length of predictions: {len(untrained_preds)}, Shape: {untrained_preds.shape}\")\n", "print(f\"Length of test samples: {len(y_test)}, Shape: {y_test.shape}\")\n", "print(f\"\\nFirst 10 predictions:\\n{untrained_preds[:10]}\")\n",