From b3d3bc849b4068aa074b0f42be104d0188b4d1a2 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Thu, 16 May 2024 16:14:10 +0200 Subject: [PATCH] add PaliGemma fine-tuning notebook --- ...etune-paligemma-on-detection-dataset.ipynb | 2460 +++++++++++++++++ 1 file changed, 2460 insertions(+) create mode 100644 notebooks/how-to-finetune-paligemma-on-detection-dataset.ipynb diff --git a/notebooks/how-to-finetune-paligemma-on-detection-dataset.ipynb b/notebooks/how-to-finetune-paligemma-on-detection-dataset.ipynb new file mode 100644 index 0000000..5a52d97 --- /dev/null +++ b/notebooks/how-to-finetune-paligemma-on-detection-dataset.ipynb @@ -0,0 +1,2460 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "[![Roboflow Notebooks](https://media.roboflow.com/notebooks/template/bannertest2-2.png?ik-sdk-version=javascript-1.4.3&updatedAt=1672932710194)](https://github.com/roboflow/notebooks)\n", + "\n", + "# Fine-tune PaliGemma on Object Detection Dataset\n", + "\n", + "---\n", + "\n", + "[![GitHub](https://badges.aleen42.com/src/github.svg)](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md)\n", + "\n", + "PaliGemma is an open vision-language model (VLM) inspired by PaLI-3, built with\n", + "open components, such as\n", + "the [SigLIP vision model](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/SigLIP_demo.ipynb)\n", + "and\n", + "the [Gemma language model](https://ai.google.dev/gemma).\n", + "PaliGemma is designed as a versatile model for transfer to a wide range of\n", + "vision-language tasks such as image and short video caption, visual question\n", + "answering, text reading, object detection and object segmentation. Together with\n", + "the pretrained and transfer checkpoints at multiple resolutions, we provide a\n", + "checkpoint transferred to a mixture of tasks that can be used for off-the-shelf\n", + "exploration.\n", + "\n", + "This notebook is an extension of the [official notebook](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/finetune_paligemma.ipynb) prepared by Google Research.\n", + "\n", + "![PaliGemma model](https://storage.cloud.google.com/com-roboflow-marketing/notebooks/examples/paligemma.png)\n", + "\n", + "To make it runnable on a T4 colab runtime with 16GB HBM and 12GB RAM, we opt to only finetune the attention layers of the language model and freeze the other parameters.\n", + "\n", + " * Download and parse Roboflow Universe dataset.\n", + " * Install deps, download model checkpoint and training data.\n", + " * Load the model onto GPU devices.\n", + " * Prepare the input to the model for training and inference.\n", + " * Finetune the model and inspect output in validation split." + ], + "metadata": { + "id": "4LqvmtZPzyY1" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Download and parse Roboflow Universe dataset\n", + "\n", + "**NOTE:**\n", + "PaliGemma requires the dataset to be in the appropriate format. Let's start by parsing the dataset from the standard YOLO format to JSONL, which is compatible with the PaliGemma training pipeline." + ], + "metadata": { + "id": "FMlw3ru1YvLg" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install -q roboflow supervision" + ], + "metadata": { + "id": "Wtvz4QZ9YuG8", + "outputId": "c652c668-7beb-40cd-9683-c80306c372ae", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/74.9 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m74.9/74.9 kB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m111.0/111.0 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m158.3/158.3 kB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m178.7/178.7 kB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.1/49.1 MB\u001b[0m \u001b[31m30.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.5/54.5 kB\u001b[0m \u001b[31m6.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "from google.colab import userdata\n", + "from roboflow import Roboflow\n", + "\n", + "ROBOFLOW_API_KEY = userdata.get('ROBOFLOW_API_KEY')\n", + "\n", + "rf = Roboflow(api_key=ROBOFLOW_API_KEY)\n", + "project = rf.workspace(\"srinithi-s-tzdkb\").project(\"fracture-detection-rhud5\")\n", + "version = project.version(3)\n", + "dataset = version.download(\"yolov8\")" + ], + "metadata": { + "id": "TGDFTYVnY4zn", + "outputId": "406c3ecb-a2f1-41b7-b2d1-79517518bcfe", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "loading Roboflow workspace...\n", + "loading Roboflow project...\n", + "[WARNING] we noticed you are downloading a `yolov8` datasets but you don't have `ultralytics` installed. Roboflow `.deploy` supports only models trained with `ultralytics==8.0.196`, to intall it `pip install ultralytics==8.0.196`.\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading Dataset Version Zip in fracture-detection-3 to yolov8:: 100%|██████████| 25468/25468 [00:01<00:00, 16833.18it/s]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\n", + "Extracting Dataset Version Zip to fracture-detection-3 in yolov8:: 100%|██████████| 2082/2082 [00:00<00:00, 9227.03it/s]\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import json\n", + "import shutil\n", + "import numpy as np\n", + "import supervision as sv\n", + "from typing import List, Dict, Optional\n", + "\n", + "\n", + "def ensure_directory_exists(directory: str) -> None:\n", + " if not os.path.exists(directory):\n", + " os.makedirs(directory)\n", + "\n", + "\n", + "def extract_entries(\n", + " dataset: sv.DetectionDataset,\n", + " classes: List[str],\n", + " start_idx: Optional[int] = 0,\n", + " end_idx: Optional[int] = None\n", + ") -> List[Dict[str, str]]:\n", + " entries = []\n", + "\n", + " for index, (image_path, image, annotations) in enumerate(dataset):\n", + " if index < start_idx:\n", + " continue\n", + " if end_idx is not None and index >= end_idx:\n", + " break\n", + "\n", + " height, width, _ = image.shape\n", + "\n", + " class_names = set(classes[class_id] for class_id in annotations.class_id)\n", + " prefix = \" ; \".join(class_names)\n", + " prefix = \"detect \" + prefix\n", + "\n", + " suffix_components = []\n", + " for xyxy, class_id in zip(annotations.xyxy, annotations.class_id):\n", + " xyxy = xyxy.copy()\n", + " xyxy /= np.array([width, height, width, height])\n", + " xyxy *= 1024\n", + " suffix_component = f\" {classes[class_id]}\"\n", + " suffix_components.append(suffix_component)\n", + "\n", + " suffix = \" ; \".join(suffix_components)\n", + " image = os.path.basename(image_path)\n", + " entries.append({\"prefix\": prefix, \"suffix\": suffix, \"image\": image})\n", + "\n", + " return entries\n", + "\n", + "\n", + "def copy_images(\n", + " dataset: sv.DetectionDataset,\n", + " target_dir: str,\n", + " start_idx: Optional[int] = 0,\n", + " end_idx: Optional[int] = None\n", + ") -> None:\n", + " for index, (image_path, image, annotations) in enumerate(dataset):\n", + " if index < start_idx:\n", + " continue\n", + " if end_idx is not None and index >= end_idx:\n", + " break\n", + "\n", + " target_path = os.path.join(target_dir, os.path.basename(image_path))\n", + " shutil.copy(image_path, target_path)\n", + "\n", + "\n", + "def save_entries(entries: List[Dict[str, str]], output_path: str) -> None:\n", + " with open(output_path, 'w') as file:\n", + " for entry in entries:\n", + " file.write(json.dumps(entry) + '\\n')\n", + "\n", + "\n", + "def parse_dataset(\n", + " dataset: sv.DetectionDataset,\n", + " target_dir: str,\n", + " target_file_path: str,\n", + " classes: List[str],\n", + " start_idx: Optional[int] = 0,\n", + " end_idx: Optional[int] = None\n", + ") -> None:\n", + " ensure_directory_exists(directory=target_dir)\n", + " copy_images(\n", + " dataset=dataset,\n", + " target_dir=target_dir,\n", + " start_idx=start_idx,\n", + " end_idx=end_idx)\n", + " entries = extract_entries(\n", + " dataset=dataset,\n", + " classes=classes,\n", + " start_idx=start_idx,\n", + " end_idx=end_idx)\n", + " save_entries(entries=entries, output_path=target_file_path)" + ], + "metadata": { + "id": "reRShie2ZFcH" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import supervision as sv\n", + "\n", + "train_ds = sv.DetectionDataset.from_yolo(\n", + " images_directory_path=f\"{dataset.location}/train/images\",\n", + " annotations_directory_path=f\"{dataset.location}/train/labels\",\n", + " data_yaml_path=f\"{dataset.location}/data.yaml\"\n", + ")\n", + "\n", + "valid_ds = sv.DetectionDataset.from_yolo(\n", + " images_directory_path=f\"{dataset.location}/valid/images\",\n", + " annotations_directory_path=f\"{dataset.location}/valid/labels\",\n", + " data_yaml_path=f\"{dataset.location}/data.yaml\"\n", + ")" + ], + "metadata": { + "id": "QnGTgGY0ZLxA" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "len(train_ds), len(valid_ds)" + ], + "metadata": { + "id": "N8xQtqC3ZOkJ", + "outputId": "749ffb4b-914c-49c9-bcdd-aae5f3225344", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": 5, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(780, 99)" + ] + }, + "metadata": {}, + "execution_count": 5 + } + ] + }, + { + "cell_type": "code", + "source": [ + "train_ds.classes" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ojgIqcZv6oPq", + "outputId": "9bb5bdd2-2aa5-4c11-8eb2-5f316b967778" + }, + "execution_count": 6, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "['Bone fracture detection - v1 2023-03-05 5-51pm']" + ] + }, + "metadata": {}, + "execution_count": 6 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "**NOTE:** Before parsing the dataset, it's worth checking the names of the object classes in the original dataset. If they are of low quality, rename them." + ], + "metadata": { + "id": "V6fnTTkN6rPx" + } + }, + { + "cell_type": "code", + "source": [ + "CLASSES = ['fracture']\n", + "DATA_DIR = \"fracture\"" + ], + "metadata": { + "id": "B6RGe_y8ZRPg" + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "parse_dataset(train_ds, DATA_DIR, f\"{DATA_DIR}/data_train.jsonl\", CLASSES)\n", + "parse_dataset(valid_ds, DATA_DIR, f\"{DATA_DIR}/data_val.jsonl\", CLASSES)" + ], + "metadata": { + "id": "sncufv4lZaLa" + }, + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!head -n 5 {DATA_DIR}/data_train.jsonl" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "WLhSenP5AtQe", + "outputId": "412bedbf-df9f-4866-9001-dca13e9f096e" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{\"prefix\": \"detect fracture\", \"suffix\": \" fracture\", \"image\": \"rot_0_7471_png_jpg.rf.30ec1d3771a6b126e7d5f14ad0b3073b.jpg\"}\n", + "{\"prefix\": \"detect fracture\", \"suffix\": \" fracture\", \"image\": \"flip_0_5824_png_jpg.rf.abe91e2cd085f0d47e35ef7021ff8549.jpg\"}\n", + "{\"prefix\": \"detect fracture\", \"suffix\": \" fracture\", \"image\": \"all_0_8542_png_jpg.rf.6bcad49d206468d7720d727caca95724.jpg\"}\n", + "{\"prefix\": \"detect fracture\", \"suffix\": \" fracture\", \"image\": \"bri_0_592_png_jpg.rf.8d8630701ed43bb703fdad74c8765b26.jpg\"}\n", + "{\"prefix\": \"detect fracture\", \"suffix\": \" fracture ; fracture ; fracture\", \"image\": \"all_0_2435_png_jpg.rf.e602a2f82935f4fdba988b280ab11b7e.jpg\"}\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "!head -n 5 {DATA_DIR}/data_val.jsonl" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YwHY21ABA0WG", + "outputId": "4e925557-3886-44d5-eede-e2ab0c975951" + }, + "execution_count": 10, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{\"prefix\": \"detect fracture\", \"suffix\": \" fracture\", \"image\": \"n_0_3127_png_jpg.rf.7b47b2b3ca07739d0c8c3b76035517db.jpg\"}\n", + "{\"prefix\": \"detect fracture\", \"suffix\": \" fracture\", \"image\": \"z_0_2899_png_jpg.rf.fe16767f8c9eea741b91d71bff5e35d0.jpg\"}\n", + "{\"prefix\": \"detect fracture\", \"suffix\": \" fracture ; fracture\", \"image\": \"rot_0_2535_png_jpg.rf.f8b989f8ee9992855932ea50c8a4350c.jpg\"}\n", + "{\"prefix\": \"detect fracture\", \"suffix\": \" fracture ; fracture\", \"image\": \"all_0_9008_png_jpg.rf.fd15521d8aa78110a1976658f65f682f.jpg\"}\n", + "{\"prefix\": \"detect fracture\", \"suffix\": \" fracture\", \"image\": \"bri_0_520_png_jpg.rf.68504a0b42c52a63dcb7ddd2e6843a6f.jpg\"}\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "**NOTE:** Let's verify if the dataset conversion was successful." + ], + "metadata": { + "id": "u0SAC3oT74dz" + } + }, + { + "cell_type": "code", + "source": [ + "import re\n", + "import numpy as np\n", + "import supervision as sv\n", + "from typing import Tuple, List, Optional\n", + "\n", + "\n", + "def from_pali_gemma(\n", + " response: str,\n", + " resolution_wh: Tuple[int, int],\n", + " classes: Optional[List[str]] = None\n", + ") -> sv.Detections:\n", + " _SEGMENT_DETECT_RE = re.compile(\n", + " r'(.*?)' +\n", + " r'' * 4 + r'\\s*' +\n", + " '(?:%s)?' % (r'' * 16) +\n", + " r'\\s*([^;<>]+)? ?(?:; )?',\n", + " )\n", + "\n", + " width, height = resolution_wh\n", + " xyxy_list = []\n", + " class_name_list = []\n", + "\n", + " while response:\n", + " m = _SEGMENT_DETECT_RE.match(response)\n", + " if not m:\n", + " break\n", + "\n", + " gs = list(m.groups())\n", + " before = gs.pop(0)\n", + " name = gs.pop()\n", + " y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]\n", + " y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))\n", + "\n", + " content = m.group()\n", + " if before:\n", + " response = response[len(before):]\n", + " content = content[len(before):]\n", + "\n", + " xyxy_list.append([x1, y1, x2, y2])\n", + " class_name_list.append(name.strip())\n", + " response = response[len(content):]\n", + "\n", + " xyxy = np.array(xyxy_list)\n", + " class_name = np.array(class_name_list)\n", + "\n", + " if classes is None:\n", + " class_id = None\n", + " else:\n", + " class_id = np.array([classes.index(name) for name in class_name])\n", + "\n", + " return sv.Detections(\n", + " xyxy=xyxy,\n", + " class_id=class_id,\n", + " data={'class_name': class_name}\n", + " )" + ], + "metadata": { + "id": "gkakRrzkJgdq" + }, + "execution_count": 11, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from PIL import Image\n", + "\n", + "image = Image.open(f\"{DATA_DIR}/rot_0_7471_png_jpg.rf.30ec1d3771a6b126e7d5f14ad0b3073b.jpg\")\n", + "detections = from_pali_gemma(\" fracture\", image.size, CLASSES)\n", + "sv.BoundingBoxAnnotator().annotate(image, detections)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 657 + }, + "id": "K9pjdoSmYvqG", + "outputId": "07554a2c-9481-468d-8fcd-1196ff1bd476" + }, + "execution_count": 12, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ], + "image/png": "\n" + }, + "metadata": {}, + "execution_count": 12 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6U0QUFveqSP2" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "DfxKb3F839Ks", + "outputId": "b292131e-cb1a-4600-dde9-e68034a3a651", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for ml_collections (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "# @title Fetch big_vision code and install dependencies.\n", + "import os\n", + "import sys\n", + "\n", + "# TPUs with\n", + "if \"COLAB_TPU_ADDR\" in os.environ:\n", + " raise \"It seems you are using Colab with remote TPUs which is not supported.\"\n", + "\n", + "# Fetch big_vision repository if python doesn't know about it and install\n", + "# dependencies needed for this notebook.\n", + "if not os.path.exists(\"big_vision_repo\"):\n", + " !git clone --quiet --branch=main --depth=1 \\\n", + " https://github.com/google-research/big_vision big_vision_repo\n", + "\n", + "# Append big_vision code to python import path\n", + "if \"big_vision_repo\" not in sys.path:\n", + " sys.path.append(\"big_vision_repo\")\n", + "\n", + "# Install missing dependencies. Assume jax~=0.4.25 with GPU available.\n", + "!pip3 install -q \"overrides\" \"ml_collections\" \"einops~=0.7\" \"sentencepiece\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "azmRZvgGyhAb" + }, + "source": [ + "### Configure your API key to access Kaggle\n", + "\n", + "To use PaliGemma, you must provide your Kaggle username and a Kaggle API key.\n", + "\n", + "1. To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This will trigger the download of a `kaggle.json` file containing your API credentials.\n", + "1. In Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.\n", + "\n", + "To be able to download, you will also need to acknowledge the Terms and Conditions of the PaliGemma on:\n", + "\n", + "* https://www.kaggle.com/models/google/paligemma/\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "zGLIp1Cx3_CX" + }, + "outputs": [], + "source": [ + "import os\n", + "from google.colab import userdata\n", + "\n", + "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n", + "# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json\n", + "\n", + "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n", + "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gQNOTfF24AV4", + "outputId": "96cb1910-0ef9-4a0e-a338-22210e9ec658" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading the checkpoint from Kaggle, this could take a few minutes....\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading from https://www.kaggle.com/api/v1/models/google/paligemma/jax/paligemma-3b-pt-224/1/download/paligemma-3b-pt-224.f16.npz...\n", + "100%|██████████| 5.45G/5.45G [02:33<00:00, 38.0MB/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model path: /root/.cache/kagglehub/models/google/paligemma/jax/paligemma-3b-pt-224/1/./paligemma-3b-pt-224.f16.npz\n", + "Downloading the model tokenizer...\n", + "Copying gs://big_vision/paligemma_tokenizer.model...\n", + "/ [1 files][ 4.1 MiB/ 4.1 MiB] \n", + "Operation completed over 1 objects/4.1 MiB. \n", + "Tokenizer path: ./paligemma_tokenizer.model\n" + ] + } + ], + "source": [ + "# @title Download checkpoint, tokenizer and dataset to local filesystem.\n", + "#\n", + "import os\n", + "import kagglehub\n", + "\n", + "MODEL_PATH = \"./paligemma-3b-pt-224.f16.npz\"\n", + "if not os.path.exists(MODEL_PATH):\n", + " print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n", + " # Note: kaggle archive contains the same checkpoint in multiple formats.\n", + " # Download only the float16 model.\n", + " MODEL_PATH = kagglehub.model_download('google/paligemma/jax/paligemma-3b-pt-224', MODEL_PATH)\n", + " print(f\"Model path: {MODEL_PATH}\")\n", + "\n", + "TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n", + "if not os.path.exists(TOKENIZER_PATH):\n", + " print(\"Downloading the model tokenizer...\")\n", + " !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}\n", + " print(f\"Tokenizer path: {TOKENIZER_PATH}\")\n", + "\n", + "# DATA_DIR=\"./longcap100\"\n", + "# if not os.path.exists(DATA_DIR):\n", + "# print(\"Downloading the dataset...\")\n", + "# !gsutil -m -q cp -n -r gs://longcap100/ .\n", + "# print(f\"Data path: {DATA_DIR}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zDoq0O77GF30" + }, + "source": [ + "## Notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dTfe2k8J4Bw0", + "outputId": "76b14d9e-b154-4825-aa6e-c52b27fbffcf" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "JAX version: 0.4.26\n", + "JAX platform: gpu\n", + "JAX devices: 1\n" + ] + } + ], + "source": [ + "import base64\n", + "import functools\n", + "import html\n", + "import io\n", + "import os\n", + "import warnings\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import ml_collections\n", + "\n", + "import tensorflow as tf\n", + "import sentencepiece\n", + "\n", + "from IPython.core.display import display, HTML\n", + "from PIL import Image\n", + "\n", + "# Import model definition from big_vision\n", + "from big_vision.models.proj.paligemma import paligemma\n", + "from big_vision.trainers.proj.paligemma import predict_fns\n", + "\n", + "# Import big vision utilities\n", + "import big_vision.datasets.jsonl\n", + "import big_vision.utils\n", + "import big_vision.sharding\n", + "\n", + "# Don't let TF use the GPU or TPUs\n", + "tf.config.set_visible_devices([], \"GPU\")\n", + "tf.config.set_visible_devices([], \"TPU\")\n", + "\n", + "backend = jax.lib.xla_bridge.get_backend()\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX platform: {backend.platform}\")\n", + "print(f\"JAX devices: {jax.device_count()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "1aghcULcEdtv" + }, + "outputs": [], + "source": [ + "# @title Construct model and load params into RAM.\n", + "\n", + "# Define model\n", + "model_config = ml_collections.FrozenConfigDict({\n", + " \"llm\": {\"vocab_size\": 257_152},\n", + " \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n", + "})\n", + "model = paligemma.Model(**model_config)\n", + "tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)\n", + "\n", + "# Load params - this can take up to 1 minute in T4 colabs.\n", + "params = paligemma.load(None, MODEL_PATH, model_config)\n", + "\n", + "# Define `decode` function to sample outputs from the model.\n", + "decode_fn = predict_fns.get_all(model)['decode']\n", + "decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())" + ] + }, + { + "cell_type": "code", + "source": [ + "# model_config = ml_collections.FrozenConfigDict({\n", + "# \"llm\": {\"vocab_size\": 257_152},\n", + "# \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n", + "# })\n", + "# model = paligemma.Model(**model_config)\n", + "# tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)\n", + "\n", + "# # Load params - this can take up to 1 minute in T4 colabs.\n", + "# params = paligemma.load(None, '/content/fine-tuned-paligemma-3b-pt-224.f16.npz', model_config)\n", + "\n", + "# # Define `decode` function to sample outputs from the model.\n", + "# decode_fn = predict_fns.get_all(model)['decode']\n", + "# decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())" + ], + "metadata": { + "id": "2LNRDMMwXFJ9" + }, + "execution_count": 18, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RWOdf_fw2SAO", + "outputId": "f666f219-5fc8-4cf7-edc3-c9c9e0235e15" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " == Model params == \n", + "img/Transformer/encoder_norm/bias (1152,) float16\n", + "img/Transformer/encoder_norm/scale (1152,) float16\n", + "img/Transformer/encoderblock/LayerNorm_0/bias (27, 1152) float16\n", + "img/Transformer/encoderblock/LayerNorm_0/scale (27, 1152) float16\n", + "img/Transformer/encoderblock/LayerNorm_1/bias (27, 1152) float16\n", + "img/Transformer/encoderblock/LayerNorm_1/scale (27, 1152) float16\n", + "img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias (27, 4304) float16\n", + "img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel (27, 1152, 4304) float16\n", + "img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias (27, 1152) float16\n", + "img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel (27, 4304, 1152) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias (27, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel (27, 1152, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias (27, 1152) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel (27, 16, 72, 1152) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias (27, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel (27, 1152, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias (27, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel (27, 1152, 16, 72) float16\n", + "img/embedding/bias (1152,) float16\n", + "img/embedding/kernel (14, 14, 3, 1152) float16\n", + "img/head/bias (2048,) float16\n", + "img/head/kernel (1152, 2048) float16\n", + "img/pos_embedding (1, 256, 1152) float16\n", + "llm/embedder/input_embedding (257152, 2048) float16\n", + "llm/final_norm/scale (2048,) float16\n", + "llm/layers/attn/attn_vec_einsum/w (18, 8, 256, 2048) float32\n", + "llm/layers/attn/kv_einsum/w (18, 2, 1, 2048, 256) float32\n", + "llm/layers/attn/q_einsum/w (18, 8, 2048, 256) float32\n", + "llm/layers/mlp/gating_einsum (18, 2, 2048, 16384) float16\n", + "llm/layers/mlp/linear (18, 16384, 2048) float16\n", + "llm/layers/pre_attention_norm/scale (18, 2048) float16\n", + "llm/layers/pre_ffw_norm/scale (18, 2048) float16\n" + ] + } + ], + "source": [ + "# @title Move params to GPU/TPU memory.\n", + "#\n", + "# To keep HBM usage low and fit in a T4 GPU (16GB HBM) we opt to only finetune\n", + "# a part of the parameters. Additionally we keep the frozen params in float16\n", + "# and cast trainable to float32.\n", + "\n", + "# Create a pytree mask of the trainable params.\n", + "def is_trainable_param(name, param): # pylint: disable=unused-argument\n", + " if name.startswith(\"llm/layers/attn/\"): return True\n", + " if name.startswith(\"llm/\"): return False\n", + " if name.startswith(\"img/\"): return False\n", + " raise ValueError(f\"Unexpected param name {name}\")\n", + "trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)\n", + "\n", + "#\n", + "# If more than one device is available (e.g. multiple GPUs) the parameters can\n", + "# be sharded across them to reduce HBM usage per device.\n", + "mesh = jax.sharding.Mesh(jax.devices(), (\"data\"))\n", + "\n", + "data_sharding = jax.sharding.NamedSharding(\n", + " mesh, jax.sharding.PartitionSpec(\"data\"))\n", + "\n", + "params_sharding = big_vision.sharding.infer_sharding(\n", + " params, strategy=[('.*', 'fsdp(axis=\"data\")')], mesh=mesh)\n", + "\n", + "# Yes: Some donated buffers are not usable.\n", + "warnings.filterwarnings(\n", + " \"ignore\", message=\"Some donated buffers were not usable\")\n", + "\n", + "@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n", + "def maybe_cast_to_f32(params, trainable):\n", + " return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,\n", + " params, trainable)\n", + "\n", + "# Loading all params in simultaneous - albeit much faster and more succinct -\n", + "# requires more RAM than the T4 colab runtimes have by default (12GB RAM).\n", + "# Instead we do it param by param.\n", + "params, treedef = jax.tree.flatten(params)\n", + "sharding_leaves = jax.tree.leaves(params_sharding)\n", + "trainable_leaves = jax.tree.leaves(trainable_mask)\n", + "for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):\n", + " params[idx] = big_vision.utils.reshard(params[idx], sharding)\n", + " params[idx] = maybe_cast_to_f32(params[idx], trainable)\n", + " params[idx].block_until_ready()\n", + "params = jax.tree.unflatten(treedef, params)\n", + "\n", + "# Print params to show what the model is made of.\n", + "def parameter_overview(params):\n", + " for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:\n", + " print(f\"{path:80s} {str(arr.shape):22s} {arr.dtype}\")\n", + "\n", + "print(\" == Model params == \")\n", + "parameter_overview(params)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "8SRW0NuU4UcW" + }, + "outputs": [], + "source": [ + "# @title Define preprocess functions to create inputs to the model.\n", + "\n", + "def preprocess_image(image, size=224):\n", + " # Model has been trained to handle images of different aspects ratios\n", + " # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize\n", + " # options are helpful to improve quality in some tasks.\n", + " image = np.asarray(image)\n", + " if image.ndim == 2: # Convert image without last channel into greyscale.\n", + " image = np.stack((image,)*3, axis=-1)\n", + " image = image[..., :3] # Remove alpha layer.\n", + " assert image.shape[-1] == 3\n", + "\n", + " image = tf.constant(image)\n", + " image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)\n", + " return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]\n", + "\n", + "def preprocess_tokens(prefix, suffix=None, seqlen=None):\n", + " # Model has been trained to handle tokenized text composed of a prefix with\n", + " # full attention and a suffix with causal attention.\n", + " separator = \"\\n\"\n", + " tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)\n", + " mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.\n", + " mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.\n", + "\n", + " if suffix:\n", + " suffix = tokenizer.encode(suffix, add_eos=True)\n", + " tokens += suffix\n", + " mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.\n", + " mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.\n", + "\n", + " mask_input = [1] * len(tokens) # 1 if its a token, 0 if padding.\n", + " if seqlen:\n", + " padding = [0] * max(0, seqlen - len(tokens))\n", + " tokens = tokens[:seqlen] + padding\n", + " mask_ar = mask_ar[:seqlen] + padding\n", + " mask_loss = mask_loss[:seqlen] + padding\n", + " mask_input = mask_input[:seqlen] + padding\n", + "\n", + " return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))\n", + "\n", + "def postprocess_tokens(tokens):\n", + " tokens = tokens.tolist() # np.array to list[int]\n", + " try: # Remove tokens at and after EOS if any.\n", + " eos_pos = tokens.index(tokenizer.eos_id())\n", + " tokens = tokens[:eos_pos]\n", + " except ValueError:\n", + " pass\n", + " return tokenizer.decode(tokens)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "id": "whzWOojGOtzi" + }, + "outputs": [], + "source": [ + "# @title Function to iterate over train and validation examples.\n", + "SEQLEN = 128\n", + "\n", + "# TODO: Consider data iterators skipping big_vision and tf.data?\n", + "train_dataset = big_vision.datasets.jsonl.DataSource(\n", + " os.path.join(DATA_DIR, \"data_train.jsonl\"),\n", + " fopen_keys={\"image\": DATA_DIR})\n", + "\n", + "val_dataset = big_vision.datasets.jsonl.DataSource(\n", + " os.path.join(DATA_DIR, \"data_val.jsonl\"),\n", + " fopen_keys={\"image\": DATA_DIR})\n", + "\n", + "\n", + "def train_data_iterator():\n", + " \"\"\"Never ending iterator over training examples.\"\"\"\n", + " # Shuffle examples and repeat so one can train for many epochs.\n", + " dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()\n", + " for example in dataset.as_numpy_iterator():\n", + " image = Image.open(io.BytesIO(example[\"image\"]))\n", + " image = preprocess_image(image)\n", + "\n", + " # prefix = \"caption en\" # Could also be a different prefix per example.\n", + " prefix = example[\"prefix\"].decode().lower()\n", + " suffix = example[\"suffix\"].decode().lower()\n", + " tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)\n", + "\n", + " yield {\n", + " \"image\": np.asarray(image),\n", + " \"text\": np.asarray(tokens),\n", + " \"mask_ar\": np.asarray(mask_ar),\n", + " \"mask_loss\": np.asarray(mask_loss),\n", + " }\n", + "\n", + "\n", + "def validation_data_iterator():\n", + " \"\"\"Single iterator over validation examples.\"\"\"\n", + " for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():\n", + " image = Image.open(io.BytesIO(example[\"image\"]))\n", + " image = preprocess_image(image)\n", + "\n", + " # prefix = \"caption en\" # Could also be a different prefix per example.\n", + " prefix = example[\"prefix\"].decode().lower()\n", + " tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)\n", + "\n", + " yield {\n", + " \"image\": np.asarray(image),\n", + " \"text\": np.asarray(tokens),\n", + " \"mask_ar\": np.asarray(mask_ar),\n", + " \"mask_input\": np.asarray(mask_input),\n", + " }\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 298 + }, + "id": "BzJfb5t0nsLq", + "outputId": "fa2b30bc-6fcc-4367-f178-1a4d72155112" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Training examples\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0330><loc0522><loc0474><loc0770> fracture ; <loc0578><loc0456><loc0671><loc0655> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0401><loc0047><loc0491><loc0164> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0564><loc0421><loc0649><loc0546> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0532><loc0200><loc0658><loc0358> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0263><loc0205><loc0415><loc0415> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0743><loc0302><loc0792><loc0434> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0396><loc0413><loc0460><loc0501> fracture ; <loc0454><loc0308><loc0526><loc0416> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0751><loc0753><loc0845><loc0876> fracture

\n", + "
\n", + " " + ] + }, + "metadata": {} + } + ], + "source": [ + "# @title Inspect training examples.\n", + "def split_and_keep_second_part(s):\n", + " parts = s.split('\\n', 1)\n", + " if len(parts) > 1:\n", + " return parts[1]\n", + " return s\n", + "\n", + "def render_inline(image, resize=(128, 128)):\n", + " \"\"\"Convert image into inline html.\"\"\"\n", + " image = Image.fromarray(image)\n", + " image.resize(resize)\n", + " with io.BytesIO() as buffer:\n", + " image.save(buffer, format='jpeg')\n", + " image_b64 = str(base64.b64encode(buffer.getvalue()), \"utf-8\")\n", + " return f\"data:image/jpeg;base64,{image_b64}\"\n", + "\n", + "def render_example(image, caption):\n", + " image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]\n", + " h, w, _ = image.shape\n", + " try:\n", + " detections = from_pali_gemma(caption, (w, h), CLASSES)\n", + " image = sv.BoundingBoxAnnotator().annotate(image, detections)\n", + " except:\n", + " print(\"result render failed, result:\", caption)\n", + " return f\"\"\"\n", + "
\n", + " \n", + "

{html.escape(caption)}

\n", + "
\n", + " \"\"\"\n", + "\n", + "html_out = \"\"\n", + "for idx, example in zip(range(8), train_data_iterator()):\n", + " caption = postprocess_tokens(example[\"text\"]) # detokenize model input.\n", + " caption = split_and_keep_second_part(caption)\n", + " html_out += render_example(example[\"image\"], caption)\n", + "\n", + "print(\"Training examples\")\n", + "display(HTML(html_out))" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "id": "dwUV_imW3WQJ" + }, + "outputs": [], + "source": [ + "# @title Define the training step and evaluation loop.\n", + "#\n", + "# The main update_fn using simple SGD.\n", + "#\n", + "@functools.partial(jax.jit, donate_argnums=(0,))\n", + "def update_fn(params, batch, learning_rate):\n", + " imgs, txts, mask_ar = batch[\"image\"], batch[\"text\"], batch[\"mask_ar\"]\n", + "\n", + " def loss_fn(params):\n", + " text_logits, _ = model.apply({\"params\": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)\n", + " logp = jax.nn.log_softmax(text_logits, axis=-1)\n", + "\n", + " # The model takes as input txts[:, :-1] but the loss is defined as predicting\n", + " # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens\n", + " # are part of the loss (e.g. prefix and padded tokens are not included).\n", + " mask_loss = batch[\"mask_loss\"][:, 1:]\n", + " targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n", + "\n", + " # Compute the loss per example. i.e. the mean of per token pplx.\n", + " # Since each example has a different number of tokens we normalize it.\n", + " token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.\n", + " example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.\n", + " example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.\n", + "\n", + " # batch_loss: mean of per example loss.\n", + " return jnp.mean(example_loss)\n", + "\n", + " loss, grads = jax.value_and_grad(loss_fn)(params)\n", + "\n", + " # Apply gradients to trainable params using SGD.\n", + " def apply_grad(param, gradient, trainable):\n", + " if not trainable: return param\n", + " return param - learning_rate * gradient\n", + "\n", + " params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)\n", + "\n", + " return params, loss\n", + "\n", + "# Evaluation/inference loop.\n", + "def make_predictions(data_iterator, *, num_examples=None,\n", + " batch_size=4, seqlen=SEQLEN, sampler=\"greedy\"):\n", + " outputs = []\n", + " while True:\n", + " # Construct a list of examples in the batch.\n", + " examples = []\n", + " try:\n", + " for _ in range(batch_size):\n", + " examples.append(next(data_iterator))\n", + " examples[-1][\"_mask\"] = np.array(True) # Indicates true example.\n", + " except StopIteration:\n", + " if len(examples) == 0:\n", + " return outputs\n", + "\n", + " # Not enough examples to complete a batch. Pad by repeating last example.\n", + " while len(examples) % batch_size:\n", + " examples.append(dict(examples[-1]))\n", + " examples[-1][\"_mask\"] = np.array(False) # Indicates padding example.\n", + "\n", + " # Convert list of examples into a dict of np.arrays and load onto devices.\n", + " batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n", + " batch = big_vision.utils.reshard(batch, data_sharding)\n", + "\n", + " # Make model predictions\n", + " tokens = decode({\"params\": params}, batch=batch,\n", + " max_decode_len=seqlen, sampler=sampler)\n", + "\n", + " # Fetch model predictions to device and detokenize.\n", + " tokens, mask = jax.device_get((tokens, batch[\"_mask\"]))\n", + " tokens = tokens[mask] # remove padding examples.\n", + " responses = [postprocess_tokens(t) for t in tokens]\n", + "\n", + " # Append to html output.\n", + " for example, response in zip(examples, responses):\n", + " outputs.append((example[\"image\"], response))\n", + " if num_examples and len(outputs) >= num_examples:\n", + " return outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "067wj_6bZAG3", + "outputId": "50132bea-0b74-4b1a-9049-1a56432b32fd" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 1/64 lr: 0.00167 loss: 4.6046\n", + "Model predictions at step 1\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0000><loc0000><loc0900><loc1015> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0000><loc0000><loc1023><loc1023> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0105><loc0000><loc0964><loc1023> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0000><loc0000><loc1015><loc1023> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0373><loc0238><loc0575><loc0432> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0000><loc0000><loc1015><loc1023> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0679><loc0549><loc0735><loc0670> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0000><loc0000><loc1017><loc0767> fracture

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 2/64 lr: 0.00333 loss: 3.9101\n", + "step: 3/64 lr: 0.00500 loss: 3.8219\n", + "step: 4/64 lr: 0.00667 loss: 3.7913\n", + "step: 5/64 lr: 0.00833 loss: 3.9509\n", + "step: 6/64 lr: 0.01000 loss: 4.0808\n", + "step: 7/64 lr: 0.00999 loss: 3.9368\n", + "step: 8/64 lr: 0.00997 loss: 3.6630\n", + "Model predictions at step 8\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0612><loc0549><loc0870><loc0651> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0000><loc0000><loc1023><loc1023> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0110><loc0000><loc0969><loc1023> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0000><loc0000><loc1013><loc1023> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0105><loc0425><loc0167><loc0509> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0000><loc0000><loc1015><loc1023> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0693><loc0513><loc0767><loc0826> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0517><loc0621><loc0583><loc0738> fracture

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 9/64 lr: 0.00994 loss: 3.7152\n", + "step: 10/64 lr: 0.00989 loss: 4.1957\n", + "step: 11/64 lr: 0.00982 loss: 3.8173\n", + "step: 12/64 lr: 0.00975 loss: 4.0603\n", + "step: 13/64 lr: 0.00966 loss: 3.6232\n", + "step: 14/64 lr: 0.00955 loss: 3.8549\n", + "step: 15/64 lr: 0.00944 loss: 3.6587\n", + "step: 16/64 lr: 0.00931 loss: 3.6792\n", + "Model predictions at step 16\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0603><loc0549><loc0841><loc0658> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0441><loc0230><loc0541><loc0344> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0620><loc0477><loc0714><loc0629> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0458><loc0000><loc0531><loc0127> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0559><loc0279><loc0617><loc0413> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0591><loc0036><loc0645><loc0238> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0485><loc0728><loc0541><loc0944> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0536><loc0621><loc0603><loc0738> fracture

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 17/64 lr: 0.00917 loss: 3.4137\n", + "step: 18/64 lr: 0.00901 loss: 3.8984\n", + "step: 19/64 lr: 0.00885 loss: 3.8873\n", + "step: 20/64 lr: 0.00867 loss: 3.7505\n", + "step: 21/64 lr: 0.00849 loss: 3.7748\n", + "step: 22/64 lr: 0.00829 loss: 3.6498\n", + "step: 23/64 lr: 0.00809 loss: 3.6461\n", + "step: 24/64 lr: 0.00787 loss: 3.7160\n", + "Model predictions at step 24\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0603><loc0559><loc0850><loc0653> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0767><loc0505><loc1023><loc0754> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0596><loc0000><loc0726><loc0127> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0303><loc0387><loc0409><loc0529> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0559><loc0302><loc0611><loc0425> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0591><loc0000><loc0652><loc0206> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0503><loc0000><loc0561><loc0127> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0559><loc0620><loc0623><loc0754> fracture

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 25/64 lr: 0.00765 loss: 3.7272\n", + "step: 26/64 lr: 0.00742 loss: 3.4325\n", + "step: 27/64 lr: 0.00719 loss: 3.7337\n", + "step: 28/64 lr: 0.00694 loss: 3.5439\n", + "step: 29/64 lr: 0.00670 loss: 3.8029\n", + "step: 30/64 lr: 0.00644 loss: 3.5595\n", + "step: 31/64 lr: 0.00619 loss: 3.4872\n", + "step: 32/64 lr: 0.00593 loss: 3.4479\n", + "Model predictions at step 32\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0575><loc0559><loc0850><loc0647> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0754><loc0485><loc1023><loc0714> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0588><loc0485><loc0677><loc0595> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0425><loc0000><loc0509><loc0100> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0525><loc0284><loc0595><loc0438> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0559><loc0000><loc0629><loc0206> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0456><loc0000><loc0536><loc0174> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0483><loc0621><loc0559><loc0735> fracture

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 33/64 lr: 0.00566 loss: 3.4278\n", + "step: 34/64 lr: 0.00540 loss: 3.7668\n", + "step: 35/64 lr: 0.00513 loss: 3.5646\n", + "step: 36/64 lr: 0.00487 loss: 3.2581\n", + "step: 37/64 lr: 0.00460 loss: 3.3962\n", + "step: 38/64 lr: 0.00434 loss: 3.6364\n", + "step: 39/64 lr: 0.00407 loss: 3.3160\n", + "step: 40/64 lr: 0.00381 loss: 3.5750\n", + "Model predictions at step 40\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0589><loc0572><loc0867><loc0653> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0762><loc0670><loc1023><loc0754> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0620><loc0505><loc0703><loc0601> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0432><loc0000><loc0509><loc0113> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0547><loc0302><loc0626><loc0425> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0583><loc0000><loc0664><loc0230> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0478><loc0000><loc0549><loc0097> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0509><loc0639><loc0582><loc0735> fracture

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 41/64 lr: 0.00356 loss: 3.5534\n", + "step: 42/64 lr: 0.00330 loss: 3.7601\n", + "step: 43/64 lr: 0.00306 loss: 3.4201\n", + "step: 44/64 lr: 0.00281 loss: 3.2198\n", + "step: 45/64 lr: 0.00258 loss: 3.8257\n", + "step: 46/64 lr: 0.00235 loss: 3.4511\n", + "step: 47/64 lr: 0.00213 loss: 3.4098\n", + "step: 48/64 lr: 0.00191 loss: 3.5586\n", + "Model predictions at step 48\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0575><loc0559><loc0870><loc0646> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0419><loc0230><loc0541><loc0344> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0620><loc0485><loc0703><loc0591> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0430><loc0000><loc0514><loc0100> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0529><loc0290><loc0611><loc0425> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0575><loc0000><loc0658><loc0216> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0478><loc0000><loc0561><loc0097> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0496><loc0633><loc0579><loc0738> fracture

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 49/64 lr: 0.00171 loss: 3.3370\n", + "step: 50/64 lr: 0.00151 loss: 3.6179\n", + "step: 51/64 lr: 0.00133 loss: 3.8233\n", + "step: 52/64 lr: 0.00115 loss: 3.4149\n", + "step: 53/64 lr: 0.00099 loss: 3.6569\n", + "step: 54/64 lr: 0.00083 loss: 3.4282\n", + "step: 55/64 lr: 0.00069 loss: 3.3576\n", + "step: 56/64 lr: 0.00056 loss: 3.5013\n", + "Model predictions at step 56\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0583><loc0559><loc0870><loc0647> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0419><loc0230><loc0541><loc0344> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0620><loc0485><loc0703><loc0601> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0430><loc0000><loc0518><loc0104> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0529><loc0290><loc0602><loc0425> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0575><loc0000><loc0652><loc0223> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0478><loc0000><loc0561><loc0113> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0496><loc0621><loc0582><loc0754> fracture

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 57/64 lr: 0.00045 loss: 3.5114\n", + "step: 58/64 lr: 0.00034 loss: 3.3903\n", + "step: 59/64 lr: 0.00025 loss: 3.3435\n", + "step: 60/64 lr: 0.00018 loss: 3.2968\n", + "step: 61/64 lr: 0.00011 loss: 3.4862\n", + "step: 62/64 lr: 0.00006 loss: 3.6225\n", + "step: 63/64 lr: 0.00003 loss: 3.4152\n", + "step: 64/64 lr: 0.00001 loss: 3.6450\n", + "Model predictions at step 64\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0583><loc0559><loc0870><loc0653> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0419><loc0230><loc0541><loc0344> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0620><loc0485><loc0703><loc0601> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0430><loc0000><loc0514><loc0104> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0529><loc0290><loc0602><loc0438> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0577><loc0000><loc0652><loc0223> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0478><loc0000><loc0561><loc0127> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0496><loc0621><loc0582><loc0754> fracture

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "CPU times: user 11min 24s, sys: 743 ms, total: 11min 24s\n", + "Wall time: 11min 27s\n" + ] + } + ], + "source": [ + "# @title Run training loop.\n", + "#\n", + "# Run a short training loop with cosine learning rate schedule.\n", + "#\n", + "# Note: the first step can be quite slow on some machines (up to several minutes)\n", + "# due to XLA compilation of the jax.jit'd function.\n", + "#\n", + "%%time\n", + "\n", + "BATCH_SIZE = 8\n", + "TRAIN_EXAMPLES = 512\n", + "# TRAIN_EXAMPLES = 256\n", + "LEARNING_RATE = 0.01\n", + "\n", + "TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE\n", + "EVAL_STEPS = TRAIN_STEPS // 8\n", + "\n", + "train_data_it = train_data_iterator()\n", + "\n", + "sched_fn = big_vision.utils.create_learning_rate_schedule(\n", + " total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,\n", + " decay_type=\"cosine\", warmup_percent=0.10)\n", + "\n", + "for step in range(1, TRAIN_STEPS+1):\n", + " # Make list of N training examples.\n", + " examples = [next(train_data_it) for _ in range(BATCH_SIZE)]\n", + "\n", + " # Convert list of examples into a dict of np.arrays and load onto devices.\n", + " batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n", + " batch = big_vision.utils.reshard(batch, data_sharding)\n", + "\n", + " # Training step and report training loss\n", + " learning_rate = sched_fn(step)\n", + " params, loss = update_fn(params, batch, learning_rate)\n", + "\n", + " loss = jax.device_get(loss)\n", + " print(f\"step: {step:2d}/{TRAIN_STEPS:2d} lr: {learning_rate:.5f} loss: {loss:.4f}\")\n", + "\n", + " if step == 1 or (step % EVAL_STEPS) == 0:\n", + " print(f\"Model predictions at step {step}\")\n", + " html_out = \"\"\n", + " for image, caption in make_predictions(\n", + " validation_data_iterator(), num_examples=8, batch_size=8):\n", + " html_out += render_example(image, caption)\n", + " display(HTML(html_out))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "hgUhEKjzPdMQ", + "outputId": "184cfc14-a20c-4898-cf0e-6c36210375b3" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model predictions\n", + "result render failed, result: \n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0583><loc0559><loc0870><loc0653> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0419><loc0230><loc0541><loc0344> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0620><loc0485><loc0703><loc0601> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0430><loc0000><loc0514><loc0104> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0529><loc0290><loc0602><loc0438> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0577><loc0000><loc0652><loc0223> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0478><loc0000><loc0561><loc0127> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0496><loc0621><loc0582><loc0754> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0413><loc0201><loc0582><loc0270> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0166><loc0425><loc0292><loc0726> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0321><loc0000><loc0425><loc0091> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0566><loc0284><loc0616><loc0413> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0453><loc0303><loc0527><loc0397> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0620><loc0441><loc0703><loc0575> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0559><loc0366><loc0652><loc0513> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0485><loc0366><loc0726><loc0503> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0363><loc0453><loc0575><loc0559> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0303><loc0465><loc0376><loc0629> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0366><loc0477><loc0465><loc0714> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0321><loc0302><loc0575><loc0465> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0529><loc0503><loc0675><loc0714> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0677><loc0300><loc0772><loc0453> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0413><loc0270><loc0698><loc0477> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0596><loc0200><loc0714><loc0387> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0463><loc0719><loc0543><loc0826> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0453><loc0421><loc0617><loc0509> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0409><loc0344><loc0514><loc0509> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0637><loc0293><loc0703><loc0366> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0403><loc0230><loc0812><loc0529> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0456><loc0344><loc0549><loc0485> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0238><loc0150><loc0313><loc0313> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0263><loc0366><loc0456><loc0467> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0483><loc0000><loc0629><loc0113> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0270><loc0639><loc0357><loc0735> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0279><loc0667><loc0357><loc0944> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0575><loc0754><loc0639><loc0845> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0470><loc0309><loc0575><loc0513> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0583><loc0054><loc0652><loc0211> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0366><loc0143><loc0438><loc0509> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0525><loc0000><loc0609><loc0127> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0438><loc0000><loc0596><loc0100> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0279><loc0344><loc0392><loc0559> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0559><loc0290><loc0645><loc0413> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0559><loc0313><loc0677><loc0453> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0485><loc0216><loc0595><loc0438> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0303><loc0453><loc0382><loc0653> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0431><loc0127><loc0569><loc0302> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0488><loc0471><loc0644><loc0582> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0463><loc0000><loc0541><loc0168> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0425><loc0465><loc0629><loc0559> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0344><loc0000><loc0424><loc0200> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0373><loc0425><loc0456><loc0629> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0620><loc0167><loc0703><loc0344> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0575><loc0629><loc0714><loc0797> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0575><loc0485><loc0687><loc0629> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0403><loc0601><loc0591><loc0710> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0251><loc0000><loc0409><loc0366> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0363><loc0290><loc0686><loc0582> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0403><loc0767><loc0477><loc0921> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0182><loc0113><loc0292><loc0303> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0620><loc0073><loc0726><loc0238> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0575><loc0425><loc0767><loc0513> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0465><loc0425><loc0568><loc0531> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0425><loc0579><loc0559><loc0646> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0559><loc0541><loc0662><loc0896> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0653><loc0754><loc0797><loc0870> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0363><loc0425><loc0453><loc0629> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0303><loc0442><loc0399><loc0575> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0579><loc0093><loc0658><loc0219> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0603><loc0074><loc0689><loc0267> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0456><loc0106><loc0527><loc0230> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0396><loc0297><loc0470><loc0474> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0435><loc0582><loc0505><loc0726> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0419><loc0667><loc0509><loc0896> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0175><loc0603><loc0344><loc0726> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0418><loc0152><loc0510><loc0271> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0425><loc0384><loc0582><loc0611> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0438><loc0000><loc0714><loc1023> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0509><loc0000><loc0603><loc0194> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0419><loc0667><loc0502><loc0887> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0456><loc0403><loc0559><loc0644> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0582><loc0363><loc0653><loc0505> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0363><loc0074><loc0485><loc0279> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0000><loc0617><loc0113><loc0721> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0366><loc0297><loc0438><loc0453> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0496><loc0453><loc0591><loc0549> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0408><loc0176><loc0477><loc0293> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0366><loc0505><loc0453><loc0870> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0375><loc0366><loc0467><loc0456> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0797><loc0529><loc0921><loc0754> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0453><loc0000><loc0559><loc1023> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0165><loc0735><loc0366><loc0826> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0199><loc0366><loc0425><loc0612> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0596><loc0211><loc0703><loc0344> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0366><loc0000><loc0456><loc0206> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0373><loc0211><loc0453><loc0513> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0448><loc0100><loc0527><loc0238> fracture

\n", + "
\n", + " \n", + "
\n", + " \n", + "

<loc0493><loc0620><loc0629><loc0714> fracture

\n", + "
\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "CPU times: user 59.1 s, sys: 202 ms, total: 59.3 s\n", + "Wall time: 1min 5s\n" + ] + } + ], + "source": [ + "# @title Evaluate the model on all examples.\n", + "#\n", + "# The validation data consists of 10 images in a different domain than training\n", + "# data.\n", + "%%time\n", + "\n", + "print(\"Model predictions\")\n", + "html_out = \"\"\n", + "for image, caption in make_predictions(validation_data_iterator(), batch_size=4):\n", + " html_out += render_example(image, caption)\n", + "display(HTML(html_out))\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Save fine-tuned model locally" + ], + "metadata": { + "id": "Hr1gTKP8trRb" + } + }, + { + "cell_type": "code", + "source": [ + "flat, _ = big_vision.utils.tree_flatten_with_names(params)\n", + "with open(\"/content/fine-tuned-paligemma-3b-pt-224.f16.npz\", \"wb\") as f:\n", + " np.savez(f, **{k: v for k, v in flat})" + ], + "metadata": { + "id": "zyVxKr2FOxPe" + }, + "execution_count": 26, + "outputs": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [], + "machine_shape": "hm" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file