diff --git a/examples/python/transformers/HuggingFace_in_Spark_NLP_CamemBertForZeroShotClassification.ipynb b/examples/python/transformers/HuggingFace_in_Spark_NLP_CamemBertForZeroShotClassification.ipynb new file mode 100644 index 00000000000000..29754edba40b38 --- /dev/null +++ b/examples/python/transformers/HuggingFace_in_Spark_NLP_CamemBertForZeroShotClassification.ipynb @@ -0,0 +1,2979 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Ly1Q4SJ4B-_k" + }, + "source": [ + "![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/HuggingFace_in_Spark_NLP_CamemBertForZeroShotClassification.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WoRNBgVaB-_o" + }, + "source": [ + "## Import CamemBertForZeroShotClassification models from HuggingFace πŸ€— into Spark NLP πŸš€\n", + "\n", + "Let's keep in mind a few things before we start 😊\n", + "\n", + "- This feature is only in `Spark NLP 4.4.1` and after. So please make sure you have upgraded to the latest Spark NLP release\n", + "- You can import CamemBERT models trained/fine-tuned for sequence classification via `CamembertForSequenceClassification` or `TFCamembertForSequenceClassification`. These models are usually under `Sequence Classification` category and have `camembert` in their labels\n", + "- Reference: [TFCamembertForSequenceClassification](https://huggingface.co/docs/transformers/model_doc/camembert#transformers.TFCamembertForSequenceClassification)\n", + "- Some [example models](https://huggingface.co/models?other=camembert&pipeline_tag=token-classification)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Zx5fPyP8B-_p" + }, + "source": [ + "## Export and Save HuggingFace model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Rs1VirKQB-_p" + }, + "source": [ + "- Let's install `HuggingFace` and `TensorFlow`. You don't need `TensorFlow` to be installed for Spark NLP, however, we need it to load and save models from HuggingFace.\n", + "- CamembertTokenizer requires the `SentencePiece` library, so we install that as well" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "esxJoWfQB-_q", + "outputId": "37e75e62-bc25-4ff9-ceca-11e263720d2e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m588.3/588.3 MB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m89.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m57.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.0/6.0 MB\u001b[0m \u001b[31m111.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m439.2/439.2 kB\u001b[0m \u001b[31m48.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m111.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m781.3/781.3 kB\u001b[0m \u001b[31m73.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "cudf-cu12 24.4.1 requires protobuf<5,>=3.20, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-iam 2.15.1 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-language 2.13.4 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-pubsub 2.22.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-resource-manager 1.12.4 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "googleapis-common-protos 1.63.2 requires protobuf!=3.20.0,!=3.20.1,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0.dev0,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "grpc-google-iam-v1 0.13.1 requires protobuf!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "pandas-gbq 0.19.2 requires google-auth-oauthlib>=0.7.0, but you have google-auth-oauthlib 0.4.6 which is incompatible.\n", + "tensorflow-datasets 4.9.6 requires protobuf>=3.20, but you have protobuf 3.19.6 which is incompatible.\n", + "tensorflow-metadata 1.15.0 requires protobuf<4.21,>=3.20.3; python_version < \"3.11\", but you have protobuf 3.19.6 which is incompatible.\n", + "tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.11.0 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q transformers tensorflow==2.11.0 sentencepiece" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B8X4Hk2hB-_r" + }, + "source": [ + "- HuggingFace comes with a native `saved_model` feature inside `save_pretrained` function for TensorFlow based models. We will use that to save it as TF `SavedModel`.\n", + "- We'll use [tblard/tf-allocine](https://huggingface.co/tblard/tf-allocine) model from HuggingFace as an example\n", + "- In addition to `TFCamembertForSequenceClassification` we also need to save the `CamembertTokenizer`. This is the same for every model, these are assets needed for tokenization inside Spark NLP." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 309, + "referenced_widgets": [ + "d3da72562dcc4e9c94358ba582c33a0c", + "89c1e714cb14442fb01379befe80c3b7", + "0b5b53b7e07c4953809e70aa6460b48d", + "8f87d812959a44428e254cfd464f8416", + "207ab3bfc47e4c92a9a22eda29697a5d", + "547119d6ee0f482c9c28239749c0db49", + "44f07151e3d14ff2afc48487c88a3de2", + "fb023dc778604d18b2ae71e6f134fb35", + "75584e408d5a45918592959d1166cd61", + "6cf9c6a811d74dacbceaad5705d5c71a", + "d68dd62aa2874a288586d5e9f1b05ab8", + "850ad48422454088bb8ec39ab92dc783", + "5e78097f12b84b829d7e6c935839c131", + "94c128fadd7f4b9e95f336aad0a8d47d", + "0287e67435584efdb1fdcd50ca71e7f1", + "adb519dd31f84449a24377400799661b", + "30b687c0c6be4e31bba624e771a0d884", + "6b00c16ba3634a2388373e987abe2c2a", + "9eeb299678f248829a0eb28c9a912602", + "3195a3ef78654ee794622cf5cf8936ec", + "28b489e3099d4818bf4fda60decfc5ec", + "f20623e2768b41f18007e0e522abd51c" + ] + }, + "id": "ovk9IONTB-_r", + "outputId": "fab2ab05-8676-4b8e-baad-749bb1c39451" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d3da72562dcc4e9c94358ba582c33a0c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0%| | 0.00/929 [00:00=3.20, but you have protobuf 3.19.6 which is incompatible.\n", + "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 16.1.0 which is incompatible.\n", + "google-cloud-iam 2.15.1 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-language 2.13.4 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-pubsub 2.22.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-resource-manager 1.12.4 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-colab 1.0.0 requires requests==2.31.0, but you have requests 2.32.3 which is incompatible.\n", + "googleapis-common-protos 1.63.2 requires protobuf!=3.20.0,!=3.20.1,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0.dev0,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "grpc-google-iam-v1 0.13.1 requires protobuf!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 16.1.0 which is incompatible.\n", + "pandas-gbq 0.19.2 requires google-auth-oauthlib>=0.7.0, but you have google-auth-oauthlib 0.4.6 which is incompatible.\n", + "tensorflow-datasets 4.9.6 requires protobuf>=3.20, but you have protobuf 3.19.6 which is incompatible.\n", + "tensorflow-metadata 1.15.0 requires protobuf<4.21,>=3.20.3; python_version < \"3.11\", but you have protobuf 3.19.6 which is incompatible.\n", + "tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.11.0 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q --upgrade transformers[onnx] optimum sentencepiece tensorflow==2.11.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KwJwuD8zOXM0" + }, + "source": [ + "- HuggingFace has an extension called Optimum which offers specialized model inference, including ONNX. We can use this to import and export ONNX models with `from_pretrained` and `save_pretrained`.\n", + "- We'll use [tblard/tf-allocine](https://huggingface.co/tblard/tf-allocine) model from HuggingFace as an example and load it as a `ORTModelForSequenceClassification`, representing an ONNX model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 420, + "referenced_widgets": [ + "eeeeff26c6d9466faa41b39c5a4fa669", + "df01abea998d42eaaa701105782abba0", + "b9d9b8a937964b85a66e8353c4d44f8d", + "9cbc3634d1d746b18c09b2ca6fe61a30", + "d2c6d0d6130e49f98360959d38f74d95", + "11b50210f33e4a8997b8c760d8e4ffe0", + "0ff93a6722634ec2b466371a2c968f9d", + "7657f8b6f33540d6b49aaf15bb73614a", + "f2ca64e852a94f9ba7dc4bf773a695ab", + "e1e015f4af6b47ec8b59807f35b2db07", + "5b2d19d102144e6f87b71be01c483b68", + "55c2a35733684494bdfa1d4fd053711c", + "3b2a3da584844cb3b83a193024c9ad40", + "d4a3ffb0fb924e8eb9eec47241abc317", + "6de7a0c6f13e4d8bb2c6af5d71be674e", + "b6402bb4b4884f79891914a56390b3e9", + "11f8cdf8c0c44daeab6303f50c7db818", + "acc61b054f424ec5848cea43e6a69672", + "337d30d86ffb4d068bc27a96b464d678", + "21f2c0ea2f7a441ca8fd88dd4a56f338", + "8be26127b62348dab6093cbd885a103a", + "99b4797f0808425b8dc0419ada9512fa", + "5e6f4369c79e4ff8aa5fa080d3658488", + "df594133572b4844a83e93d0c8db40c0", + "8f49629b99a84d5cb310c6530efc1012", + "49d0187f5da14cc99694bc08899e16a9", + "4b49e41be2b043a79d11aa73110b6258", + "a9c757b180fc4bfc93be39dd8d3f9cdd", + "3672f6662ce844c98f6d574ad37fbfaf", + "5f1c186f31d140aaa29e72299cdde746", + "504e7a1e464640f0beca9058a4011e1d", + "6a726bdd6135416e91637a034c5623fb", + "47a352883ad64b5ab3880c7dc23a1681", + "424a9faacc9b485aae5c472d68a55d66", + "c25961ebc54e48109438e724711c9b33", + "c7c6ed46459341caaebe21e339f2d6a7", + "fbb749170f894741a6eb818b19d576c3", + "112531db46384f73af826f020f3df253", + "7a04e2480edf4b9b81322c45c0210e55", + "9173d5126c794468a40f83a2744808d6", + "dff293d56f7b484fb54f0c7d2ddd1643", + "08b3f87119504658b9a6deea86adce96", + "099c6c294cfe49779b2e46932748cd7e", + "8112a8a0c34a474bb0229c91b0ae88e0", + "fb8ab0c7d7af4fdc9a6f99a06c53b681", + "11593c36fc804d4db38f1b156c2d9456", + "2269c361542a4ad6840513d5ba598ed5", + "54bd32b592f842d6ab4cbafe7486714b", + "40dc70033f81402db224255a72a90ede", + "bb494ef4a5f046d09a8b4cb52baac036", + "9f45321331ca42cd85ff84f0a7aa0fe8", + "9412a08772a941acabd75ded6573120d", + "b381561928f44dce9da1d692ae01e9f1", + "2ea468d131be464a8a74f9ec88f9eb4c", + "8b79db60b5494b54964acb5048532d34" + ] + }, + "id": "xm2lRUOxOXM0", + "outputId": "f4d1715c-85dd-434c-d4de-acbb8647f584" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "eeeeff26c6d9466faa41b39c5a4fa669", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0%| | 0.00/929 [00:00 False\n" + ] + } + ], + "source": [ + "from optimum.onnxruntime import ORTModelForSequenceClassification\n", + "import tensorflow as tf\n", + "\n", + "MODEL_NAME = 'mtheo/camembert-base-xnli'\n", + "ONNX_MODEL = f\"onnx_models/{MODEL_NAME}\"\n", + "\n", + "ort_model = ORTModelForSequenceClassification.from_pretrained(MODEL_NAME, export=True)\n", + "\n", + "# Save the ONNX model\n", + "ort_model.save_pretrained(ONNX_MODEL)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "42GXDuPCOXM1" + }, + "source": [ + "Let's have a look inside these two directories and see what we are dealing with:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NTxJiuBoOXM1", + "outputId": "1744db38-d85f-455b-e6a7-b9ada6f75c44" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total 434784\n", + "-rw-r--r-- 1 root root 933 Jul 15 17:02 config.json\n", + "-rw-r--r-- 1 root root 442778957 Jul 15 17:02 model.onnx\n", + "-rw-r--r-- 1 root root 1038 Jul 15 17:02 special_tokens_map.json\n", + "-rw-r--r-- 1 root root 1589 Jul 15 17:02 tokenizer_config.json\n", + "-rw-r--r-- 1 root root 2418946 Jul 15 17:02 tokenizer.json\n" + ] + } + ], + "source": [ + "!ls -l {ONNX_MODEL}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VU-_uvb9TUL6" + }, + "source": [ + "We are using based model for the tokenizer because the model `mtheo/camembert-base-xnli` does not have sentencepiece" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "klcWcbmhTFhC", + "outputId": "82141c67-fed0-471c-9d09-b620cc4498d5" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenizer loaded successfully!\n", + "Tokenizer saved successfully!\n" + ] + } + ], + "source": [ + "from transformers import TFCamembertForSequenceClassification, CamembertTokenizer\n", + "import tensorflow as tf\n", + "\n", + "# Try to load the tokenizer\n", + "try:\n", + " tokenizer = CamembertTokenizer.from_pretrained('camembert-base')\n", + " print(\"Tokenizer loaded successfully!\")\n", + "except OSError as e:\n", + " print(f\"Error loading tokenizer: {e}\")\n", + "\n", + "# Try to save the tokenizer\n", + "try:\n", + " tokenizer.save_pretrained(ONNX_MODEL)\n", + " print(\"Tokenizer saved successfully!\")\n", + "except Exception as e:\n", + " print(f\"Error saving tokenizer: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rUab53MhT9dj", + "outputId": "69ca57f6-d080-418d-f8a5-9174c0f130d0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total 435580\n", + "-rw-r--r-- 1 root root 22 Jul 15 17:03 added_tokens.json\n", + "-rw-r--r-- 1 root root 933 Jul 15 17:02 config.json\n", + "-rw-r--r-- 1 root root 442778957 Jul 15 17:02 model.onnx\n", + "-rw-r--r-- 1 root root 810912 Jul 15 17:03 sentencepiece.bpe.model\n", + "-rw-r--r-- 1 root root 374 Jul 15 17:03 special_tokens_map.json\n", + "-rw-r--r-- 1 root root 1783 Jul 15 17:03 tokenizer_config.json\n", + "-rw-r--r-- 1 root root 2418946 Jul 15 17:02 tokenizer.json\n" + ] + } + ], + "source": [ + "!ls -l {ONNX_MODEL}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bx24WfBQOXM1" + }, + "outputs": [], + "source": [ + "!mkdir {ONNX_MODEL}/assets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lUKXGzFYOXM1" + }, + "source": [ + "- As you can see, we need to move `sentencepiece.bpe.model` from the tokenizer to assets folder which Spark NLP will look for\n", + "- In addition to vocabs, we also need `labels` and their `ids` which is saved inside the model's config. We will save this inside `labels.txt`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GT-EioRiOXM1" + }, + "outputs": [], + "source": [ + "# get label2id dictionary\n", + "labels = ort_model.config.id2label\n", + "# sort the dictionary based on the id\n", + "labels = [value for key,value in sorted(labels.items(), reverse=False)]\n", + "\n", + "with open(ONNX_MODEL + '/assets/labels.txt', 'w') as f:\n", + " f.write('\\n'.join(labels))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mYD9y9qDOXM2" + }, + "outputs": [], + "source": [ + "!mv {ONNX_MODEL}/sentencepiece.bpe.model {ONNX_MODEL}/assets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oTyrVcwPOXM2" + }, + "source": [ + "Voila! We have our `sentencepiece.bpe.model` and `labels.txt` inside assets directory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "XTzghoZ9OXM2", + "outputId": "ad030c33-e4ff-4c89-b695-6885cf13e268" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "onnx_models/mtheo/camembert-base-xnli:\n", + "total 434792\n", + "-rw-r--r-- 1 root root 22 Jul 15 17:03 added_tokens.json\n", + "drwxr-xr-x 2 root root 4096 Jul 15 17:03 assets\n", + "-rw-r--r-- 1 root root 933 Jul 15 17:02 config.json\n", + "-rw-r--r-- 1 root root 442778957 Jul 15 17:02 model.onnx\n", + "-rw-r--r-- 1 root root 374 Jul 15 17:03 special_tokens_map.json\n", + "-rw-r--r-- 1 root root 1783 Jul 15 17:03 tokenizer_config.json\n", + "-rw-r--r-- 1 root root 2418946 Jul 15 17:02 tokenizer.json\n", + "\n", + "onnx_models/mtheo/camembert-base-xnli/assets:\n", + "total 796\n", + "-rw-r--r-- 1 root root 32 Jul 15 17:03 labels.txt\n", + "-rw-r--r-- 1 root root 810912 Jul 15 17:03 sentencepiece.bpe.model\n" + ] + } + ], + "source": [ + "!ls -lR {ONNX_MODEL}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MtZFJ3geOXM2" + }, + "source": [ + "## Import and Save CamemBertForZeroShotClassification in Spark NLP\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tKGcJKB_OXM2" + }, + "source": [ + "- Let's install and setup Spark NLP in Google Colab\n", + "- This part is pretty easy via our simple script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "r2-9OsKzOXM2", + "outputId": "d7ba439c-2ff4-4569-ed5e-52b6b4f45e3d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2023-11-03 19:56:31-- http://setup.johnsnowlabs.com/colab.sh\n", + "Resolving setup.johnsnowlabs.com (setup.johnsnowlabs.com)... 51.158.130.125\n", + "Connecting to setup.johnsnowlabs.com (setup.johnsnowlabs.com)|51.158.130.125|:80... connected.\n", + "HTTP request sent, awaiting response... 302 Moved Temporarily\n", + "Location: https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp/master/scripts/colab_setup.sh [following]\n", + "--2023-11-03 19:56:31-- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp/master/scripts/colab_setup.sh\n", + "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n", + "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 1191 (1.2K) [text/plain]\n", + "Saving to: β€˜STDOUT’\n", + "\n", + "- 100%[===================>] 1.16K --.-KB/s in 0s \n", + "\n", + "2023-11-03 19:56:31 (92.1 MB/s) - written to stdout [1191/1191]\n", + "\n", + "Installing PySpark 3.2.3 and Spark NLP 5.1.4\n", + "setup Colab for PySpark 3.2.3 and Spark NLP 5.1.4\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.5/281.5 MB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m540.7/540.7 kB\u001b[0m \u001b[31m41.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m199.7/199.7 kB\u001b[0m \u001b[31m22.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for pyspark (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "! wget http://setup.johnsnowlabs.com/colab.sh -O - | bash" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uQrimXpqOXM2" + }, + "source": [ + "Let's start Spark with Spark NLP included via our simple `start()` function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Kl2EKK-lOXM3", + "outputId": "418a3ccd-ad76-468d-cd97-17fee4fc22aa" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = _posixsubprocess.fork_exec(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Apache Spark version: 3.4.0\n" + ] + } + ], + "source": [ + "import sparknlp\n", + "# let's start Spark with Spark NLP\n", + "spark = sparknlp.start()\n", + "\n", + "print(\"Apache Spark version: {}\".format(spark.version))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cIoqa8jpOXM3" + }, + "source": [ + "- Let's use `loadSavedModel` functon in `CamemBertForSequenceClassification` which allows us to load TensorFlow model in SavedModel format\n", + "- Most params can be set later when you are loading this model in `CamemBertForSequenceClassification` in runtime like `setMaxSentenceLength`, so don't worry what you are setting them now\n", + "- `loadSavedModel` accepts two params, first is the path to the TF SavedModel. The second is the SparkSession that is `spark` variable we previously started via `sparknlp.start()`\n", + "- NOTE: `loadSavedModel` accepts local paths in addition to distributed file systems such as `HDFS`, `S3`, `DBFS`, etc. This feature was introduced in Spark NLP 4.2.2 release. Keep in mind the best and recommended way to move/share/reuse Spark NLP models is to use `write.save` so you can use `.load()` from any file systems natively." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "K6W_LZSLOXM3" + }, + "outputs": [], + "source": [ + "from sparknlp.annotator import *\n", + "\n", + "zero_shot_classifier = CamemBertForZeroShotClassification.loadSavedModel(\n", + " f\"{ONNX_MODEL}\",\n", + " spark\n", + " )\\\n", + " .setInputCols([\"document\",'token'])\\\n", + " .setOutputCol(\"class\")\\\n", + " .setCaseSensitive(True)\\\n", + " .setMaxSentenceLength(128)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "U--84qnxOXM3" + }, + "source": [ + "- Let's save it on disk so it is easier to be moved around and also be used later via `.load` function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FeL-PyNLOXM3" + }, + "outputs": [], + "source": [ + "zero_shot_classifier.write().overwrite().save(\"./{}_spark_nlp_onnx\".format(ONNX_MODEL))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KiX_wtJzOXM3" + }, + "source": [ + "Let's clean up stuff we don't need anymore" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-QDg7Y6bOXM3" + }, + "outputs": [], + "source": [ + "!rm -rf {ONNX_MODEL}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Bfk3hT24OXM3" + }, + "source": [ + "Awesome 😎 !\n", + "\n", + "This is your CamemBertForZeroShotClassification model from HuggingFace πŸ€— loaded and saved by Spark NLP πŸš€" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "BQgl3FIvOXM4", + "outputId": "99ac2903-b98e-4d7f-96f9-7aa0c7073afb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total 433272\n", + "-rw-r--r-- 1 root root 442846656 Jul 15 17:05 camembert_classification_onnx\n", + "-rw-r--r-- 1 root root 810912 Jul 15 17:05 camembert_spp\n", + "drwxr-xr-x 3 root root 4096 Jul 15 17:05 fields\n", + "drwxr-xr-x 2 root root 4096 Jul 15 17:05 metadata\n" + ] + } + ], + "source": [ + "! ls -l {ONNX_MODEL}_spark_nlp_onnx" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Fz3MtMcUOXM4" + }, + "source": [ + "Now let's see how we can use it on other machines, clusters, or any place you wish to use your new and shiny CamemBertForZeroShotClassification model 😊" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IEbYe8JfOXM4" + }, + "outputs": [], + "source": [ + "zero_shot_classifier_loaded = CamemBertForZeroShotClassification.load(\"./{}_spark_nlp_onnx\".format(ONNX_MODEL))\\\n", + " .setInputCols([\"document\",'token'])\\\n", + " .setOutputCol(\"class\") \\\n", + " .setCandidateLabels([\"sport\", \"politique\", \"science\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OiGTom7aOXM4" + }, + "source": [ + "You can see what labels were used to train this model via `getClasses` function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "33b5pv9KOXM4", + "outputId": "619fa545-11e9-48b0-c20a-9fc793bbb7d0" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['neutral', 'contradiction', 'entailment']" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# .getClasses was introduced in spark-nlp==3.4.0\n", + "zero_shot_classifier_loaded.getClasses()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i9DBBy2QOXM4" + }, + "source": [ + "This is how you can use your loaded classifier model in Spark NLP πŸš€ pipeline:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pojxPjIfOXM4", + "outputId": "00179235-95f2-4ff1-a8c9-835803f7ba6a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+--------------------+-----------+\n", + "| text| result|\n", + "+--------------------+-----------+\n", + "|Alad'2 est claire...|[politique]|\n", + "|Je m'attendais Γ  ...|[politique]|\n", + "+--------------------+-----------+\n", + "\n" + ] + } + ], + "source": [ + "from pyspark.ml import Pipeline\n", + "\n", + "from sparknlp.base import *\n", + "from sparknlp.annotator import *\n", + "\n", + "document_assembler = DocumentAssembler() \\\n", + " .setInputCol('text') \\\n", + " .setOutputCol('document')\n", + "\n", + "tokenizer = Tokenizer() \\\n", + " .setInputCols(['document']) \\\n", + " .setOutputCol('token')\n", + "\n", + "pipeline = Pipeline(stages=[\n", + " document_assembler,\n", + " tokenizer,\n", + " zero_shot_classifier_loaded\n", + "])\n", + "\n", + "# couple of simple examples\n", + "example = spark.createDataFrame([[\"Alad'2 est clairement le meilleur film de l'annΓ©e 2018.\"], [\"Je m'attendais Γ  mieux de la part de Franck Dubosc !\"]]).toDF(\"text\")\n", + "\n", + "result = pipeline.fit(example).transform(example)\n", + "\n", + "# result is a DataFrame\n", + "result.select(\"text\", \"class.result\").show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MlzWRhQyOXM5" + }, + "source": [ + "That's it! You can now go wild and use hundreds of `CamemBertForZeroShotClassification` models from HuggingFace πŸ€— in Spark NLP πŸš€\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "L4", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "08b3f87119504658b9a6deea86adce96": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "099c6c294cfe49779b2e46932748cd7e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0ff93a6722634ec2b466371a2c968f9d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "112531db46384f73af826f020f3df253": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "11593c36fc804d4db38f1b156c2d9456": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bb494ef4a5f046d09a8b4cb52baac036", + "placeholder": "​", + "style": "IPY_MODEL_9f45321331ca42cd85ff84f0a7aa0fe8", + "value": "special_tokens_map.json: 100%" + } + }, + "11b50210f33e4a8997b8c760d8e4ffe0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "11f8cdf8c0c44daeab6303f50c7db818": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "21f2c0ea2f7a441ca8fd88dd4a56f338": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "2269c361542a4ad6840513d5ba598ed5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9412a08772a941acabd75ded6573120d", + "max": 354, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_b381561928f44dce9da1d692ae01e9f1", + "value": 354 + } + }, + "2ea468d131be464a8a74f9ec88f9eb4c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "337d30d86ffb4d068bc27a96b464d678": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3672f6662ce844c98f6d574ad37fbfaf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3b2a3da584844cb3b83a193024c9ad40": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_11f8cdf8c0c44daeab6303f50c7db818", + "placeholder": "​", + "style": "IPY_MODEL_acc61b054f424ec5848cea43e6a69672", + "value": "model.safetensors: 100%" + } + }, + "40dc70033f81402db224255a72a90ede": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "424a9faacc9b485aae5c472d68a55d66": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_c25961ebc54e48109438e724711c9b33", + "IPY_MODEL_c7c6ed46459341caaebe21e339f2d6a7", + "IPY_MODEL_fbb749170f894741a6eb818b19d576c3" + ], + "layout": "IPY_MODEL_112531db46384f73af826f020f3df253" + } + }, + "47a352883ad64b5ab3880c7dc23a1681": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "49d0187f5da14cc99694bc08899e16a9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_6a726bdd6135416e91637a034c5623fb", + "placeholder": "​", + "style": "IPY_MODEL_47a352883ad64b5ab3880c7dc23a1681", + "value": " 516/516 [00:00<00:00, 44.6kB/s]" + } + }, + "4b49e41be2b043a79d11aa73110b6258": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "504e7a1e464640f0beca9058a4011e1d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "54bd32b592f842d6ab4cbafe7486714b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2ea468d131be464a8a74f9ec88f9eb4c", + "placeholder": "​", + "style": "IPY_MODEL_8b79db60b5494b54964acb5048532d34", + "value": " 354/354 [00:00<00:00, 28.4kB/s]" + } + }, + "55c2a35733684494bdfa1d4fd053711c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_3b2a3da584844cb3b83a193024c9ad40", + "IPY_MODEL_d4a3ffb0fb924e8eb9eec47241abc317", + "IPY_MODEL_6de7a0c6f13e4d8bb2c6af5d71be674e" + ], + "layout": "IPY_MODEL_b6402bb4b4884f79891914a56390b3e9" + } + }, + "5b2d19d102144e6f87b71be01c483b68": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5e6f4369c79e4ff8aa5fa080d3658488": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_df594133572b4844a83e93d0c8db40c0", + "IPY_MODEL_8f49629b99a84d5cb310c6530efc1012", + "IPY_MODEL_49d0187f5da14cc99694bc08899e16a9" + ], + "layout": "IPY_MODEL_4b49e41be2b043a79d11aa73110b6258" + } + }, + "5f1c186f31d140aaa29e72299cdde746": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6a726bdd6135416e91637a034c5623fb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6de7a0c6f13e4d8bb2c6af5d71be674e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8be26127b62348dab6093cbd885a103a", + "placeholder": "​", + "style": "IPY_MODEL_99b4797f0808425b8dc0419ada9512fa", + "value": " 443M/443M [00:27<00:00, 14.7MB/s]" + } + }, + "7657f8b6f33540d6b49aaf15bb73614a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7a04e2480edf4b9b81322c45c0210e55": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8112a8a0c34a474bb0229c91b0ae88e0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "8b79db60b5494b54964acb5048532d34": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "8be26127b62348dab6093cbd885a103a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8f49629b99a84d5cb310c6530efc1012": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5f1c186f31d140aaa29e72299cdde746", + "max": 516, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_504e7a1e464640f0beca9058a4011e1d", + "value": 516 + } + }, + "9173d5126c794468a40f83a2744808d6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9412a08772a941acabd75ded6573120d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "99b4797f0808425b8dc0419ada9512fa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9cbc3634d1d746b18c09b2ca6fe61a30": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e1e015f4af6b47ec8b59807f35b2db07", + "placeholder": "​", + "style": "IPY_MODEL_5b2d19d102144e6f87b71be01c483b68", + "value": " 929/929 [00:00<00:00, 81.4kB/s]" + } + }, + "9f45321331ca42cd85ff84f0a7aa0fe8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a9c757b180fc4bfc93be39dd8d3f9cdd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "acc61b054f424ec5848cea43e6a69672": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b381561928f44dce9da1d692ae01e9f1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b6402bb4b4884f79891914a56390b3e9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b9d9b8a937964b85a66e8353c4d44f8d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7657f8b6f33540d6b49aaf15bb73614a", + "max": 929, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f2ca64e852a94f9ba7dc4bf773a695ab", + "value": 929 + } + }, + "bb494ef4a5f046d09a8b4cb52baac036": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c25961ebc54e48109438e724711c9b33": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7a04e2480edf4b9b81322c45c0210e55", + "placeholder": "​", + "style": "IPY_MODEL_9173d5126c794468a40f83a2744808d6", + "value": "tokenizer.json: 100%" + } + }, + "c7c6ed46459341caaebe21e339f2d6a7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_dff293d56f7b484fb54f0c7d2ddd1643", + "max": 2420898, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_08b3f87119504658b9a6deea86adce96", + "value": 2420898 + } + }, + "d2c6d0d6130e49f98360959d38f74d95": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d4a3ffb0fb924e8eb9eec47241abc317": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_337d30d86ffb4d068bc27a96b464d678", + "max": 442525380, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_21f2c0ea2f7a441ca8fd88dd4a56f338", + "value": 442525380 + } + }, + "df01abea998d42eaaa701105782abba0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_11b50210f33e4a8997b8c760d8e4ffe0", + "placeholder": "​", + "style": "IPY_MODEL_0ff93a6722634ec2b466371a2c968f9d", + "value": "config.json: 100%" + } + }, + "df594133572b4844a83e93d0c8db40c0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a9c757b180fc4bfc93be39dd8d3f9cdd", + "placeholder": "​", + "style": "IPY_MODEL_3672f6662ce844c98f6d574ad37fbfaf", + "value": "tokenizer_config.json: 100%" + } + }, + "dff293d56f7b484fb54f0c7d2ddd1643": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e1e015f4af6b47ec8b59807f35b2db07": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "eeeeff26c6d9466faa41b39c5a4fa669": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_df01abea998d42eaaa701105782abba0", + "IPY_MODEL_b9d9b8a937964b85a66e8353c4d44f8d", + "IPY_MODEL_9cbc3634d1d746b18c09b2ca6fe61a30" + ], + "layout": "IPY_MODEL_d2c6d0d6130e49f98360959d38f74d95" + } + }, + "f2ca64e852a94f9ba7dc4bf773a695ab": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "fb8ab0c7d7af4fdc9a6f99a06c53b681": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_11593c36fc804d4db38f1b156c2d9456", + "IPY_MODEL_2269c361542a4ad6840513d5ba598ed5", + "IPY_MODEL_54bd32b592f842d6ab4cbafe7486714b" + ], + "layout": "IPY_MODEL_40dc70033f81402db224255a72a90ede" + } + }, + "fbb749170f894741a6eb818b19d576c3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_099c6c294cfe49779b2e46932748cd7e", + "placeholder": "​", + "style": "IPY_MODEL_8112a8a0c34a474bb0229c91b0ae88e0", + "value": " 2.42M/2.42M [00:01<00:00, 1.23MB/s]" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/python/sparknlp/annotator/classifier_dl/__init__.py b/python/sparknlp/annotator/classifier_dl/__init__.py index ed3e7160892349..b01c166c40999e 100644 --- a/python/sparknlp/annotator/classifier_dl/__init__.py +++ b/python/sparknlp/annotator/classifier_dl/__init__.py @@ -52,4 +52,5 @@ from sparknlp.annotator.classifier_dl.mpnet_for_sequence_classification import * from sparknlp.annotator.classifier_dl.mpnet_for_question_answering import * from sparknlp.annotator.classifier_dl.mpnet_for_token_classification import * -from sparknlp.annotator.classifier_dl.albert_for_zero_shot_classification import * \ No newline at end of file +from sparknlp.annotator.classifier_dl.albert_for_zero_shot_classification import * +from sparknlp.annotator.classifier_dl.camembert_for_zero_shot_classification import * diff --git a/python/sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py b/python/sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py new file mode 100644 index 00000000000000..7b16c4475e5511 --- /dev/null +++ b/python/sparknlp/annotator/classifier_dl/camembert_for_zero_shot_classification.py @@ -0,0 +1,202 @@ +# Copyright 2017-2024 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains classes for CamemBertForSequenceClassification.""" + +from sparknlp.common import * + + +class CamemBertForZeroShotClassification(AnnotatorModel, + HasCaseSensitiveProperties, + HasBatchedAnnotate, + HasClassifierActivationProperties, + HasCandidateLabelsProperties, + HasEngine, + HasMaxSentenceLengthLimit): + """CamemBertForZeroShotClassification using a `ModelForSequenceClassification` trained on NLI (natural language + inference) tasks. Equivalent of `DeBertaForSequenceClassification` models, but these models don't require a hardcoded + number of potential classes, they can be chosen at runtime. It usually means it's slower but it is much more + flexible. + Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis + pair and passed to the pretrained model. + Pretrained models can be loaded with :meth:`.pretrained` of the companion + object: + >>> sequenceClassifier = CamemBertForZeroShotClassification.pretrained() \\ + ... .setInputCols(["token", "document"]) \\ + ... .setOutputCol("label") + The default model is ``"camembert_zero_shot_classifier_xnli_onnx"``, if no name is + provided. + For available pretrained models please see the `Models Hub + `__. + To see which models are compatible and how to import them see + `Import Transformers into Spark NLP πŸš€ + `_. + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT, TOKEN`` ``CATEGORY`` + ====================== ====================== + Parameters + ---------- + batchSize + Batch size. Large values allows faster processing but requires more + memory, by default 8 + caseSensitive + Whether to ignore case in tokens for embeddings matching, by default + True + configProtoBytes + ConfigProto from tensorflow, serialized into byte array. + maxSentenceLength + Max sentence length to process, by default 128 + coalesceSentences + Instead of 1 class per sentence (if inputCols is `sentence`) output 1 + class per document by averaging probabilities in all sentences, by + default False + activation + Whether to calculate logits via Softmax or Sigmoid, by default + `"softmax"`. + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = DocumentAssembler() \\ + ... .setInputCol("text") \\ + ... .setOutputCol("document") + >>> tokenizer = Tokenizer() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("token") + >>> sequenceClassifier = CamemBertForZeroShotClassification.pretrained() \\ + ... .setInputCols(["token", "document"]) \\ + ... .setOutputCol("multi_class") \\ + ... .setCaseSensitive(True) + ... .setCandidateLabels(["sport", "politique", "science"]) + >>> pipeline = Pipeline().setStages([ + ... documentAssembler, + ... tokenizer, + ... sequenceClassifier + ... ]) + >>> data = spark.createDataFrame([["L'Γ©quipe de France joue aujourd'hui au Parc des Princes"]]).toDF("text") + >>> result = pipeline.fit(data).transform(data) + >>> result.select("class.result").show(truncate=False) + +------+ + |result| + +------+ + |[sport]| + +------+ + """ + name = "CamemBertForZeroShotClassification" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.TOKEN] + + outputAnnotatorType = AnnotatorType.CATEGORY + + configProtoBytes = Param(Params._dummy(), + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()", + TypeConverters.toListInt) + + coalesceSentences = Param(Params._dummy(), "coalesceSentences", + "Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging probabilities in all sentences.", + TypeConverters.toBoolean) + + def getClasses(self): + """ + Returns labels used to train this model + """ + return self._call_java("getClasses") + + def setConfigProtoBytes(self, b): + """Sets configProto from tensorflow, serialized into byte array. + + Parameters + ---------- + b : List[int] + ConfigProto from tensorflow, serialized into byte array + """ + return self._set(configProtoBytes=b) + + def setCoalesceSentences(self, value): + """Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 + class per document by averaging probabilities in all sentences, by default True. + + Due to max sequence length limit in almost all transformer models such as BERT + (512 tokens), this parameter helps feeding all the sentences into the model and + averaging all the probabilities for the entire document instead of probabilities + per sentence. + + Parameters + ---------- + value : bool + If the output of all sentences will be averaged to one output + """ + return self._set(coalesceSentences=value) + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.CamemBertForZeroShotClassification", + java_model=None): + super(CamemBertForZeroShotClassification, self).__init__( + classname=classname, + java_model=java_model + ) + self._setDefault( + batchSize=8, + maxSentenceLength=128, + caseSensitive=True, + coalesceSentences=False, + activation="softmax" + ) + + @staticmethod + def loadSavedModel(folder, spark_session): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + CamemBertForZeroShotClassification + The restored model + """ + from sparknlp.internal import _CamemBertForZeroShotClassificationLoader + jModel = _CamemBertForZeroShotClassificationLoader(folder, spark_session._jsparkSession)._java_obj + return CamemBertForZeroShotClassification(java_model=jModel) + + @staticmethod + def pretrained(name="camembert_zero_shot_classifier_xnli_onnx", lang="fr", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default + "camembert_zero_shot_classifier_xnli_onnx" + lang : str, optional + Language of the pretrained model, by default "fr" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + CamemBertForSequenceClassification + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(CamemBertForZeroShotClassification, name, lang, remote_loc) diff --git a/python/sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py b/python/sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py index 51b04f47d39ee8..6e8364cf8caaf4 100644 --- a/python/sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py +++ b/python/sparknlp/annotator/classifier_dl/deberta_for_zero_shot_classification.py @@ -21,7 +21,8 @@ class DeBertaForZeroShotClassification(AnnotatorModel, HasBatchedAnnotate, HasClassifierActivationProperties, HasCandidateLabelsProperties, - HasEngine): + HasEngine, + HasMaxSentenceLengthLimit): """DeBertaForZeroShotClassification using a `ModelForSequenceClassification` trained on NLI (natural language inference) tasks. Equivalent of `DeBertaForSequenceClassification` models, but these models don't require a hardcoded number of potential classes, they can be chosen at runtime. It usually means it's slower but it is much more @@ -101,11 +102,6 @@ class per document by averaging probabilities in all sentences, by outputAnnotatorType = AnnotatorType.CATEGORY - maxSentenceLength = Param(Params._dummy(), - "maxSentenceLength", - "Max sentence length to process", - typeConverter=TypeConverters.toInt) - configProtoBytes = Param(Params._dummy(), "configProtoBytes", "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()", @@ -130,15 +126,6 @@ def setConfigProtoBytes(self, b): """ return self._set(configProtoBytes=b) - def setMaxSentenceLength(self, value): - """Sets max sentence length to process, by default 128. - Parameters - ---------- - value : int - Max sentence length to process - """ - return self._set(maxSentenceLength=value) - def setCoalesceSentences(self, value): """Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging probabilities in all sentences. Due to max sequence length limit in almost all transformer models such as DeBerta diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index adf19667279fbe..c8732ef3ecb4e5 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -860,6 +860,13 @@ def __init__(self, path, jspark): jspark, ) +class _CamemBertForZeroShotClassificationLoader(ExtendedJavaWrapper): + def __init__(self, path, jspark): + super(_CamemBertForZeroShotClassificationLoader, self).__init__( + "com.johnsnowlabs.nlp.annotators.classifier.dl.CamemBertForZeroShotClassification.loadSavedModel", + path, + jspark, + ) class _RobertaQAToZeroShotNerLoader(ExtendedJavaWrapper): def __init__(self, path): diff --git a/python/test/annotator/classifier_dl/camembert_for_zero_shot_classification_test.py b/python/test/annotator/classifier_dl/camembert_for_zero_shot_classification_test.py new file mode 100644 index 00000000000000..57f1bc8256f2ce --- /dev/null +++ b/python/test/annotator/classifier_dl/camembert_for_zero_shot_classification_test.py @@ -0,0 +1,57 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from test.annotator.common.has_max_sentence_length_test import HasMaxSentenceLengthTests +from test.util import SparkContextForTest + + +@pytest.mark.slow +class CamemBertForZeroShotClassificationTestSpec(unittest.TestCase, HasMaxSentenceLengthTests): + def setUp(self): + self.text = "L'Γ©quipe de France joue aujourd'hui au Parc des Princes" + self.data = SparkContextForTest.spark \ + .createDataFrame([[self.text]]).toDF("text") + + self.tested_annotator = CamemBertForZeroShotClassification \ + .pretrained() \ + .setInputCols(["document", "token"]) \ + .setOutputCol("class") + + def test_run(self): + document_assembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("document") + + tokenizer = Tokenizer().setInputCols("document").setOutputCol("token") + + doc_classifier = self.tested_annotator + + pipeline = Pipeline(stages=[ + document_assembler, + tokenizer, + doc_classifier + ]) + + model = pipeline.fit(self.data) + model.transform(self.data).show() + + light_pipeline = LightPipeline(model) + annotations_result = light_pipeline.fullAnnotate(self.text) + print(annotations_result) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/CamemBertClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/CamemBertClassification.scala index aa2eac4270f4c9..3b2d7ef2c7663d 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/CamemBertClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/CamemBertClassification.scala @@ -16,15 +16,16 @@ package com.johnsnowlabs.ml.ai -import ai.onnxruntime.OnnxTensor +import ai.onnxruntime.{OnnxTensor, OrtEnvironment} import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ +import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.BasicTokenizer import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} -import org.tensorflow.ndarray.buffer.LongDataBuffer +import org.tensorflow.ndarray.buffer.{IntDataBuffer, LongDataBuffer} import org.slf4j.{Logger, LoggerFactory} import scala.collection.JavaConverters._ @@ -97,7 +98,19 @@ private[johnsnowlabs] class CamemBertClassification( def tokenizeSeqString( candidateLabels: Seq[String], maxSeqLength: Int, - caseSensitive: Boolean): Seq[WordpieceTokenizedSentence] = ??? + caseSensitive: Boolean): Seq[WordpieceTokenizedSentence] = { + val basicTokenizer = new BasicTokenizer(caseSensitive) + val encoder = + new SentencepieceEncoder(spp, caseSensitive, sentencePieceDelimiterId, pieceIdOffset = 1) + + val labelsToSentences = candidateLabels.map { s => Sentence(s, 0, s.length - 1, 0) } + + labelsToSentences.map(label => { + val tokens = basicTokenizer.tokenize(label) + val wordpieceTokens = tokens.flatMap(token => encoder.encode(token)).take(maxSeqLength) + WordpieceTokenizedSentence(wordpieceTokens) + }) + } def tokenizeDocument( docs: Seq[Annotation], @@ -142,6 +155,7 @@ private[johnsnowlabs] class CamemBertClassification( private def getRawScoresWithTF(batch: Seq[Array[Int]], maxSentenceLength: Int): Array[Float] = { val tensors = new TensorResources() +// val (tokenBuffers, maskBuffers) = initializeTFLongTensorResources(batch, tensors, maxSentenceLength) val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max val batchLength = batch.length @@ -192,17 +206,8 @@ private[johnsnowlabs] class CamemBertClassification( } private def getRawScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = { - - // [nb of encoded sentences , maxSentenceLength] val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) - - val tokenTensors = - OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) - val maskTensors = - OnnxTensor.createTensor( - env, - batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) - + val (tokenTensors, maskTensors) = initializeOnnxTensorResources(batch, env) val inputs = Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava @@ -260,13 +265,114 @@ private[johnsnowlabs] class CamemBertClassification( batch: Seq[Array[Int]], entailmentId: Int, contradictionId: Int, - activation: String): Array[Array[Float]] = ??? + activation: String): Array[Array[Float]] = { + + val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max + val paddedBatch = batch.map(arr => padArrayWithZeros(arr, maxSentenceLength)) + val batchLength = paddedBatch.length + + val rawScores = detectedEngine match { + case TensorFlow.name => computeZeroShotLogitsWithTF(paddedBatch, maxSentenceLength) + case ONNX.name => computeZeroShotLogitsWithONNX(paddedBatch) + } + + val dim = rawScores.length / batchLength + rawScores + .grouped(dim) + .toArray + } + + def computeZeroShotLogitsWithONNX(batch: Seq[Array[Int]]): Array[Float] = { + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) + val (tokenTensors, maskTensors) = initializeOnnxTensorResources(batch, env) + val inputs = + Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava + + try { + val results = runner.run(inputs) + try { + val embeddings = results + .get("logits") + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + tokenTensors.close() + maskTensors.close() + + embeddings + } finally if (results != null) results.close() + } + } + + def computeZeroShotLogitsWithTF( + batch: Seq[Array[Int]], + maxSentenceLength: Int): Array[Float] = { + val tensors = new TensorResources() + val (tokenBuffers, maskBuffers, segmentBuffers) = + initializeTFIntTensorResources(batch, tensors, maxSentenceLength) + // [nb of encoded sentences , maxSentenceLength] + val shape = Array(batch.length.toLong, maxSentenceLength) + + batch.zipWithIndex + .foreach { case (sentence, idx) => + val offset = idx * maxSentenceLength + tokenBuffers.offset(offset).write(sentence) + maskBuffers + .offset(offset) + .write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1)) + segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0)) + } + + val runner = tensorflowWrapper.get + .getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false) + .runner + + val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers) + val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers) + val segmentTensors = tensors.createIntBufferTensor(shape, segmentBuffers) + + runner + .feed( + _tfCamemBertSignatures.getOrElse( + ModelSignatureConstants.InputIds.key, + "missing_input_id_key"), + tokenTensors) + .feed( + _tfCamemBertSignatures + .getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"), + maskTensors) + .feed( + _tfCamemBertSignatures + .getOrElse(ModelSignatureConstants.TokenTypeIds.key, "missing_segment_ids_key"), + segmentTensors) + .fetch(_tfCamemBertSignatures + .getOrElse(ModelSignatureConstants.LogitsOutput.key, "missing_logits_key")) + + val outs = runner.run().asScala + val rawScores = TensorResources.extractFloats(outs.head) + + outs.foreach(_.close()) + tensors.clearSession(outs) + tensors.clearTensors() + + rawScores + } + + private def padArrayWithZeros(arr: Array[Int], maxLength: Int): Array[Int] = { + if (arr.length >= maxLength) { + arr + } else { + arr ++ Array.fill(maxLength - arr.length)(0) + } + } def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { val batchLength = batch.length + val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max val (startLogits, endLogits) = detectedEngine match { case ONNX.name => computeLogitsWithOnnx(batch) - case _ => computeLogitsWithTF(batch) + case TensorFlow.name => computeLogitsWithTF(batch, maxSentenceLength) } val endDim = endLogits.length / batchLength @@ -280,14 +386,12 @@ private[johnsnowlabs] class CamemBertClassification( (startScores, endScores) } - private def computeLogitsWithTF(batch: Seq[Array[Int]]): (Array[Float], Array[Float]) = { + private def computeLogitsWithTF( + batch: Seq[Array[Int]], + maxSentenceLength: Int): (Array[Float], Array[Float]) = { val tensors = new TensorResources() - - val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max - val batchLength = batch.length - - val tokenBuffers: LongDataBuffer = tensors.createLongBuffer(batchLength * maxSentenceLength) - val maskBuffers: LongDataBuffer = tensors.createLongBuffer(batchLength * maxSentenceLength) + val (tokenBuffers, maskBuffers) = + initializeTFLongTensorResources(batch, tensors, maxSentenceLength) // [nb of encoded sentences , maxSentenceLength] val shape = Array(batch.length.toLong, maxSentenceLength) @@ -335,17 +439,32 @@ private[johnsnowlabs] class CamemBertClassification( (startLogits, endLogits) } - private def computeLogitsWithOnnx(batch: Seq[Array[Int]]): (Array[Float], Array[Float]) = { - // [nb of encoded sentences] - val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) + private def initializeTFLongTensorResources( + batch: Seq[Array[Int]], + tensors: TensorResources, + maxSentenceLength: Int): (LongDataBuffer, LongDataBuffer) = { + val batchLength = batch.length + val dim = batchLength * maxSentenceLength + val tokenBuffers: LongDataBuffer = tensors.createLongBuffer(dim) + val maskBuffers: LongDataBuffer = tensors.createLongBuffer(dim) + (tokenBuffers, maskBuffers) + } - val tokenTensors = - OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) - val maskTensors = - OnnxTensor.createTensor( - env, - batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) + private def initializeTFIntTensorResources( + batch: Seq[Array[Int]], + tensors: TensorResources, + maxSentenceLength: Int): (IntDataBuffer, IntDataBuffer, IntDataBuffer) = { + val batchLength = batch.length + val dim = batchLength * maxSentenceLength + val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(dim) + val maskBuffers: IntDataBuffer = tensors.createIntBuffer(dim) + val segmentBuffers: IntDataBuffer = tensors.createIntBuffer(dim) + (tokenBuffers, maskBuffers, segmentBuffers) + } + private def computeLogitsWithOnnx(batch: Seq[Array[Int]]): (Array[Float], Array[Float]) = { + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) + val (tokenTensors, maskTensors) = initializeOnnxTensorResources(batch, env) val inputs = Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava @@ -380,6 +499,17 @@ private[johnsnowlabs] class CamemBertClassification( } } + private def initializeOnnxTensorResources(batch: Seq[Array[Int]], env: OrtEnvironment) = { + val tokenTensors = + OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) + val maskTensors = + OnnxTensor.createTensor( + env, + batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) + + (tokenTensors, maskTensors) + } + def findIndexedToken( tokenizedSentences: Seq[TokenizedSentence], sentence: (WordpieceTokenizedSentence, Int), diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotator.scala b/src/main/scala/com/johnsnowlabs/nlp/annotator.scala index 1b46ec8330bc48..36cd023551e263 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotator.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotator.scala @@ -552,6 +552,9 @@ package object annotator { object CamemBertEmbeddings extends ReadablePretrainedCamemBertModel with ReadCamemBertDLModel + type CamemBertaForZeroShotClassification = + com.johnsnowlabs.nlp.annotators.classifier.dl.CamemBertForZeroShotClassification + type SpanBertCorefModel = com.johnsnowlabs.nlp.annotators.coref.SpanBertCorefModel object SpanBertCorefModel diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForZeroShotClassification.scala new file mode 100644 index 00000000000000..4a5bcde0e87ef1 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForZeroShotClassification.scala @@ -0,0 +1,415 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.classifier.dl + +import com.johnsnowlabs.ml.ai.CamemBertClassification +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.tensorflow.{ + ReadTensorflowModel, + TensorflowWrapper, + WriteTensorflowModel +} +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ + ReadSentencePieceModel, + SentencePieceWrapper, + WriteSentencePieceModel +} +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadSentencePieceAsset, + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} +import com.johnsnowlabs.nlp.annotators.common.{SentenceSplit, TokenizedWithSentence} +import com.johnsnowlabs.nlp.serialization.MapFeature +import com.johnsnowlabs.nlp.{ + Annotation, + AnnotatorModel, + AnnotatorType, + HasBatchedAnnotate, + HasCandidateLabelsProperties, + HasCaseSensitiveProperties, + HasClassifierActivationProperties, + HasEngine, + HasPretrained, + ParamsAndFeaturesReadable +} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param.{BooleanParam, IntArrayParam, IntParam} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.SparkSession + +class CamemBertForZeroShotClassification(override val uid: String) + extends AnnotatorModel[CamemBertForZeroShotClassification] + with HasBatchedAnnotate[CamemBertForZeroShotClassification] + with WriteTensorflowModel + with WriteOnnxModel + with WriteSentencePieceModel + with HasCaseSensitiveProperties + with HasClassifierActivationProperties + with HasEngine + with HasCandidateLabelsProperties { + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + + def this() = this(Identifiable.randomUID("CamemBertForZeroShotClassification")) + + /** Input Annotator Types: DOCUMENT, TOKEN + * + * @group anno + */ + override val inputAnnotatorTypes: Array[String] = + Array(AnnotatorType.DOCUMENT, AnnotatorType.TOKEN) + + /** Output Annotator Types: CATEGORY + * + * @group anno + */ + override val outputAnnotatorType: AnnotatorType = AnnotatorType.CATEGORY + + /** Labels used to decode predicted IDs back to string tags + * + * @group param + */ + val labels: MapFeature[String, Int] = new MapFeature(this, "labels").setProtected() + + /** @group setParam */ + def setLabels(value: Map[String, Int]): this.type = set(labels, value) + + /** Returns labels used to train this model */ + def getClasses: Array[String] = { + $$(labels).keys.toArray + } + + /** Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document + * by averaging probabilities in all sentences (Default: `false`). + * + * Due to max sequence length limit in almost all transformer models such as DeBerta (512 + * tokens), this parameter helps feeding all the sentences into the model and averaging all the + * probabilities for the entire document instead of probabilities per sentence. + * + * @group param + */ + val coalesceSentences = new BooleanParam( + this, + "coalesceSentences", + "If sets to true the output of all sentences will be averaged to one output instead of one output per sentence. Defaults to false.") + + /** @group setParam */ + def setCoalesceSentences(value: Boolean): this.type = set(coalesceSentences, value) + + /** @group getParam */ + def getCoalesceSentences: Boolean = $(coalesceSentences) + + /** ConfigProto from tensorflow, serialized into byte array. Get with + * `config_proto.SerializeToString()` + * + * @group param + */ + val configProtoBytes = new IntArrayParam( + this, + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()") + + /** @group setParam */ + def setConfigProtoBytes(bytes: Array[Int]): CamemBertForZeroShotClassification.this.type = + set(this.configProtoBytes, bytes) + + /** @group getParam */ + def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte)) + + /** Max sentence length to process (Default: `128`) + * + * @group param + */ + val maxSentenceLength = + new IntParam(this, "maxSentenceLength", "Max sentence length to process") + + /** @group setParam */ + def setMaxSentenceLength(value: Int): this.type = { + require( + value <= 512, + "DeBerta models do not support sequences longer than 512 because of trainable positional embeddings.") + require(value >= 1, "The maxSentenceLength must be at least 1") + set(maxSentenceLength, value) + this + } + + /** @group getParam */ + def getMaxSentenceLength: Int = $(maxSentenceLength) + + /** It contains TF model signatures for the laded saved model + * + * @group param + */ + val signatures = + new MapFeature[String, String](model = this, name = "signatures").setProtected() + + /** @group setParam */ + def setSignatures(value: Map[String, String]): this.type = { + set(signatures, value) + this + } + + /** @group getParam */ + def getSignatures: Option[Map[String, String]] = get(this.signatures) + + private var _model: Option[Broadcast[CamemBertClassification]] = None + + /** @group setParam */ + def setModelIfNotSet( + spark: SparkSession, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper], + spp: SentencePieceWrapper): CamemBertForZeroShotClassification = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new CamemBertClassification( + tensorflowWrapper, + onnxWrapper, + spp, + configProtoBytes = None, + tags = $$(labels), + signatures = getSignatures, + threshold = $(threshold)))) + } + + this + } + + /** Whether to lowercase tokens or not (Default: `true`). + * + * @group setParam + */ + override def setCaseSensitive(value: Boolean): this.type = { + set(this.caseSensitive, value) + } + + /** @group getParam */ + def getModelIfNotSet: CamemBertClassification = _model.get.value + + setDefault( + batchSize -> 8, + maxSentenceLength -> 128, + caseSensitive -> true, + coalesceSentences -> false) + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations in batches that correspond to inputAnnotationCols generated by previous + * annotators if any + * @return + * any number of annotations processed for every batch of input annotations. Not necessary + * one to one relationship + * + * IMPORTANT: !MUST! return sequences of equal lengths !! IMPORTANT: !MUST! return sentences + * that belong to the same original row !! (challenging) + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + batchedAnnotations.map(annotations => { + val sentences = SentenceSplit.unpack(annotations).toArray + val tokenizedSentences = TokenizedWithSentence.unpack(annotations).toArray + + if (tokenizedSentences.nonEmpty) { + getModelIfNotSet.predictSequenceWithZeroShot( + tokenizedSentences, + sentences, + $(candidateLabels), + $(entailmentIdParam), + $(contradictionIdParam), + $(batchSize), + $(maxSentenceLength), + $(caseSensitive), + $(coalesceSentences), + $$(labels), + getActivation) + + } else { + Seq.empty[Annotation] + } + }) + } + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + val suffix = "_camembert_classification" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + CamemBertForSequenceClassification.tfFile) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + CamemBertForSequenceClassification.onnxFile) + } + + writeSentencePieceModel( + path, + spark, + getModelIfNotSet.spp, + "_camembert", + CamemBertForSequenceClassification.sppFile) + + } +} + +trait ReadPretrainedCamemBertForZeroShotClassification + extends ParamsAndFeaturesReadable[CamemBertForZeroShotClassification] + with HasPretrained[CamemBertForZeroShotClassification] { + override val defaultModelName: Some[String] = Some("camembert_zero_shot_classifier_xnli_onnx") + override val defaultLang: String = "fr" + + override def pretrained(): CamemBertForZeroShotClassification = super.pretrained() + + override def pretrained(name: String): CamemBertForZeroShotClassification = + super.pretrained(name) + + override def pretrained(name: String, lang: String): CamemBertForZeroShotClassification = + super.pretrained(name, lang) + + override def pretrained( + name: String, + lang: String, + remoteLoc: String): CamemBertForZeroShotClassification = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadCamemBertForZeroShotClassification + extends ReadTensorflowModel + with ReadOnnxModel + with ReadSentencePieceModel { + this: ParamsAndFeaturesReadable[CamemBertForZeroShotClassification] => + + override val tfFile: String = "camembert_classification_tensorflow" + override val onnxFile: String = "camembert_classification_onnx" + override val sppFile: String = "camembert_spp" + + def readModel( + instance: CamemBertForZeroShotClassification, + path: String, + spark: SparkSession): Unit = { + + val spp = readSentencePieceModel(path, spark, "_camembert_spp", sppFile) + + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = readTensorflowModel(path, spark, "_camembert_classification_tf") + instance.setModelIfNotSet(spark, Some(tfWrapper), None, spp) + case ONNX.name => + val onnxWrapper = + readOnnxModel( + path, + spark, + "camembert_zero_classification_onnx", + zipped = true, + useBundle = false, + None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) + case _ => + throw new Exception(notSupportedEngineError) + } + + } + + addReader(readModel) + + def loadSavedModel( + modelPath: String, + spark: SparkSession): CamemBertForZeroShotClassification = { + + val (localModelPath, detectedEngine) = modelSanityCheck(modelPath) + + val spModel = loadSentencePieceAsset(localModelPath, "sentencepiece.bpe.model") + val labels = loadTextAsset(localModelPath, "labels.txt").zipWithIndex.toMap + + val entailmentIds = labels.filter(x => x._1.toLowerCase().startsWith("entail")).values.toArray + val contradictionIds = + labels.filter(x => x._1.toLowerCase().startsWith("contradict")).values.toArray + + require( + entailmentIds.length == 1 && contradictionIds.length == 1, + s"""This annotator supports classifiers trained on NLI datasets. You must have only at least 2 or maximum 3 labels in your dataset: + + example with 3 labels: 'contradict', 'neutral', 'entailment' + example with 2 labels: 'contradict', 'entailment' + + You can modify assets/labels.txt file to match the above format. + + Current labels: ${labels.keys.mkString(", ")} + """) + + val annotatorModel = new CamemBertForZeroShotClassification() + .setLabels(labels) + .setCandidateLabels(labels.keys.toArray) + + /* set the entailment id */ + annotatorModel.set(annotatorModel.entailmentIdParam, entailmentIds.head) + /* set the contradiction id */ + annotatorModel.set(annotatorModel.contradictionIdParam, contradictionIds.head) + /* set the engine */ + annotatorModel.set(annotatorModel.engine, detectedEngine) + + detectedEngine match { + case TensorFlow.name => + val (wrapper, signatures) = + TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) + + val _signatures = signatures match { + case Some(s) => s + case None => throw new Exception("Cannot load signature definitions from model!") + } + + /** the order of setSignatures is important if we use getSignatures inside + * setModelIfNotSet + */ + annotatorModel + .setSignatures(_signatures) + .setModelIfNotSet(spark, Some(wrapper), None, spModel) + case ONNX.name => + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) + annotatorModel.setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } + +} + +/** This is the companion object of [[CamemBertForZeroShotClassification]]. Please refer to that + * class for the documentation. + */ + +object CamemBertForZeroShotClassification + extends ReadPretrainedCamemBertForZeroShotClassification + with ReadCamemBertForZeroShotClassification diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala index 8ed41de985baa9..c4be567a26683b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala @@ -687,7 +687,8 @@ object PythonResourceDownloader { "AutoGGUFModel" -> AutoGGUFModel, "AlbertForZeroShotClassification" -> AlbertForZeroShotClassification, "MxbaiEmbeddings" -> MxbaiEmbeddings, - "SnowFlakeEmbeddings" -> SnowFlakeEmbeddings + "SnowFlakeEmbeddings" -> SnowFlakeEmbeddings, + "CamemBertForZeroShotClassification" -> CamemBertForZeroShotClassification ) // List pairs of types such as the one with key type can load a pretrained model from the value type diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceMetadata.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceMetadata.scala index 4f53b37db985c2..992708e86c0992 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceMetadata.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceMetadata.scala @@ -35,7 +35,8 @@ case class ResourceMetadata( isZipped: Boolean = false, category: Option[String] = Some(ResourceType.NOT_DEFINED.toString), checksum: String = "", - annotator: Option[String] = None) + annotator: Option[String] = None, + engine: Option[String] = None) extends Ordered[ResourceMetadata] { lazy val key: String = { diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForZeroShotClassificationTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForZeroShotClassificationTestSpec.scala new file mode 100644 index 00000000000000..af9f114920a1ec --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForZeroShotClassificationTestSpec.scala @@ -0,0 +1,99 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.classifier.dl + +import com.johnsnowlabs.nlp.annotators.Tokenizer +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.SlowTest +import org.apache.spark.ml.Pipeline +import org.apache.spark.sql.functions.explode +import org.scalatest.flatspec.AnyFlatSpec + +class CamemBertForZeroShotClassificationTestSpec extends AnyFlatSpec { + + "CamemBertForZeroShotClassification" should "correctly load custom ONNX model" taggedAs SlowTest in { + import ResourceHelper.spark.implicits._ + + val dataDf = Seq("L'Γ©quipe de France joue aujourd'hui au Parc des Princes").toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val tokenizer = new Tokenizer() + .setInputCols(Array("document")) + .setOutputCol("token") + + val zeroShotClassifier = CamemBertForZeroShotClassification + .pretrained() + .setOutputCol("multi_class") + .setCaseSensitive(true) + .setCoalesceSentences(true) + .setCandidateLabels(Array("sport", "politique", "science")) + + val pipeline = new Pipeline().setStages(Array(document, tokenizer, zeroShotClassifier)) + + val pipelineModel = pipeline.fit(dataDf) + val pipelineDF = pipelineModel.transform(dataDf) + + pipelineDF.select("multi_class").show(false) + val totalDocs = pipelineDF.select(explode($"document.result")).count.toInt + val totalLabels = pipelineDF.select(explode($"multi_class.result")).count.toInt + + println(s"total tokens: $totalDocs") + println(s"total labels: $totalLabels") + + assert(totalDocs == totalLabels) + } + + it should "correctly load custom Tensorflow model" taggedAs SlowTest in { + import ResourceHelper.spark.implicits._ + + val dataDf = Seq("L'Γ©quipe de France joue aujourd'hui au Parc des Princes").toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val tokenizer = new Tokenizer() + .setInputCols(Array("document")) + .setOutputCol("token") + + val zeroShotClassifier = CamemBertForZeroShotClassification + .pretrained("camembert-zero-shot-classifier-xnli-tf") + .setOutputCol("multi_class") + .setCaseSensitive(true) + .setCoalesceSentences(true) + .setCandidateLabels(Array("sport", "politique", "science")) + + val pipeline = new Pipeline().setStages(Array(document, tokenizer, zeroShotClassifier)) + + val pipelineModel = pipeline.fit(dataDf) + val pipelineDF = pipelineModel.transform(dataDf) + + pipelineDF.select("multi_class").show(false) + val totalDocs = pipelineDF.select(explode($"document.result")).count.toInt + val totalLabels = pipelineDF.select(explode($"multi_class.result")).count.toInt + + println(s"total tokens: $totalDocs") + println(s"total labels: $totalLabels") + + assert(totalDocs == totalLabels) + } + +}