From 25e4586ad4587d47380e1ca8ae224dc24b9957dc Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Thu, 16 May 2024 16:53:00 +0200 Subject: [PATCH] update PaliGemma fine-tuning notebook --- ...etune-paligemma-on-detection-dataset.ipynb | 231 ++++++++++++------ 1 file changed, 158 insertions(+), 73 deletions(-) diff --git a/notebooks/how-to-finetune-paligemma-on-detection-dataset.ipynb b/notebooks/how-to-finetune-paligemma-on-detection-dataset.ipynb index 5a52d97..87dcd74 100644 --- a/notebooks/how-to-finetune-paligemma-on-detection-dataset.ipynb +++ b/notebooks/how-to-finetune-paligemma-on-detection-dataset.ipynb @@ -58,24 +58,24 @@ ], "metadata": { "id": "Wtvz4QZ9YuG8", - "outputId": "c652c668-7beb-40cd-9683-c80306c372ae", + "outputId": "52eb718b-71b6-4a68-b921-034bc4c31657", "colab": { "base_uri": "https://localhost:8080/" } }, - "execution_count": 1, + "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/74.9 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m74.9/74.9 kB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m111.0/111.0 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m158.3/158.3 kB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/74.9 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m74.9/74.9 kB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/111.0 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m111.0/111.0 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m158.3/158.3 kB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m178.7/178.7 kB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.1/49.1 MB\u001b[0m \u001b[31m30.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.5/54.5 kB\u001b[0m \u001b[31m6.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.1/49.1 MB\u001b[0m \u001b[31m29.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.5/54.5 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } @@ -96,12 +96,12 @@ ], "metadata": { "id": "TGDFTYVnY4zn", - "outputId": "406c3ecb-a2f1-41b7-b2d1-79517518bcfe", + "outputId": "b654575f-0d3e-4ec2-ac2e-1dfddf7d9a39", "colab": { "base_uri": "https://localhost:8080/" } }, - "execution_count": 2, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -116,7 +116,7 @@ "output_type": "stream", "name": "stderr", "text": [ - "Downloading Dataset Version Zip in fracture-detection-3 to yolov8:: 100%|██████████| 25468/25468 [00:01<00:00, 16833.18it/s]" + "Downloading Dataset Version Zip in fracture-detection-3 to yolov8:: 100%|██████████| 25468/25468 [00:00<00:00, 29882.82it/s]" ] }, { @@ -131,7 +131,7 @@ "name": "stderr", "text": [ "\n", - "Extracting Dataset Version Zip to fracture-detection-3 in yolov8:: 100%|██████████| 2082/2082 [00:00<00:00, 9227.03it/s]\n" + "Extracting Dataset Version Zip to fracture-detection-3 in yolov8:: 100%|██████████| 2082/2082 [00:00<00:00, 8391.99it/s]\n" ] } ] @@ -233,7 +233,7 @@ "metadata": { "id": "reRShie2ZFcH" }, - "execution_count": 3, + "execution_count": null, "outputs": [] }, { @@ -256,7 +256,7 @@ "metadata": { "id": "QnGTgGY0ZLxA" }, - "execution_count": 4, + "execution_count": null, "outputs": [] }, { @@ -266,12 +266,12 @@ ], "metadata": { "id": "N8xQtqC3ZOkJ", - "outputId": "749ffb4b-914c-49c9-bcdd-aae5f3225344", + "outputId": "71416628-a42b-448f-83b8-76aaf27a4504", "colab": { "base_uri": "https://localhost:8080/" } }, - "execution_count": 5, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -295,9 +295,9 @@ "base_uri": "https://localhost:8080/" }, "id": "ojgIqcZv6oPq", - "outputId": "9bb5bdd2-2aa5-4c11-8eb2-5f316b967778" + "outputId": "2d8276fc-b143-46f2-fee7-00fc7ac4969d" }, - "execution_count": 6, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -329,7 +329,7 @@ "metadata": { "id": "B6RGe_y8ZRPg" }, - "execution_count": 7, + "execution_count": null, "outputs": [] }, { @@ -341,7 +341,7 @@ "metadata": { "id": "sncufv4lZaLa" }, - "execution_count": 8, + "execution_count": null, "outputs": [] }, { @@ -354,9 +354,9 @@ "base_uri": "https://localhost:8080/" }, "id": "WLhSenP5AtQe", - "outputId": "412bedbf-df9f-4866-9001-dca13e9f096e" + "outputId": "370aa8f9-d227-4f40-c07e-0d1dfc4b9314" }, - "execution_count": 9, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -381,9 +381,9 @@ "base_uri": "https://localhost:8080/" }, "id": "YwHY21ABA0WG", - "outputId": "4e925557-3886-44d5-eede-e2ab0c975951" + "outputId": "da408895-5dda-4669-f98e-7c0e3bc5c8fc" }, - "execution_count": 10, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -469,7 +469,7 @@ "metadata": { "id": "gkakRrzkJgdq" }, - "execution_count": 11, + "execution_count": null, "outputs": [] }, { @@ -487,9 +487,9 @@ "height": 657 }, "id": "K9pjdoSmYvqG", - "outputId": "07554a2c-9481-468d-8fcd-1196ff1bd476" + "outputId": "e6044758-fb06-4254-8ec8-caa2a0f5b016" }, - "execution_count": 12, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -515,10 +515,10 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": { "id": "DfxKb3F839Ks", - "outputId": "b292131e-cb1a-4600-dde9-e68034a3a651", + "outputId": "aeb78031-aa54-4d39-a014-69818c2ea9cf", "colab": { "base_uri": "https://localhost:8080/" } @@ -528,9 +528,9 @@ "output_type": "stream", "name": "stdout", "text": [ - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Building wheel for ml_collections (setup.py) ... \u001b[?25l\u001b[?25hdone\n" ] } @@ -579,7 +579,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": { "id": "zGLIp1Cx3_CX" }, @@ -597,13 +597,13 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gQNOTfF24AV4", - "outputId": "96cb1910-0ef9-4a0e-a338-22210e9ec658" + "outputId": "7f121e65-db2c-4805-e532-24c9b4e4a638" }, "outputs": [ { @@ -618,7 +618,7 @@ "name": "stderr", "text": [ "Downloading from https://www.kaggle.com/api/v1/models/google/paligemma/jax/paligemma-3b-pt-224/1/download/paligemma-3b-pt-224.f16.npz...\n", - "100%|██████████| 5.45G/5.45G [02:33<00:00, 38.0MB/s]\n" + "100%|██████████| 5.45G/5.45G [01:43<00:00, 56.3MB/s]\n" ] }, { @@ -628,7 +628,7 @@ "Model path: /root/.cache/kagglehub/models/google/paligemma/jax/paligemma-3b-pt-224/1/./paligemma-3b-pt-224.f16.npz\n", "Downloading the model tokenizer...\n", "Copying gs://big_vision/paligemma_tokenizer.model...\n", - "/ [1 files][ 4.1 MiB/ 4.1 MiB] \n", + "- [1 files][ 4.1 MiB/ 4.1 MiB] \n", "Operation completed over 1 objects/4.1 MiB. \n", "Tokenizer path: ./paligemma_tokenizer.model\n" ] @@ -652,13 +652,7 @@ "if not os.path.exists(TOKENIZER_PATH):\n", " print(\"Downloading the model tokenizer...\")\n", " !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}\n", - " print(f\"Tokenizer path: {TOKENIZER_PATH}\")\n", - "\n", - "# DATA_DIR=\"./longcap100\"\n", - "# if not os.path.exists(DATA_DIR):\n", - "# print(\"Downloading the dataset...\")\n", - "# !gsutil -m -q cp -n -r gs://longcap100/ .\n", - "# print(f\"Data path: {DATA_DIR}\")" + " print(f\"Tokenizer path: {TOKENIZER_PATH}\")" ] }, { @@ -672,13 +666,13 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dTfe2k8J4Bw0", - "outputId": "76b14d9e-b154-4825-aa6e-c52b27fbffcf" + "outputId": "553973db-81e9-4219-94a1-2741d58efffb" }, "outputs": [ { @@ -731,7 +725,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": { "id": "1aghcULcEdtv" }, @@ -775,18 +769,18 @@ "metadata": { "id": "2LNRDMMwXFJ9" }, - "execution_count": 18, + "execution_count": null, "outputs": [] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "RWOdf_fw2SAO", - "outputId": "f666f219-5fc8-4cf7-edc3-c9c9e0235e15" + "outputId": "6b73dd60-4a87-4849-de4f-e74b042e1ada" }, "outputs": [ { @@ -887,7 +881,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": { "id": "8SRW0NuU4UcW" }, @@ -945,7 +939,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": { "id": "whzWOojGOtzi" }, @@ -1005,14 +999,14 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 298 }, "id": "BzJfb5t0nsLq", - "outputId": "fa2b30bc-6fcc-4367-f178-1a4d72155112" + "outputId": "55c17b4f-a713-46ff-cf89-13953e180765" }, "outputs": [ { @@ -1031,43 +1025,43 @@ "text/html": [ "\n", "
\n", - " \n", - "

