From abbc603cfd8183025d5e51d3fbe9b3331e67df0c Mon Sep 17 00:00:00 2001 From: emilioMaddalena Date: Sat, 26 Apr 2025 15:43:58 +0200 Subject: [PATCH 1/3] Create label_consistency.ipynb First notebook to try and reproduce the error. --- notebooks/label_consistency.ipynb | 249 ++++++++++++++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 notebooks/label_consistency.ipynb diff --git a/notebooks/label_consistency.ipynb b/notebooks/label_consistency.ipynb new file mode 100644 index 00000000..adbcc249 --- /dev/null +++ b/notebooks/label_consistency.ipynb @@ -0,0 +1,249 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "e28f757b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/emiliomaddalena/Documents/github/setfit/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from setfit import SetFitModel, Trainer, TrainingArguments\n", + "from datasets import Dataset\n", + "import numpy as np\n", + "import torch\n", + "import random" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c7f08856", + "metadata": {}, + "outputs": [], + "source": [ + "sentences = [\n", + " # food\n", + " \"I enjoy making homemade pizza on weekends.\",\n", + " \"Fresh vegetables add color and flavor to any meal.\",\n", + " \"A freshly baked cake can brighten any celebration.\",\n", + " \"Many people love discovering new coffee shops.\",\n", + " # chemestry\n", + " \"Chemistry focuses on the composition, structure, and properties of matter.\",\n", + " \"Chemical reactions can transform one set of substances into entirely different ones.\",\n", + " \"Acids and bases react to form salts and water in neutralization processes.\",\n", + " \"Organic chemistry studies carbon-based compounds essential to living organisms.\",\n", + " # sports\n", + " \"Regular exercise boosts both health and mood.\",\n", + " \"Soccer is popular in many countries around the world.\",\n", + " \"Professional athletes maintain strict training schedules.\",\n", + " \"A supportive crowd can motivate players to perform better.\",\n", + "]\n", + "\n", + "labels = [\n", + " \"food\",\n", + " \"food\",\n", + " \"food\",\n", + " \"food\",\n", + " \"chemistry\",\n", + " \"chemistry\",\n", + " \"chemistry\",\n", + " \"chemistry\",\n", + " \"sports\",\n", + " \"sports\",\n", + " \"sports\",\n", + " \"sports\",\n", + "]\n", + "\n", + "train_dataset = Dataset.from_dict({\"text\": sentences, \"label\": labels})" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "78368d95", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "model_id = \"sentence-transformers/all-MiniLM-L6-v2\"\n", + "model = SetFitModel.from_pretrained(\n", + " model_id,\n", + " labels=[\n", + " \"food\",\n", + " \"chemistry\",\n", + " \"sports\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1c3ab4ab", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/emiliomaddalena/Documents/github/setfit/.venv/lib/python3.12/site-packages/codecarbon/input.py:9: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html\n", + " import pkg_resources\n", + "/Users/emiliomaddalena/Documents/github/setfit/.venv/lib/python3.12/site-packages/pkg_resources/__init__.py:3147: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.\n", + "Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages\n", + " declare_namespace(pkg)\n", + "/Users/emiliomaddalena/Documents/github/setfit/.venv/lib/python3.12/site-packages/datasets/utils/_dill.py:385: DeprecationWarning: co_lnotab is deprecated, use co_lines instead.\n", + " obj.co_lnotab, # for < python 3.10 [not counted in args]\n", + "Map: 100%|██████████| 12/12 [00:00<00:00, 3939.24 examples/s]\n", + "***** Running training *****\n", + " Num unique pairs = 96\n", + " Batch size = 32\n", + " Num epochs = 3\n", + "/Users/emiliomaddalena/Documents/github/setfit/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py:683: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.\n", + " warnings.warn(warn_msg)\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [9/9 00:01, Epoch 3/3]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
10.144000

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/emiliomaddalena/Documents/github/setfit/.venv/lib/python3.12/site-packages/codecarbon/output_methods/file.py:50: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n", + " df = pd.concat([df, pd.DataFrame.from_records([dict(total.values)])])\n" + ] + } + ], + "source": [ + "SEED = 100\n", + "random.seed(SEED)\n", + "np.random.seed(SEED)\n", + "torch.manual_seed(SEED)\n", + "\n", + "args = TrainingArguments(\n", + " batch_size=32,\n", + " num_epochs=3,\n", + ")\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=args,\n", + " train_dataset=train_dataset,\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2381c6ce", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "I enjoy making homemade pizza on weekends.\n", + "food\n", + "chemistry\n", + "['food', 'chemistry', 'sports']\n", + "['chemistry' 'food' 'sports']\n" + ] + } + ], + "source": [ + "# 'I enjoy making homemade pizza on weekends.'\n", + "sentence = sentences[0]\n", + "print(sentence)\n", + "\n", + "# 'food' \n", + "pred = model.predict(sentence)\n", + "print(pred)\n", + "\n", + "# 'chemistry' \n", + "probs = model.predict_proba(sentence, as_numpy=True)\n", + "pred = model.id2label[int(np.argmax(probs))]\n", + "print(pred)\n", + "\n", + "# ['food', 'chemistry', 'sports']\n", + "print(model.labels)\n", + "# ['chemistry', 'food', 'sports']\n", + "print(model.model_head.classes_)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41c206b9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 510a5eec44767491ac9d57420bbf40e5fdc53fdf Mon Sep 17 00:00:00 2001 From: emilioMaddalena Date: Sat, 26 Apr 2025 17:36:31 +0200 Subject: [PATCH 2/3] Cover mismatching labels case - In predict_proba, probs were consistent with the model head's labels order - Now we make sure probs are always consistent with self.labels --- src/setfit/modeling.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index dfc0face..f917c0ad 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -518,6 +518,13 @@ def predict_proba( probs = torch.stack(probs, axis=1) else: probs = np.stack(probs, axis=1) + if list(self.labels) != list(self.model_head.classes_): + # If the user has specified labels when instantiating the model, we have to take into account + # the possibility of the model head having reordered the labels. + head_labels = list(self.model_head.classes_) + user_labels = list(self.labels) + reorder_map = np.array([head_labels.index(label) for label in user_labels]) + probs = probs[:, reorder_map] outputs = self._output_type_conversion(probs, as_numpy=as_numpy) return outputs[0] if is_singular else outputs From aed6f2952353819cb0ac96650f2865a72c2fc03c Mon Sep 17 00:00:00 2001 From: emilioMaddalena Date: Sat, 26 Apr 2025 17:38:27 +0200 Subject: [PATCH 3/3] Remove notebook --- notebooks/label_consistency.ipynb | 249 ------------------------------ 1 file changed, 249 deletions(-) delete mode 100644 notebooks/label_consistency.ipynb diff --git a/notebooks/label_consistency.ipynb b/notebooks/label_consistency.ipynb deleted file mode 100644 index adbcc249..00000000 --- a/notebooks/label_consistency.ipynb +++ /dev/null @@ -1,249 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "e28f757b", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/emiliomaddalena/Documents/github/setfit/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "from setfit import SetFitModel, Trainer, TrainingArguments\n", - "from datasets import Dataset\n", - "import numpy as np\n", - "import torch\n", - "import random" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "c7f08856", - "metadata": {}, - "outputs": [], - "source": [ - "sentences = [\n", - " # food\n", - " \"I enjoy making homemade pizza on weekends.\",\n", - " \"Fresh vegetables add color and flavor to any meal.\",\n", - " \"A freshly baked cake can brighten any celebration.\",\n", - " \"Many people love discovering new coffee shops.\",\n", - " # chemestry\n", - " \"Chemistry focuses on the composition, structure, and properties of matter.\",\n", - " \"Chemical reactions can transform one set of substances into entirely different ones.\",\n", - " \"Acids and bases react to form salts and water in neutralization processes.\",\n", - " \"Organic chemistry studies carbon-based compounds essential to living organisms.\",\n", - " # sports\n", - " \"Regular exercise boosts both health and mood.\",\n", - " \"Soccer is popular in many countries around the world.\",\n", - " \"Professional athletes maintain strict training schedules.\",\n", - " \"A supportive crowd can motivate players to perform better.\",\n", - "]\n", - "\n", - "labels = [\n", - " \"food\",\n", - " \"food\",\n", - " \"food\",\n", - " \"food\",\n", - " \"chemistry\",\n", - " \"chemistry\",\n", - " \"chemistry\",\n", - " \"chemistry\",\n", - " \"sports\",\n", - " \"sports\",\n", - " \"sports\",\n", - " \"sports\",\n", - "]\n", - "\n", - "train_dataset = Dataset.from_dict({\"text\": sentences, \"label\": labels})" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "78368d95", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.\n" - ] - } - ], - "source": [ - "model_id = \"sentence-transformers/all-MiniLM-L6-v2\"\n", - "model = SetFitModel.from_pretrained(\n", - " model_id,\n", - " labels=[\n", - " \"food\",\n", - " \"chemistry\",\n", - " \"sports\",\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "1c3ab4ab", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/emiliomaddalena/Documents/github/setfit/.venv/lib/python3.12/site-packages/codecarbon/input.py:9: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html\n", - " import pkg_resources\n", - "/Users/emiliomaddalena/Documents/github/setfit/.venv/lib/python3.12/site-packages/pkg_resources/__init__.py:3147: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.\n", - "Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages\n", - " declare_namespace(pkg)\n", - "/Users/emiliomaddalena/Documents/github/setfit/.venv/lib/python3.12/site-packages/datasets/utils/_dill.py:385: DeprecationWarning: co_lnotab is deprecated, use co_lines instead.\n", - " obj.co_lnotab, # for < python 3.10 [not counted in args]\n", - "Map: 100%|██████████| 12/12 [00:00<00:00, 3939.24 examples/s]\n", - "***** Running training *****\n", - " Num unique pairs = 96\n", - " Batch size = 32\n", - " Num epochs = 3\n", - "/Users/emiliomaddalena/Documents/github/setfit/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py:683: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.\n", - " warnings.warn(warn_msg)\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "

\n", - " \n", - " \n", - " [9/9 00:01, Epoch 3/3]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
10.144000

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/emiliomaddalena/Documents/github/setfit/.venv/lib/python3.12/site-packages/codecarbon/output_methods/file.py:50: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n", - " df = pd.concat([df, pd.DataFrame.from_records([dict(total.values)])])\n" - ] - } - ], - "source": [ - "SEED = 100\n", - "random.seed(SEED)\n", - "np.random.seed(SEED)\n", - "torch.manual_seed(SEED)\n", - "\n", - "args = TrainingArguments(\n", - " batch_size=32,\n", - " num_epochs=3,\n", - ")\n", - "\n", - "trainer = Trainer(\n", - " model=model,\n", - " args=args,\n", - " train_dataset=train_dataset,\n", - ")\n", - "\n", - "trainer.train()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "2381c6ce", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "I enjoy making homemade pizza on weekends.\n", - "food\n", - "chemistry\n", - "['food', 'chemistry', 'sports']\n", - "['chemistry' 'food' 'sports']\n" - ] - } - ], - "source": [ - "# 'I enjoy making homemade pizza on weekends.'\n", - "sentence = sentences[0]\n", - "print(sentence)\n", - "\n", - "# 'food' \n", - "pred = model.predict(sentence)\n", - "print(pred)\n", - "\n", - "# 'chemistry' \n", - "probs = model.predict_proba(sentence, as_numpy=True)\n", - "pred = model.id2label[int(np.argmax(probs))]\n", - "print(pred)\n", - "\n", - "# ['food', 'chemistry', 'sports']\n", - "print(model.labels)\n", - "# ['chemistry', 'food', 'sports']\n", - "print(model.model_head.classes_)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "41c206b9", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.7" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}