diff --git a/examples/python/transformers/HuggingFace_in_Spark_NLP_AlbertForZeroShotClassification.ipynb b/examples/python/transformers/HuggingFace_in_Spark_NLP_AlbertForZeroShotClassification.ipynb new file mode 100644 index 00000000000000..39296e1081a74e --- /dev/null +++ b/examples/python/transformers/HuggingFace_in_Spark_NLP_AlbertForZeroShotClassification.ipynb @@ -0,0 +1,3053 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "jyP44g5ieZsq" + }, + "source": [ + "![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/HuggingFace_in_Spark_NLP_AlbertForZeroShotClassification.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "W8r6iDKceZs7" + }, + "source": [ + "## Import AlbertForZeroShotClassification models from HuggingFace 🤗 into Spark NLP 🚀\n", + "\n", + "Let's keep in mind a few things before we start 😊\n", + "\n", + "- This feature is only in `Spark NLP 5.4.2` and after. So please make sure you have upgraded to the latest Spark NLP release\n", + "- You can import ALBERT models trained/fine-tuned for sequence classification via `AlbertForSequenceClassification` or `TFAlbertForSequenceClassification`. These models are usually under `Text Classification` category and have `albert` in their labels\n", + "- Reference: [TFAlbertForSequenceClassification](https://huggingface.co/docs/transformers/model_doc/albert#transformers.TFAlbertForSequenceClassification)\n", + "- Some [example models](https://huggingface.co/models?filter=albert&pipeline_tag=text-classification)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oh2BmRddeZtA" + }, + "source": [ + "## Export and Save HuggingFace model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I3OaTwVYeZtC" + }, + "source": [ + "- Let's install `HuggingFace` and `TensorFlow`. You don't need `TensorFlow` to be installed for Spark NLP, however, we need it to load and save models from HuggingFace.\n", + "- We lock TensorFlow on `2.11.0` version and Transformers on `4.25.1`. This doesn't mean it won't work with the future releases, but we wanted you to know which versions have been tested successfully.\n", + "- Albert uses SentencePiece, so we will have to install that as well" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "PFN7a31JeZtE", + "outputId": "2d8b80dc-91af-4dd1-fe9d-29751b994458" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.8/8.8 MB\u001b[0m \u001b[31m13.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m588.3/588.3 MB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m71.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m56.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.0/6.0 MB\u001b[0m \u001b[31m116.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m439.2/439.2 kB\u001b[0m \u001b[31m31.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.6/3.6 MB\u001b[0m \u001b[31m97.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m102.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m781.3/781.3 kB\u001b[0m \u001b[31m45.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "cudf-cu12 24.4.1 requires protobuf<5,>=3.20, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-bigquery-connection 1.15.5 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-bigtable 2.25.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-functions 1.16.5 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-iam 2.15.2 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-language 2.13.4 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-pubsub 2.23.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-resource-manager 1.12.5 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-translate 3.15.5 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "googleapis-common-protos 1.63.2 requires protobuf!=3.20.0,!=3.20.1,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0.dev0,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "grpc-google-iam-v1 0.13.1 requires protobuf!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "pandas-gbq 0.19.2 requires google-auth-oauthlib>=0.7.0, but you have google-auth-oauthlib 0.4.6 which is incompatible.\n", + "tensorflow-datasets 4.9.6 requires protobuf>=3.20, but you have protobuf 3.19.6 which is incompatible.\n", + "tensorflow-metadata 1.15.0 requires protobuf<4.21,>=3.20.3; python_version < \"3.11\", but you have protobuf 3.19.6 which is incompatible.\n", + "tf-keras 2.17.0 requires tensorflow<2.18,>=2.17, but you have tensorflow 2.11.0 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q transformers==4.39.3 tensorflow==2.11.0 sentencepiece" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AKsQCgipeZtI" + }, + "source": [ + "- HuggingFace comes with a native `saved_model` feature inside `save_pretrained` function for TensorFlow based models. We will use that to save it as TF `SavedModel`.\n", + "- We'll use [DAMO-NLP-SG/zero-shot-classify-SSTuning-ALBERT'](https://huggingface.co/mohsenfayyaz/albert-base-v2-toxicity) model from HuggingFace as an example\n", + "- In addition to `TFAlbertForSequenceClassification` we also need to save the `AlbertTokenizer`. This is the same for every model, these are assets needed for tokenization inside Spark NLP." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 379, + "referenced_widgets": [ + "87015e2373034217a1b8ac76441caf49", + "411fd89d1a31427f9bf9e35b38c2f420", + "3f5840d43cc54c90ad334c1c28116e29", + "4563668896c84756b3eadd789dc759d7", + "0920e6780a0d4a1bac328451ab722dc3", + "ceb2224ee447481d87cfea97517fbca6", + "a5ed208cf9854ed5bdb6f3b36168456e", + "a15f448686e945faa99f2316dffe177c", + "2c1158f9725141df9c996da47ae030d0", + "cdcf470194a14f8f9149146c4dcb567c", + "243d96a25a5447faa1f145c7af93ae22", + "db12325d7df341c9a5d8c52b9486b6c0", + "bcb8fa3afe3c498090af08ec31d3bf49", + "9fe306b8d7014a45b5a9ddccd5c09380", + "1483578c50ad492a8b5f57927c994410", + "8f6274003b304dc2a15ecb3ebe3961b7", + "803a8344b2344a7691851f117dd4ac06", + "225bb44c41d24f3ebf6e27952a5fd61b", + "dd975e4a7cb34eca8a97b3e02f502aa5", + "d45e2ac8c5494450abeb8b5d0a6a9297", + "cba34126485e4bd984347ffe9cf0f387", + "35434d241da946f9b48c584eeaa1e61c" + ] + }, + "id": "7n9_7HNoeZtL", + "outputId": "aef3c068-340c-436a-88cb-cc2563c34a29" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "try downloading TF weights\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "87015e2373034217a1b8ac76441caf49", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0%| | 0.00/1.77k [00:00=3.20, but you have protobuf 3.19.6 which is incompatible.\n", + "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.\n", + "gcsfs 2024.6.1 requires fsspec==2024.6.1, but you have fsspec 2024.5.0 which is incompatible.\n", + "google-cloud-bigquery-connection 1.15.5 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-bigtable 2.25.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-functions 1.16.5 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-iam 2.15.2 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-language 2.13.4 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-pubsub 2.23.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-resource-manager 1.12.5 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "google-cloud-translate 3.15.5 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "googleapis-common-protos 1.63.2 requires protobuf!=3.20.0,!=3.20.1,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0.dev0,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "grpc-google-iam-v1 0.13.1 requires protobuf!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.\n", + "ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 17.0.0 which is incompatible.\n", + "pandas-gbq 0.19.2 requires google-auth-oauthlib>=0.7.0, but you have google-auth-oauthlib 0.4.6 which is incompatible.\n", + "tensorflow-datasets 4.9.6 requires protobuf>=3.20, but you have protobuf 3.19.6 which is incompatible.\n", + "tensorflow-metadata 1.15.0 requires protobuf<4.21,>=3.20.3; python_version < \"3.11\", but you have protobuf 3.19.6 which is incompatible.\n", + "tf-keras 2.17.0 requires tensorflow<2.18,>=2.17, but you have tensorflow 2.11.0 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q --upgrade transformers[onnx] optimum sentencepiece tensorflow==2.11.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KwJwuD8zOXM0" + }, + "source": [ + "- HuggingFace has an extension called Optimum which offers specialized model inference, including ONNX. We can use this to import and export ONNX models with `from_pretrained` and `save_pretrained`.\n", + "- We'll use [DAMO-NLP-SG/zero-shot-classify-SSTuning-ALBERT](https://huggingface.co/DAMO-NLP-SG/zero-shot-classify-SSTuning-ALBERT) model from HuggingFace as an example and load it as a `ORTModelForSequenceClassification`, representing an ONNX model." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 405, + "referenced_widgets": [ + "7fbb3e3d9c3b48428a2cc92e0f36eb47", + "3d1a3ca0fc3a486786276252536ac558", + "9813d17825064a5181ba0f9dceaceb9f", + "16589466aa89479abeaa4710fde6a73a", + "544511d40b1143df84c7b157ddc91c41", + "30769a6d996c4bcca7b30eed1503eed9", + "869468fe84414b91962ace9341b2f1a9", + "3272a0508cb34aea83948f9c23c9650a", + "2c844aa44d2b4055a714307f2080022c", + "a2d080b727634940875e96a9147f34e7", + "12da96ef71224fdab90d4b3fd6a28e42", + "5fe97bffe2744b4cbc60ce550b2503a1", + "06b5057f15a14efcbb03a4018ccc3da2", + "838508bbc957403ab7074d8210e4e6ed", + "171828a82e744618bae44d7f32eae31d", + "b2d5f2b331e54fb18428ae9682f653ca", + "9731c969cf3346629fe8b7dbda556e43", + "9140cbda612a4496ae820029b8646d67", + "13bab1fdb3c54fe1b0895220df19e829", + "10f3d7dbabcc4c8098a8c055bf136190", + "1852726ae1184e6f9f5dd6bb1d1fce2c", + "c959f6d9fc344f338241f75fe6e18c06", + "ce64ac856dc6472385eccab5342738c1", + "978eab94b9864aabb6d3f83113f85940", + "47e85899badd419ebc15a6751ce5bd43", + "edf118b9dbc94f29b090733b70c61cd4", + "2a02b4b49cab4f42b416e6831df6ea7e", + "756282ecf7494603ae07ade2d30abcd0", + "ee064242089e4a90b15d2eb4add1f145", + "e5d1a59fb70242b183ec32dddc497212", + "c78cd69d67f94d1796626c926a626448", + "61817859027e4613931ce32d1e3e494e", + "feafc3adf2d64b9d8da9b8cf1618ac9c", + "4d2a53af0aaa4501b35f7fec8ff0b273", + "d06e3fa91b214eb19ffe1071109302d3", + "ffa7c44d4ca4465bbc1c030aaff0b6fa", + "73e798ee33eb42888f2d1f0ea4e4b268", + "3db7f93e0cad41af98fd790a03304829", + "e6bef63825424e4ebb3350b7de6a99e0", + "c191db451caa433ea8547ab3affd76f1", + "1725349d5c2940c5a74fc8d4676d84e2", + "91e93912c6704288b1ead18fcb9aee9e", + "ca52d3613af14540828687ba3d9748ab", + "4d04cd7024a341268f759d2790836a0d", + "6e38416e898f429484b8399994a90bae", + "c921fb0d282141a3964b0f8d5c91dc4b", + "f737165ed69c4b199c9fac9bdd18faa8", + "61fbc6373e374ceca4889cfbc91cf5e1", + "0a81dfc4f5b64173b1577ece9bf2c2f7", + "b2cb083a419e4c34b64ad545d638a11b", + "af5ec264af794e9288c62127b8bb52d7", + "d95ec2867f90451195d80eb731dc651f", + "c89bdfb092574fb5b003adaf9bffe1c0", + "5019b11b53bf425190337f008132d56e", + "6d39b93d1da246a78770497eb9403b6a" + ] + }, + "id": "xm2lRUOxOXM0", + "outputId": "6c8c4b8d-1092-4842-a548-2405cf48def9" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7fbb3e3d9c3b48428a2cc92e0f36eb47", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0%| | 0.00/1.77k [00:00>> sequenceClassifier = AlbertForZeroShotClassification.pretrained() \\ + ... .setInputCols(["token", "document"]) \\ + ... .setOutputCol("label") + + The default model is ``"albert_base_zero_shot_classifier_onnx"``, if no name is + provided. + + For available pretrained models please see the `Models Hub + `__. + + To see which models are compatible and how to import them see + `Import Transformers into Spark NLP 🚀 + `_. + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT, TOKEN`` ``CATEGORY`` + ====================== ====================== + + Parameters + ---------- + batchSize + Batch size. Large values allows faster processing but requires more + memory, by default 8 + caseSensitive + Whether to ignore case in tokens for embeddings matching, by default + True + configProtoBytes + ConfigProto from tensorflow, serialized into byte array. + maxSentenceLength + Max sentence length to process, by default 128 + coalesceSentences + Instead of 1 class per sentence (if inputCols is `sentence`) output 1 + class per document by averaging probabilities in all sentences, by + default False + activation + Whether to calculate logits via Softmax or Sigmoid, by default + `"softmax"`. + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = DocumentAssembler() \\ + ... .setInputCol("text") \\ + ... .setOutputCol("document") + >>> tokenizer = Tokenizer() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("token") + >>> sequenceClassifier = AlbertForZeroShotClassification.pretrained() \\ + ... .setInputCols(["token", "document"]) \\ + ... .setOutputCol("label") \\ + ... .setCaseSensitive(True) + >>> pipeline = Pipeline().setStages([ + ... documentAssembler, + ... tokenizer, + ... sequenceClassifier + ... ]) + >>> data = spark.createDataFrame([["I have a problem with my iphone that needs to be resolved asap!!"]]).toDF("text") + >>> result = pipeline.fit(data).transform(data) + >>> result.select("label.result").show(truncate=False) + +---------+ + |result | + +---------+ + |[urgent] | + +---------+ + """ + name = "AlbertForZeroShotClassification" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.TOKEN] + + outputAnnotatorType = AnnotatorType.CATEGORY + + configProtoBytes = Param(Params._dummy(), + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()", + TypeConverters.toListInt) + + coalesceSentences = Param(Params._dummy(), "coalesceSentences", + "Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging probabilities in all sentences.", + TypeConverters.toBoolean) + + def getClasses(self): + """ + Returns labels used to train this model + """ + return self._call_java("getClasses") + + def setConfigProtoBytes(self, b): + """Sets configProto from tensorflow, serialized into byte array. + + Parameters + ---------- + b : List[int] + ConfigProto from tensorflow, serialized into byte array + """ + return self._set(configProtoBytes=b) + + def setCoalesceSentences(self, value): + """Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging + probabilities in all sentences. Due to max sequence length limit in almost all transformer models such as Bart + (512 tokens), this parameter helps to feed all the sentences into the model and averaging all the probabilities + for the entire document instead of probabilities per sentence. (Default: true) + + Parameters + ---------- + value : bool + If the output of all sentences will be averaged to one output + """ + return self._set(coalesceSentences=value) + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.AlbertForZeroShotClassification", + java_model=None): + super(AlbertForZeroShotClassification, self).__init__( + classname=classname, + java_model=java_model + ) + self._setDefault( + batchSize=8, + maxSentenceLength=128, + caseSensitive=True, + coalesceSentences=False, + activation="softmax" + ) + + @staticmethod + def loadSavedModel(folder, spark_session): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + AlbertForZeroShotClassification + The restored model + """ + from sparknlp.internal import _AlbertForZeroShotClassificationLoader + jModel = _AlbertForZeroShotClassificationLoader(folder, spark_session._jsparkSession)._java_obj + return AlbertForZeroShotClassification(java_model=jModel) + + @staticmethod + def pretrained(name="albert_zero_shot_classifier_onnx", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default + "albert_zero_shot_classifier_onnx" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + BartForZeroShotClassification + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(AlbertForZeroShotClassification, name, lang, remote_loc) \ No newline at end of file diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index dc53a009506319..50588a08fe09de 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -58,6 +58,15 @@ def __init__(self, path, jspark): ) +class _AlbertForZeroShotClassificationLoader(ExtendedJavaWrapper): + def __init__(self, path, jspark): + super(_AlbertForZeroShotClassificationLoader, self).__init__( + "com.johnsnowlabs.nlp.annotators.classifier.dl.AlbertForZeroShotClassification.loadSavedModel", + path, + jspark, + ) + + class _BertLoader(ExtendedJavaWrapper): def __init__(self, path, jspark, use_openvino=False): super(_BertLoader, self).__init__( diff --git a/python/test/annotator/classifier_dl/albert_for_zero_shot_classification_test.py b/python/test/annotator/classifier_dl/albert_for_zero_shot_classification_test.py new file mode 100644 index 00000000000000..7afda4131c4470 --- /dev/null +++ b/python/test/annotator/classifier_dl/albert_for_zero_shot_classification_test.py @@ -0,0 +1,60 @@ +# Copyright 2017-2024 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from test.annotator.common.has_max_sentence_length_test import HasMaxSentenceLengthTests +from test.util import SparkContextForTest + + +@pytest.mark.slow +class AlbertForZeroShotClassificationTestSpec(unittest.TestCase, HasMaxSentenceLengthTests): + def setUp(self): + self.text = "I have a problem with my iphone that needs to be resolved asap!!" + self.data = SparkContextForTest.spark \ + .createDataFrame([[self.text]]).toDF("text") + self.candidate_labels = ["urgent", "mobile", "technology"] + + self.tested_annotator = AlbertForZeroShotClassification \ + .pretrained()\ + .setInputCols(["document", "token"]) \ + .setOutputCol("multi_class") \ + .setCandidateLabels(self.candidate_labels) + + def test_run(self): + document_assembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("document") + + tokenizer = Tokenizer().setInputCols("document").setOutputCol("token") + + doc_classifier = self.tested_annotator + + pipeline = Pipeline(stages=[ + document_assembler, + tokenizer, + doc_classifier + ]) + + model = pipeline.fit(self.data) + model.transform(self.data).show() + + light_pipeline = LightPipeline(model) + annotations_result = light_pipeline.fullAnnotate(self.text) + multi_class_result = annotations_result[0]["multi_class"][0].result + self.assertIn(multi_class_result, self.candidate_labels) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala index 748f0d4f81e771..206d6175d12206 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala @@ -25,10 +25,12 @@ import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignat import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ONNX, Openvino, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ +import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.BasicTokenizer import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} import org.intel.openvino.Tensor import org.tensorflow.ndarray.buffer.IntDataBuffer import org.slf4j.{Logger, LoggerFactory} +import org.tensorflow.ndarray.buffer.IntDataBuffer import scala.collection.JavaConverters._ @@ -95,7 +97,19 @@ private[johnsnowlabs] class AlbertClassification( def tokenizeSeqString( candidateLabels: Seq[String], maxSeqLength: Int, - caseSensitive: Boolean): Seq[WordpieceTokenizedSentence] = ??? + caseSensitive: Boolean): Seq[WordpieceTokenizedSentence] = { + val basicTokenizer = new BasicTokenizer(caseSensitive) + val encoder = + new SentencepieceEncoder(spp, caseSensitive, sentencePieceDelimiterId, pieceIdOffset = 1) + + val labelsToSentences = candidateLabels.map { s => Sentence(s, 0, s.length - 1, 0) } + + labelsToSentences.map(label => { + val tokens = basicTokenizer.tokenize(label) + val wordpieceTokens = tokens.flatMap(token => encoder.encode(token)).take(maxSeqLength) + WordpieceTokenizedSentence(wordpieceTokens) + }) + } def tokenizeDocument( docs: Seq[Annotation], @@ -310,7 +324,30 @@ private[johnsnowlabs] class AlbertClassification( batch: Seq[Array[Int]], entailmentId: Int, contradictionId: Int, - activation: String): Array[Array[Float]] = ??? + activation: String): Array[Array[Float]] = { + + val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max + val paddedBatch = batch.map(arr => padArrayWithZeros(arr, maxSentenceLength)) + val batchLength = paddedBatch.length + + val rawScores = detectedEngine match { + case TensorFlow.name => getRawScoresWithTF(paddedBatch, maxSentenceLength) + case ONNX.name => getRawScoresWithOnnx(paddedBatch, maxSentenceLength, sequence = true) + } + + val dim = rawScores.length / batchLength + rawScores + .grouped(dim) + .toArray + } + + private def padArrayWithZeros(arr: Array[Int], maxLength: Int): Array[Int] = { + if (arr.length >= maxLength) { + arr + } else { + arr ++ Array.fill(maxLength - arr.length)(0) + } + } def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { val batchLength = batch.length diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotator.scala b/src/main/scala/com/johnsnowlabs/nlp/annotator.scala index 5d8197ba018e21..98387ad04fb14a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotator.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotator.scala @@ -432,6 +432,13 @@ package object annotator { extends ReadablePretrainedAlbertForTokenModel with ReadAlbertForTokenDLModel + type AlbertForZeroShotClassification = + com.johnsnowlabs.nlp.annotators.classifier.dl.AlbertForZeroShotClassification + + object AlbertForZeroShotClassification + extends ReadablePretrainedAlbertForZeroShotModel + with ReadAlbertForZeroShotDLModel + type XlnetForTokenClassification = com.johnsnowlabs.nlp.annotators.classifier.dl.XlnetForTokenClassification diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassification.scala new file mode 100644 index 00000000000000..6cb650237a604d --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassification.scala @@ -0,0 +1,402 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.classifier.dl + +import com.johnsnowlabs.ml.ai.AlbertClassification +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ + ReadSentencePieceModel, + SentencePieceWrapper, + WriteSentencePieceModel +} +import com.johnsnowlabs.ml.tensorflow.{ + ReadTensorflowModel, + TensorflowWrapper, + WriteTensorflowModel +} +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadSentencePieceAsset, + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} +import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.nlp.annotators.common.{SentenceSplit, TokenizedWithSentence} +import com.johnsnowlabs.nlp.serialization.MapFeature +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param.{BooleanParam, IntArrayParam, IntParam} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.SparkSession + +class AlbertForZeroShotClassification(override val uid: String) + extends AnnotatorModel[AlbertForZeroShotClassification] + with HasBatchedAnnotate[AlbertForZeroShotClassification] + with WriteTensorflowModel + with WriteOnnxModel + with WriteSentencePieceModel + with HasCaseSensitiveProperties + with HasClassifierActivationProperties + with HasEngine + with HasCandidateLabelsProperties { + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + def this() = this(Identifiable.randomUID("ALBERT_FOR_ZERO_SHOT_CLASSIFICATION")) + + /** Input Annotator Types: DOCUMENT, TOKEN + * + * @group anno + */ + override val inputAnnotatorTypes: Array[AnnotatorType] = + Array(AnnotatorType.DOCUMENT, AnnotatorType.TOKEN) + + /** Output Annotator Types: CATEGORY + * + * @group anno + */ + override val outputAnnotatorType: AnnotatorType = AnnotatorType.CATEGORY + + /** Labels used to decode predicted IDs back to string tags + * + * @group param + */ + val labels: MapFeature[String, Int] = new MapFeature(this, "labels").setProtected() + + /** @group setParam */ + def setLabels(value: Map[String, Int]): this.type = { + if (get(labels).isEmpty) + set(labels, value) + this + } + + /** Returns labels used to train this model */ + def getClasses: Array[String] = { + $$(labels).keys.toArray + } + + /** Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document + * by averaging probabilities in all sentences (Default: `false`). + * + * Due to max sequence length limit in almost all transformer models such as DeBerta (512 + * tokens), this parameter helps feeding all the sentences into the model and averaging all the + * probabilities for the entire document instead of probabilities per sentence. + * + * @group param + */ + val coalesceSentences = new BooleanParam( + this, + "coalesceSentences", + "If sets to true the output of all sentences will be averaged to one output instead of one output per sentence. Defaults to false.") + + /** @group setParam */ + def setCoalesceSentences(value: Boolean): this.type = set(coalesceSentences, value) + + /** @group getParam */ + def getCoalesceSentences: Boolean = $(coalesceSentences) + + /** ConfigProto from tensorflow, serialized into byte array. Get with + * `config_proto.SerializeToString()` + * + * @group param + */ + val configProtoBytes = new IntArrayParam( + this, + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()") + + /** @group setParam */ + def setConfigProtoBytes(bytes: Array[Int]): AlbertForZeroShotClassification.this.type = + set(this.configProtoBytes, bytes) + + /** @group getParam */ + def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte)) + + /** Max sentence length to process (Default: `128`) + * + * @group param + */ + val maxSentenceLength = + new IntParam(this, "maxSentenceLength", "Max sentence length to process") + + /** @group setParam */ + def setMaxSentenceLength(value: Int): this.type = { + require( + value <= 512, + "DeBerta models do not support sequences longer than 512 because of trainable positional embeddings.") + require(value >= 1, "The maxSentenceLength must be at least 1") + set(maxSentenceLength, value) + this + } + + /** @group getParam */ + def getMaxSentenceLength: Int = $(maxSentenceLength) + + /** It contains TF model signatures for the laded saved model + * + * @group param + */ + val signatures = + new MapFeature[String, String](model = this, name = "signatures").setProtected() + + /** @group setParam */ + def setSignatures(value: Map[String, String]): this.type = { + set(signatures, value) + this + } + + /** @group getParam */ + def getSignatures: Option[Map[String, String]] = get(this.signatures) + + private var _model: Option[Broadcast[AlbertClassification]] = None + + /** @group setParam */ + def setModelIfNotSet( + spark: SparkSession, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper], + spp: SentencePieceWrapper): AlbertForZeroShotClassification = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new AlbertClassification( + tensorflowWrapper, + onnxWrapper, + spp, + configProtoBytes = getConfigProtoBytes, + tags = $$(labels), + signatures = getSignatures))) + } + + this + } + + /** @group getParam */ + def getModelIfNotSet: AlbertClassification = _model.get.value + + /** Whether to lowercase tokens or not (Default: `true`). + * + * @group setParam + */ + override def setCaseSensitive(value: Boolean): this.type = { + set(this.caseSensitive, value) + } + + setDefault( + batchSize -> 8, + maxSentenceLength -> 128, + caseSensitive -> true, + coalesceSentences -> false) + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations in batches that correspond to inputAnnotationCols generated by previous + * annotators if any + * @return + * any number of annotations processed for every batch of input annotations. Not necessary + * one to one relationship + * + * IMPORTANT: !MUST! return sequences of equal lengths !! IMPORTANT: !MUST! return sentences + * that belong to the same original row !! (challenging) + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + batchedAnnotations.map(annotations => { + val sentences = SentenceSplit.unpack(annotations).toArray + val tokenizedSentences = TokenizedWithSentence.unpack(annotations).toArray + + if (tokenizedSentences.nonEmpty) { + getModelIfNotSet.predictSequenceWithZeroShot( + tokenizedSentences, + sentences, + $(candidateLabels), + $(entailmentIdParam), + $(contradictionIdParam), + $(batchSize), + $(maxSentenceLength), + $(caseSensitive), + $(coalesceSentences), + $$(labels), + getActivation) + + } else { + Seq.empty[Annotation] + } + }) + } + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + val suffix = "_albert_classification" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + AlbertForSequenceClassification.tfFile) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + AlbertForSequenceClassification.onnxFile) + } + + writeSentencePieceModel( + path, + spark, + getModelIfNotSet.spp, + "_albert", + AlbertForSequenceClassification.sppFile) + } + +} + +trait ReadablePretrainedAlbertForZeroShotModel + extends ParamsAndFeaturesReadable[AlbertForZeroShotClassification] + with HasPretrained[AlbertForZeroShotClassification] { + override val defaultModelName: Some[String] = Some("albert_zero_shot_classifier_onnx") + override val defaultLang: String = "en" + + /** Java compliant-overrides */ + override def pretrained(): AlbertForZeroShotClassification = super.pretrained() + + override def pretrained(name: String): AlbertForZeroShotClassification = + super.pretrained(name) + + override def pretrained(name: String, lang: String): AlbertForZeroShotClassification = + super.pretrained(name, lang) + + override def pretrained( + name: String, + lang: String, + remoteLoc: String): AlbertForZeroShotClassification = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadAlbertForZeroShotDLModel + extends ReadTensorflowModel + with ReadOnnxModel + with ReadSentencePieceModel { + this: ParamsAndFeaturesReadable[AlbertForZeroShotClassification] => + + override val tfFile: String = "albert_classification_tensorflow" + override val onnxFile: String = "albert_classification_onnx" + override val sppFile: String = "albert_spp" + + def readModel( + instance: AlbertForZeroShotClassification, + path: String, + spark: SparkSession): Unit = { + + val spp = readSentencePieceModel(path, spark, "_albert_spp", sppFile) + + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = readTensorflowModel(path, spark, "_albert_classification_tf") + instance.setModelIfNotSet(spark, Some(tfWrapper), None, spp) + case ONNX.name => + val onnxWrapper = + readOnnxModel( + path, + spark, + "albert_zero_classification_onnx", + zipped = true, + useBundle = false, + None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) + case _ => + throw new Exception(notSupportedEngineError) + } + } + + addReader(readModel) + + def loadSavedModel(modelPath: String, spark: SparkSession): AlbertForZeroShotClassification = { + + val (localModelPath, detectedEngine) = modelSanityCheck(modelPath) + + val spModel = loadSentencePieceAsset(localModelPath, "spiece.model") + val labels = loadTextAsset(localModelPath, "labels.txt").zipWithIndex.toMap + + val entailmentIds = labels.filter(x => x._1.toLowerCase().startsWith("entail")).values.toArray + val contradictionIds = + labels.filter(x => x._1.toLowerCase().startsWith("contradict")).values.toArray + + require( + entailmentIds.length == 1 && contradictionIds.length == 1, + s"""This annotator supports classifiers trained on NLI datasets. You must have only at least 2 or maximum 3 labels in your dataset: + + example with 3 labels: 'contradict', 'neutral', 'entailment' + example with 2 labels: 'contradict', 'entailment' + + You can modify assets/labels.txt file to match the above format. + + Current labels: ${labels.keys.mkString(", ")} + """) + + val annotatorModel = new AlbertForZeroShotClassification() + .setLabels(labels) + .setCandidateLabels(labels.keys.toArray) + + /* set the entailment id */ + annotatorModel.set(annotatorModel.entailmentIdParam, entailmentIds.head) + /* set the contradiction id */ + annotatorModel.set(annotatorModel.contradictionIdParam, contradictionIds.head) + /* set the engine */ + annotatorModel.set(annotatorModel.engine, detectedEngine) + + detectedEngine match { + case TensorFlow.name => + val (wrapper, signatures) = + TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) + + val _signatures = signatures match { + case Some(s) => s + case None => throw new Exception("Cannot load signature definitions from model!") + } + + /** the order of setSignatures is important if we use getSignatures inside + * setModelIfNotSet + */ + annotatorModel + .setSignatures(_signatures) + .setModelIfNotSet(spark, Some(wrapper), None, spModel) + case ONNX.name => + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) + annotatorModel.setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } +} + +/** This is the companion object of [[AlbertForZeroShotClassification]]. Please refer to that + * class for the documentation. + */ +object AlbertForZeroShotClassification + extends ReadablePretrainedAlbertForZeroShotModel + with ReadAlbertForZeroShotDLModel diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala index 467e565351c301..c4d887c3e03934 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala @@ -691,8 +691,9 @@ object PythonResourceDownloader { "LLAMA2Transformer" -> LLAMA2Transformer, "M2M100Transformer" -> M2M100Transformer, "UAEEmbeddings" -> UAEEmbeddings, + "AlbertForZeroShotClassification" -> AlbertForZeroShotClassification, "MxbaiEmbeddings" -> MxbaiEmbeddings, - "SnowFlakeEmbeddings" -> SnowFlakeEmbeddings + "SnowFlakeEmbeddings" -> SnowFlakeEmbeddings, ) // List pairs of types such as the one with key type can load a pretrained model from the value type diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassificationTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassificationTestSpec.scala new file mode 100644 index 00000000000000..06980c5e36eae2 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassificationTestSpec.scala @@ -0,0 +1,66 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.classifier.dl + +import com.johnsnowlabs.nlp.annotators.Tokenizer +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.SlowTest +import org.apache.spark.ml.Pipeline +import org.apache.spark.sql.functions.explode +import org.scalatest.flatspec.AnyFlatSpec + +class AlbertForZeroShotClassificationTestSpec extends AnyFlatSpec { + + "AlbertForZeroShotClassification" should "correctly load custom ONNX model" taggedAs SlowTest in { + import ResourceHelper.spark.implicits._ + + val dataDf = + Seq("I have a problem with my iphone that needs to be resolved asap!!").toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val tokenizer = new Tokenizer() + .setInputCols(Array("document")) + .setOutputCol("token") + + val zeroShotClassifier = AlbertForZeroShotClassification + .pretrained() + .setInputCols(Array("document", "token")) + .setOutputCol("multi_class") + .setCaseSensitive(true) + .setCoalesceSentences(true) + .setCandidateLabels(Array("urgent", "mobile", "technology")) + + val pipeline = new Pipeline().setStages(Array(document, tokenizer, zeroShotClassifier)) + + val pipelineModel = pipeline.fit(dataDf) + val pipelineDF = pipelineModel.transform(dataDf) + + pipelineDF.select("multi_class").show(false) + val totalDocs = pipelineDF.select(explode($"document.result")).count.toInt + val totalLabels = pipelineDF.select(explode($"multi_class.result")).count.toInt + + println(s"total tokens: $totalDocs") + println(s"total labels: $totalLabels") + + assert(totalDocs == totalLabels) + } + +}