<loc0330><loc0522><loc0474><loc0770> fracture ; <loc0578><loc0456><loc0671><loc0655> fracture

\n", + " \n", + "

<loc0368><loc0290><loc0511><loc0338> fracture

\n", "
\n", " \n", "
\n", - " \n", - "

<loc0401><loc0047><loc0491><loc0164> fracture

\n", + " \n", + "

<loc0677><loc0418><loc0735><loc0509> fracture

\n", "
\n", " \n", "
\n", - " \n", - "

<loc0564><loc0421><loc0649><loc0546> fracture

\n", + " \n", + "

<loc0198><loc0474><loc0288><loc0676> fracture

\n", "
\n", " \n", "
\n", - " \n", - "

<loc0532><loc0200><loc0658><loc0358> fracture

\n", + " \n", + "

<loc0325><loc0474><loc0588><loc0693> fracture

\n", "
\n", " \n", "
\n", - " \n", - "

<loc0263><loc0205><loc0415><loc0415> fracture

\n", + " \n", + "

<loc0493><loc0231><loc0569><loc0389> fracture

\n", "
\n", " \n", "
\n", - " \n", - "

<loc0743><loc0302><loc0792><loc0434> fracture

\n", + " \n", + "

<loc0688><loc0379><loc0788><loc0516> fracture

