Skip to content

Commit

Permalink
hf: warn of deprecating internal callback (#740)
Browse files Browse the repository at this point in the history
* hf: warn of deprecating internal callback

* hf: update notebook

* fix notebook

* hf: test without passing live instance

* merge hf notebooks and fix dvclivecallback import

* hf: test HF_DVCLIVE_LOG_MODEL env var

* fix ci: account for huggingface transformers changes

* see if loss is broken

* show what is going through on_log

* try unparallelize

* check next_step call count

* try report_to none

* revert all code

* clean up hf tests

* revert moving spy call

* hf: test log_model=None

* Update src/dvclive/huggingface.py

---------

Co-authored-by: Matt Seddon <[email protected]>
  • Loading branch information
Dave Berenbaum and mattseddon authored Feb 27, 2024
1 parent 2c7c378 commit a45efba
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 97 deletions.
162 changes: 88 additions & 74 deletions examples/DVCLive-HuggingFace.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
},
"outputs": [],
"source": [
"!pip install accelerate datasets dvclive evaluate pandas 'transformers[torch]' --upgrade"
"!pip install datasets dvclive evaluate pandas 'transformers[torch]' --upgrade"
]
},
{
Expand Down Expand Up @@ -153,85 +153,102 @@
"### Training and Tracking experiments with DVCLive\n",
"\n",
"Track experiments in DVC by changing a few lines of your Python code.\n",
"Enable experiment tracking using `save_dvc_exp=True` and save model artifacts using `log_model=True`.\n"
"Save model artifacts using `HF_DVCLIVE_LOG_MODEL=true`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"referenced_widgets": [
"425795652e7047eab04fdb8816d85fe7",
"47f9d18c35a04f69b9bfe4e835a98d13",
"eae5029fcd6a49aab13a1aa11bb55a77",
"6b272c5eff1b49aab4f906cc0cea84bf",
"8f86d83b55b04afdb74ced76a4326f98",
"f5a4ea7cfa8a4cbbad83bf3a33db4172",
"cc96d129d3014c1a9fbb598a985c4c88",
"fb0f45a06d47475d83b7b9eec42d4e06",
"2c92c08cd14e46179fb1be7d5e36e2d1",
"1b760f9910934859ba90b92d94460855",
"30cb39221ea6463d9070fd9d0eefbfae",
"fc06215ebdfc4be0bcf8b2a3c227f4de",
"9db967ac8d4048d3bfedc3b2c21e15ba",
"0306ec09003e4271804a17a37a6e3db5",
"d2192a8e3e7b4df09d3cd7904b104519",
"ad9df876a82040f2853f63899c36e5fa",
"e40116d5bdbc43d0b0ffc27476c9eb6a",
"20c15ecc60794e07b70d96083514655b",
"674a922aef3b4d82b5b65241ef5d4de1",
"cec588e36e814237add88c6695d54952",
"1236f4a882fc47389c392fa88f85d6cc",
"3288c1eaa7b640f4bcf74fcdf7b12aaa",
"d7f7f24763b844128f44e7b787e848ff",
"d17ee4b49c184d63afa71fee9f15d4a1",
"92afad02b9064c1ba40396e5bc36931c",
"94d4335888e24e69964197f029f2389d",
"c6472537577a4900a3b190b64285b7a9",
"ab41d610477b47708634c7be45df9bc6",
"b802e2906072462abebd4b7b41fc1ab9",
"87576f1a78664458be6f7c4cbaa806a8",
"6f6a505d2bf84d588dc5a23dc34ef988",
"c9b665b4da694075b02ed09189dee813",
"30fad6f36d9f4f8eb6a1a3f70a93fe31"
]
},
"id": "gKKSTh0ZdmsW",
"outputId": "54733639-db04-4b82-f8ac-3d20f81712bb"
"id": "-A1oXCxE4zGi"
},
"outputs": [],
"source": [
"from dvclive.huggingface import DVCLiveCallback\n",
"%env HF_DVCLIVE_LOG_MODEL=true"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gKKSTh0ZdmsW"
},
"outputs": [],
"source": [
"from transformers.integrations import DVCLiveCallback\n",
"from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
"\n",
"for epochs in (5, 10, 15):\n",
" model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-cased\", num_labels=2)\n",
" for param in model.base_model.parameters():\n",
" param.requires_grad = False\n",
"model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-cased\", num_labels=2)\n",
"for param in model.base_model.parameters():\n",
" param.requires_grad = False\n",
"\n",
"lr = 3e-4\n",
"\n",
"training_args = TrainingArguments(\n",
" evaluation_strategy=\"epoch\",\n",
" learning_rate=lr,\n",
" logging_strategy=\"epoch\",\n",
" num_train_epochs=5,\n",
" output_dir=\"output\",\n",
" overwrite_output_dir=True,\n",
" load_best_model_at_end=True,\n",
" save_strategy=\"epoch\",\n",
" weight_decay=0.01,\n",
")\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=small_train_dataset,\n",
" eval_dataset=small_eval_dataset,\n",
" compute_metrics=compute_metrics,\n",
")\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KKJCw0Vj6UTw"
},
"source": [
"To customize tracking, include `transformers.integrations.DVCLiveCallback` in the `Trainer` callbacks and pass additional keyword arguments to `dvclive.Live`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M4FKUYTi5zYQ"
},
"outputs": [],
"source": [
"from dvclive import Live\n",
"from transformers.integrations import DVCLiveCallback\n",
"\n",
"lr = 1e-4\n",
"\n",
" training_args = TrainingArguments(\n",
" evaluation_strategy=\"epoch\",\n",
" learning_rate=3e-4,\n",
" logging_strategy=\"epoch\",\n",
" num_train_epochs=epochs,\n",
" output_dir=\"output\",\n",
" overwrite_output_dir=True,\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\",\n",
" save_strategy=\"epoch\",\n",
" weight_decay=0.01,\n",
" )\n",
"training_args = TrainingArguments(\n",
" evaluation_strategy=\"epoch\",\n",
" learning_rate=lr,\n",
" logging_strategy=\"epoch\",\n",
" num_train_epochs=5,\n",
" output_dir=\"output\",\n",
" overwrite_output_dir=True,\n",
" load_best_model_at_end=True,\n",
" save_strategy=\"epoch\",\n",
" weight_decay=0.01,\n",
")\n",
"\n",
" trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=small_train_dataset,\n",
" eval_dataset=small_eval_dataset,\n",
" compute_metrics=compute_metrics,\n",
" callbacks=[DVCLiveCallback(log_model=True, report=\"notebook\")]\n",
" )\n",
" trainer.train()"
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=small_train_dataset,\n",
" eval_dataset=small_eval_dataset,\n",
" compute_metrics=compute_metrics,\n",
" callbacks=[DVCLiveCallback(live=Live(report=\"notebook\"), log_model=True)],\n",
")\n",
"trainer.train()"
]
},
{
Expand All @@ -249,8 +266,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wwMwHvVtdmsW",
"outputId": "7db5ce6b-4f70-4a5f-fa3e-84218d2c1446"
"id": "wwMwHvVtdmsW"
},
"outputs": [],
"source": [
Expand All @@ -270,8 +286,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TNBGUqoCdmsW",
"outputId": "2a4ebf29-e7e3-40f7-ec56-2ccc857c215f"
"id": "TNBGUqoCdmsW"
},
"outputs": [],
"source": [
Expand All @@ -282,8 +297,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sL5pH4X5dmsW",
"outputId": "168ffdad-baff-4b79-8596-c54a5f4e4c86"
"id": "sL5pH4X5dmsW"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -311,7 +325,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
5 changes: 5 additions & 0 deletions src/dvclive/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def __init__(
log_model: Optional[Union[Literal["all"], bool]] = None,
**kwargs,
):
logger.warning(
"This callback is deprecated and will be removed in DVCLive 4.0"
" in favor of `transformers.integrations.DVCLiveCallback`"
" https://dvc.org/doc/dvclive/ml-frameworks/huggingface."
)
super().__init__()
self._log_model = log_model
self.live = live if live is not None else Live(**kwargs)
Expand Down
77 changes: 54 additions & 23 deletions tests/frameworks/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
Trainer,
TrainingArguments,
)
from transformers.integrations import DVCLiveCallback as ExternalCallback

