diff --git a/examples/DVCLive-HuggingFace.ipynb b/examples/DVCLive-HuggingFace.ipynb index 2f0a145f..6f9a81dd 100644 --- a/examples/DVCLive-HuggingFace.ipynb +++ b/examples/DVCLive-HuggingFace.ipynb @@ -37,7 +37,7 @@ }, "outputs": [], "source": [ - "!pip install accelerate datasets dvclive evaluate pandas 'transformers[torch]' --upgrade" + "!pip install datasets dvclive evaluate pandas 'transformers[torch]' --upgrade" ] }, { @@ -153,85 +153,102 @@ "### Training and Tracking experiments with DVCLive\n", "\n", "Track experiments in DVC by changing a few lines of your Python code.\n", - "Enable experiment tracking using `save_dvc_exp=True` and save model artifacts using `log_model=True`.\n" + "Save model artifacts using `HF_DVCLIVE_LOG_MODEL=true`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "referenced_widgets": [ - "425795652e7047eab04fdb8816d85fe7", - "47f9d18c35a04f69b9bfe4e835a98d13", - "eae5029fcd6a49aab13a1aa11bb55a77", - "6b272c5eff1b49aab4f906cc0cea84bf", - "8f86d83b55b04afdb74ced76a4326f98", - "f5a4ea7cfa8a4cbbad83bf3a33db4172", - "cc96d129d3014c1a9fbb598a985c4c88", - "fb0f45a06d47475d83b7b9eec42d4e06", - "2c92c08cd14e46179fb1be7d5e36e2d1", - "1b760f9910934859ba90b92d94460855", - "30cb39221ea6463d9070fd9d0eefbfae", - "fc06215ebdfc4be0bcf8b2a3c227f4de", - "9db967ac8d4048d3bfedc3b2c21e15ba", - "0306ec09003e4271804a17a37a6e3db5", - "d2192a8e3e7b4df09d3cd7904b104519", - "ad9df876a82040f2853f63899c36e5fa", - "e40116d5bdbc43d0b0ffc27476c9eb6a", - "20c15ecc60794e07b70d96083514655b", - "674a922aef3b4d82b5b65241ef5d4de1", - "cec588e36e814237add88c6695d54952", - "1236f4a882fc47389c392fa88f85d6cc", - "3288c1eaa7b640f4bcf74fcdf7b12aaa", - "d7f7f24763b844128f44e7b787e848ff", - "d17ee4b49c184d63afa71fee9f15d4a1", - "92afad02b9064c1ba40396e5bc36931c", - "94d4335888e24e69964197f029f2389d", - "c6472537577a4900a3b190b64285b7a9", - "ab41d610477b47708634c7be45df9bc6", - "b802e2906072462abebd4b7b41fc1ab9", - "87576f1a78664458be6f7c4cbaa806a8", - "6f6a505d2bf84d588dc5a23dc34ef988", - "c9b665b4da694075b02ed09189dee813", - "30fad6f36d9f4f8eb6a1a3f70a93fe31" - ] - }, - "id": "gKKSTh0ZdmsW", - "outputId": "54733639-db04-4b82-f8ac-3d20f81712bb" + "id": "-A1oXCxE4zGi" }, "outputs": [], "source": [ - "from dvclive.huggingface import DVCLiveCallback\n", + "%env HF_DVCLIVE_LOG_MODEL=true" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gKKSTh0ZdmsW" + }, + "outputs": [], + "source": [ + "from transformers.integrations import DVCLiveCallback\n", "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n", "\n", - "for epochs in (5, 10, 15):\n", - " model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-cased\", num_labels=2)\n", - " for param in model.base_model.parameters():\n", - " param.requires_grad = False\n", + "model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-cased\", num_labels=2)\n", + "for param in model.base_model.parameters():\n", + " param.requires_grad = False\n", + "\n", + "lr = 3e-4\n", + "\n", + "training_args = TrainingArguments(\n", + " evaluation_strategy=\"epoch\",\n", + " learning_rate=lr,\n", + " logging_strategy=\"epoch\",\n", + " num_train_epochs=5,\n", + " output_dir=\"output\",\n", + " overwrite_output_dir=True,\n", + " load_best_model_at_end=True,\n", + " save_strategy=\"epoch\",\n", + " weight_decay=0.01,\n", + ")\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=small_train_dataset,\n", + " eval_dataset=small_eval_dataset,\n", + " compute_metrics=compute_metrics,\n", + ")\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KKJCw0Vj6UTw" + }, + "source": [ + "To customize tracking, include `transformers.integrations.DVCLiveCallback` in the `Trainer` callbacks and pass additional keyword arguments to `dvclive.Live`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "M4FKUYTi5zYQ" + }, + "outputs": [], + "source": [ + "from dvclive import Live\n", + "from transformers.integrations import DVCLiveCallback\n", + "\n", + "lr = 1e-4\n", "\n", - " training_args = TrainingArguments(\n", - " evaluation_strategy=\"epoch\",\n", - " learning_rate=3e-4,\n", - " logging_strategy=\"epoch\",\n", - " num_train_epochs=epochs,\n", - " output_dir=\"output\",\n", - " overwrite_output_dir=True,\n", - " load_best_model_at_end=True,\n", - " report_to=\"none\",\n", - " save_strategy=\"epoch\",\n", - " weight_decay=0.01,\n", - " )\n", + "training_args = TrainingArguments(\n", + " evaluation_strategy=\"epoch\",\n", + " learning_rate=lr,\n", + " logging_strategy=\"epoch\",\n", + " num_train_epochs=5,\n", + " output_dir=\"output\",\n", + " overwrite_output_dir=True,\n", + " load_best_model_at_end=True,\n", + " save_strategy=\"epoch\",\n", + " weight_decay=0.01,\n", + ")\n", "\n", - " trainer = Trainer(\n", - " model=model,\n", - " args=training_args,\n", - " train_dataset=small_train_dataset,\n", - " eval_dataset=small_eval_dataset,\n", - " compute_metrics=compute_metrics,\n", - " callbacks=[DVCLiveCallback(log_model=True, report=\"notebook\")]\n", - " )\n", - " trainer.train()" + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=small_train_dataset,\n", + " eval_dataset=small_eval_dataset,\n", + " compute_metrics=compute_metrics,\n", + " callbacks=[DVCLiveCallback(live=Live(report=\"notebook\"), log_model=True)],\n", + ")\n", + "trainer.train()" ] }, { @@ -249,8 +266,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "wwMwHvVtdmsW", - "outputId": "7db5ce6b-4f70-4a5f-fa3e-84218d2c1446" + "id": "wwMwHvVtdmsW" }, "outputs": [], "source": [ @@ -270,8 +286,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "TNBGUqoCdmsW", - "outputId": "2a4ebf29-e7e3-40f7-ec56-2ccc857c215f" + "id": "TNBGUqoCdmsW" }, "outputs": [], "source": [ @@ -282,8 +297,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "sL5pH4X5dmsW", - "outputId": "168ffdad-baff-4b79-8596-c54a5f4e4c86" + "id": "sL5pH4X5dmsW" }, "outputs": [], "source": [ @@ -311,7 +325,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/src/dvclive/huggingface.py b/src/dvclive/huggingface.py index bc14a814..883c1490 100644 --- a/src/dvclive/huggingface.py +++ b/src/dvclive/huggingface.py @@ -24,6 +24,11 @@ def __init__( log_model: Optional[Union[Literal["all"], bool]] = None, **kwargs, ): + logger.warning( + "This callback is deprecated and will be removed in DVCLive 4.0" + " in favor of `transformers.integrations.DVCLiveCallback`" + " https://dvc.org/doc/dvclive/ml-frameworks/huggingface." + ) super().__init__() self._log_model = log_model self.live = live if live is not None else Live(**kwargs) diff --git a/tests/frameworks/test_huggingface.py b/tests/frameworks/test_huggingface.py index 8ec39373..ac824588 100644 --- a/tests/frameworks/test_huggingface.py +++ b/tests/frameworks/test_huggingface.py @@ -17,8 +17,9 @@ Trainer, TrainingArguments, ) + from transformers.integrations import DVCLiveCallback as ExternalCallback - from dvclive.huggingface import DVCLiveCallback + from dvclive.huggingface import DVCLiveCallback as InternalCallback except ImportError: pytest.skip("skipping huggingface tests", allow_module_level=True) @@ -101,11 +102,12 @@ def args(): evaluation_strategy="epoch", num_train_epochs=2, save_strategy="epoch", - report_to="none", + report_to="none", # Disable auto-reporting to avoid duplication ) -def test_huggingface_integration(tmp_dir, model, args, data, mocker): +@pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback]) +def test_huggingface_integration(tmp_dir, model, args, data, mocker, callback): trainer = Trainer( model, args, @@ -113,9 +115,8 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker): eval_dataset=data[1], compute_metrics=compute_metrics, ) - callback = DVCLiveCallback() - live = callback.live - spy = mocker.spy(live, "end") + callback = callback() + spy = mocker.spy(Live, "end") trainer.add_callback(callback) trainer.train() spy.assert_called_once() @@ -125,8 +126,6 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker): logs, _ = parse_metrics(live) - assert len(logs) == 10 - scalars = os.path.join(live.plots_dir, Metric.subfolder) assert os.path.join(scalars, "eval", "foo.tsv") in logs assert os.path.join(scalars, "eval", "loss.tsv") in logs @@ -138,19 +137,31 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker): assert params["num_train_epochs"] == 2 -@pytest.mark.parametrize("log_model", ["all", True, None]) +@pytest.mark.parametrize("log_model", ["all", True, False, None]) @pytest.mark.parametrize("best", [True, False]) -def test_huggingface_log_model(tmp_dir, model, data, mocker, log_model, best): - live_callback = DVCLiveCallback(log_model=log_model) - log_artifact = mocker.patch.object(live_callback.live, "log_artifact") +@pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback]) +def test_huggingface_log_model( + tmp_dir, + mocked_dvc_repo, + model, + data, + args, + monkeypatch, + mocker, + log_model, + best, + callback, +): + live = Live() + log_artifact = mocker.patch.object(live, "log_artifact") + if callback == ExternalCallback: + monkeypatch.setenv("HF_DVCLIVE_LOG_MODEL", str(log_model)) + live_callback = callback(live=live) + else: + live_callback = callback(live=live, log_model=log_model) + + args.load_best_model_at_end = best - args = TrainingArguments( - "foo", - evaluation_strategy="epoch", - num_train_epochs=2, - save_strategy="epoch", - load_best_model_at_end=best, - ) trainer = Trainer( model, args, @@ -164,11 +175,12 @@ def test_huggingface_log_model(tmp_dir, model, data, mocker, log_model, best): expected_call_count = { "all": 2, True: 1, + False: 0, None: 0, } assert log_artifact.call_count == expected_call_count[log_model] - if log_model == "last": + if log_model is True: name = "best" if best else "last" log_artifact.assert_called_with( os.path.join(args.output_dir, name), @@ -178,8 +190,27 @@ def test_huggingface_log_model(tmp_dir, model, data, mocker, log_model, best): ) -def test_huggingface_pass_logger(): +@pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback]) +def test_huggingface_pass_logger(callback): logger = Live("train_logs") - assert DVCLiveCallback().live is not logger - assert DVCLiveCallback(live=logger).live is logger + assert callback().live is not logger + assert callback(live=logger).live is logger + + +@pytest.mark.parametrize("report_to", ["all", "dvclive", "none"]) +def test_huggingface_report_to(model, report_to): + args = TrainingArguments("foo", report_to=report_to) + trainer = Trainer( + model, + args, + ) + live_cbs = [ + cb + for cb in trainer.callback_handler.callbacks + if isinstance(cb, ExternalCallback) + ] + if report_to == "none": + assert not any(live_cbs) + else: + assert any(live_cbs)