\n", "
\n", " \n", "
\n", - " \n", - "

<loc0396><loc0413><loc0460><loc0501> fracture ; <loc0454><loc0308><loc0526><loc0416> fracture

\n", + " \n", + "

<loc0373><loc0288><loc0498><loc0542> fracture

\n", "
\n", " \n", "
\n", - " \n", - "

<loc0751><loc0753><loc0845><loc0876> fracture

\n", + " \n", + "

<loc0819><loc0250><loc0878><loc0405> fracture

\n", "
\n", " " ] @@ -1119,7 +1113,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": { "id": "dwUV_imW3WQJ" }, @@ -1205,7 +1199,98 @@ }, { "cell_type": "code", - "execution_count": 24, + "source": [ + "# @title Let't check model performance without finetuning\n", + "\n", + "print(\"Model predictions\")\n", + "html_out = \"\"\n", + "for image, caption in make_predictions(validation_data_iterator(), num_examples=8, batch_size=4):\n", + " html_out += render_example(image, caption)\n", + "display(HTML(html_out))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 298 + }, + "id": "GCXYnIdm4ILQ", + "outputId": "e5099da4-b012-4faf-cc93-b8a826073974" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model predictions\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0000><loc0000><loc0907><loc1015> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0000><loc0000><loc1023><loc1015> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0105><loc0000><loc0968><loc1023> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0000><loc0000><loc1015><loc1023> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0073><loc0216><loc1023><loc0919> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0000><loc0000><loc1015><loc0964> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0089><loc0000><loc1023><loc1023> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0000><loc0000><loc1017><loc0754> fracture

\n", + "
\n", + " " + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "**NOTE:** We see that in most cases, the model defaults to returning a box containing the entire skeleton visible in the image. We need to be more precise." + ], + "metadata": { + "id": "7QLWAH8l4cJc" + } + }, + { + "cell_type": "code", + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -1869,7 +1954,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -2436,7 +2521,7 @@ "metadata": { "id": "zyVxKr2FOxPe" }, - "execution_count": 26, + "execution_count": null, "outputs": [] } ],