from dvclive.huggingface import DVCLiveCallback
from dvclive.huggingface import DVCLiveCallback as InternalCallback
except ImportError:
pytest.skip("skipping huggingface tests", allow_module_level=True)

Expand Down Expand Up @@ -101,21 +102,21 @@ def args():
evaluation_strategy="epoch",
num_train_epochs=2,
save_strategy="epoch",
report_to="none",
report_to="none", # Disable auto-reporting to avoid duplication
)


def test_huggingface_integration(tmp_dir, model, args, data, mocker):
@pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback])
def test_huggingface_integration(tmp_dir, model, args, data, mocker, callback):
trainer = Trainer(
model,
args,
train_dataset=data[0],
eval_dataset=data[1],
compute_metrics=compute_metrics,
)
callback = DVCLiveCallback()
live = callback.live
spy = mocker.spy(live, "end")
callback = callback()
spy = mocker.spy(Live, "end")
trainer.add_callback(callback)
trainer.train()
spy.assert_called_once()
Expand All @@ -125,8 +126,6 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker):

logs, _ = parse_metrics(live)

assert len(logs) == 10

scalars = os.path.join(live.plots_dir, Metric.subfolder)
assert os.path.join(scalars, "eval", "foo.tsv") in logs
assert os.path.join(scalars, "eval", "loss.tsv") in logs
Expand All @@ -138,19 +137,31 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker):
assert params["num_train_epochs"] == 2


@pytest.mark.parametrize("log_model", ["all", True, None])
@pytest.mark.parametrize("log_model", ["all", True, False, None])
@pytest.mark.parametrize("best", [True, False])
def test_huggingface_log_model(tmp_dir, model, data, mocker, log_model, best):
live_callback = DVCLiveCallback(log_model=log_model)
log_artifact = mocker.patch.object(live_callback.live, "log_artifact")
@pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback])
def test_huggingface_log_model(
tmp_dir,
mocked_dvc_repo,
model,
data,
args,
monkeypatch,
mocker,
log_model,
best,
callback,
):
live = Live()
log_artifact = mocker.patch.object(live, "log_artifact")
if callback == ExternalCallback:
monkeypatch.setenv("HF_DVCLIVE_LOG_MODEL", str(log_model))
live_callback = callback(live=live)
else:
live_callback = callback(live=live, log_model=log_model)

args.load_best_model_at_end = best

args = TrainingArguments(
"foo",
evaluation_strategy="epoch",
num_train_epochs=2,
save_strategy="epoch",
load_best_model_at_end=best,
)
trainer = Trainer(
model,
args,
Expand All @@ -164,11 +175,12 @@ def test_huggingface_log_model(tmp_dir, model, data, mocker, log_model, best):
expected_call_count = {
"all": 2,
True: 1,
False: 0,
None: 0,
}
assert log_artifact.call_count == expected_call_count[log_model]

if log_model == "last":
if log_model is True:
name = "best" if best else "last"
log_artifact.assert_called_with(
os.path.join(args.output_dir, name),
Expand All @@ -178,8 +190,27 @@ def test_huggingface_log_model(tmp_dir, model, data, mocker, log_model, best):
)


def test_huggingface_pass_logger():
@pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback])
def test_huggingface_pass_logger(callback):
logger = Live("train_logs")

assert DVCLiveCallback().live is not logger
assert DVCLiveCallback(live=logger).live is logger
assert callback().live is not logger
assert callback(live=logger).live is logger


@pytest.mark.parametrize("report_to", ["all", "dvclive", "none"])
def test_huggingface_report_to(model, report_to):
args = TrainingArguments("foo", report_to=report_to)
trainer = Trainer(
model,
args,
)
live_cbs = [
cb
for cb in trainer.callback_handler.callbacks
if isinstance(cb, ExternalCallback)
]
if report_to == "none":
assert not any(live_cbs)
else:
assert any(live_cbs)

0 comments on commit a45efba

Please sign in to comment.