Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hf: warn of deprecating internal callback #740

Merged
merged 27 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8ecc790
hf: warn of deprecating internal callback
Nov 17, 2023
4ef660e
hf: update notebook
Nov 17, 2023
c386e89
Merge branch 'main' into hf-deprecation-warning
Nov 17, 2023
42cd1ab
fix notebook
Nov 17, 2023
cb89b98
Merge branch 'main' into hf-deprecation-warning
Dec 12, 2023
e2c23d7
hf: test without passing live instance
Dec 12, 2023
99aff66
Merge branch 'hf-deprecation-warning' of github.com:iterative/dvclive…
Dec 12, 2023
801178e
Merge branch 'main' into hf-deprecation-warning
Dec 12, 2023
fa244f9
Merge branch 'main' into hf-deprecation-warning
Dec 22, 2023
8e9afeb
Merge branch 'main' into hf-deprecation-warning
Jan 22, 2024
d52461b
merge hf notebooks and fix dvclivecallback import
Jan 22, 2024
b0c3c5c
Merge branch 'hf-deprecation-warning' of github.com:iterative/dvclive…
Jan 22, 2024
f7b4efb
hf: test HF_DVCLIVE_LOG_MODEL env var
Jan 22, 2024
6bef79e
fix ci: account for huggingface transformers changes
mattseddon Feb 12, 2024
a9f01b0
see if loss is broken
mattseddon Feb 12, 2024
88931f0
show what is going through on_log
mattseddon Feb 12, 2024
be1cfde
try unparallelize
mattseddon Feb 12, 2024
ec73bf7
check next_step call count
mattseddon Feb 12, 2024
e6f1845
try report_to none
mattseddon Feb 12, 2024
27e38cb
revert all code
mattseddon Feb 12, 2024
8c25642
Merge branch 'check-huggingface' into hf-deprecation-warning
Feb 12, 2024
153f72b
clean up hf tests
Feb 12, 2024
4043b8d
merge
Feb 12, 2024
8f2f91b
revert moving spy call
Feb 12, 2024
63c40f5
hf: test log_model=None
Feb 12, 2024
1977009
Merge branch 'main' into hf-deprecation-warning
Feb 27, 2024
7e9d1f9
Update src/dvclive/huggingface.py
Feb 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 88 additions & 74 deletions examples/DVCLive-HuggingFace.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
},
"outputs": [],
"source": [
"!pip install accelerate datasets dvclive evaluate pandas 'transformers[torch]' --upgrade"
"!pip install datasets dvclive evaluate pandas 'transformers[torch]' --upgrade"
]
},
{
Expand Down Expand Up @@ -153,85 +153,102 @@
"### Training and Tracking experiments with DVCLive\n",
"\n",
"Track experiments in DVC by changing a few lines of your Python code.\n",
"Enable experiment tracking using `save_dvc_exp=True` and save model artifacts using `log_model=True`.\n"
"Save model artifacts using `HF_DVCLIVE_LOG_MODEL=true`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"referenced_widgets": [
"425795652e7047eab04fdb8816d85fe7",
"47f9d18c35a04f69b9bfe4e835a98d13",
"eae5029fcd6a49aab13a1aa11bb55a77",
"6b272c5eff1b49aab4f906cc0cea84bf",
"8f86d83b55b04afdb74ced76a4326f98",
"f5a4ea7cfa8a4cbbad83bf3a33db4172",
"cc96d129d3014c1a9fbb598a985c4c88",
"fb0f45a06d47475d83b7b9eec42d4e06",
"2c92c08cd14e46179fb1be7d5e36e2d1",
"1b760f9910934859ba90b92d94460855",
"30cb39221ea6463d9070fd9d0eefbfae",
"fc06215ebdfc4be0bcf8b2a3c227f4de",
"9db967ac8d4048d3bfedc3b2c21e15ba",
"0306ec09003e4271804a17a37a6e3db5",
"d2192a8e3e7b4df09d3cd7904b104519",
"ad9df876a82040f2853f63899c36e5fa",
"e40116d5bdbc43d0b0ffc27476c9eb6a",
"20c15ecc60794e07b70d96083514655b",
"674a922aef3b4d82b5b65241ef5d4de1",
"cec588e36e814237add88c6695d54952",
"1236f4a882fc47389c392fa88f85d6cc",
"3288c1eaa7b640f4bcf74fcdf7b12aaa",
"d7f7f24763b844128f44e7b787e848ff",
"d17ee4b49c184d63afa71fee9f15d4a1",
"92afad02b9064c1ba40396e5bc36931c",
"94d4335888e24e69964197f029f2389d",
"c6472537577a4900a3b190b64285b7a9",
"ab41d610477b47708634c7be45df9bc6",
"b802e2906072462abebd4b7b41fc1ab9",
"87576f1a78664458be6f7c4cbaa806a8",
"6f6a505d2bf84d588dc5a23dc34ef988",
"c9b665b4da694075b02ed09189dee813",
"30fad6f36d9f4f8eb6a1a3f70a93fe31"
]
},
"id": "gKKSTh0ZdmsW",
"outputId": "54733639-db04-4b82-f8ac-3d20f81712bb"
"id": "-A1oXCxE4zGi"
},
"outputs": [],
"source": [
"from dvclive.huggingface import DVCLiveCallback\n",
"%env HF_DVCLIVE_LOG_MODEL=true"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gKKSTh0ZdmsW"
},
"outputs": [],
"source": [
"from transformers.integrations import DVCLiveCallback\n",
"from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
"\n",
"for epochs in (5, 10, 15):\n",
" model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-cased\", num_labels=2)\n",
" for param in model.base_model.parameters():\n",
" param.requires_grad = False\n",
"model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-cased\", num_labels=2)\n",
"for param in model.base_model.parameters():\n",
" param.requires_grad = False\n",
"\n",
"lr = 3e-4\n",
"\n",
"training_args = TrainingArguments(\n",
" evaluation_strategy=\"epoch\",\n",
" learning_rate=lr,\n",
" logging_strategy=\"epoch\",\n",
" num_train_epochs=5,\n",
" output_dir=\"output\",\n",
" overwrite_output_dir=True,\n",
" load_best_model_at_end=True,\n",
" save_strategy=\"epoch\",\n",
" weight_decay=0.01,\n",
")\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=small_train_dataset,\n",
" eval_dataset=small_eval_dataset,\n",
" compute_metrics=compute_metrics,\n",
")\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KKJCw0Vj6UTw"
},
"source": [
"To customize tracking, include `transformers.integrations.DVCLiveCallback` in the `Trainer` callbacks and pass additional keyword arguments to `dvclive.Live`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M4FKUYTi5zYQ"
},
"outputs": [],
"source": [
"from dvclive import Live\n",
"from transformers.integrations import DVCLiveCallback\n",
"\n",
"lr = 1e-4\n",
"\n",
" training_args = TrainingArguments(\n",
" evaluation_strategy=\"epoch\",\n",
" learning_rate=3e-4,\n",
" logging_strategy=\"epoch\",\n",
" num_train_epochs=epochs,\n",
" output_dir=\"output\",\n",
" overwrite_output_dir=True,\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\",\n",
" save_strategy=\"epoch\",\n",
" weight_decay=0.01,\n",
" )\n",
"training_args = TrainingArguments(\n",
" evaluation_strategy=\"epoch\",\n",
" learning_rate=lr,\n",
" logging_strategy=\"epoch\",\n",
" num_train_epochs=5,\n",
" output_dir=\"output\",\n",
" overwrite_output_dir=True,\n",
" load_best_model_at_end=True,\n",
" save_strategy=\"epoch\",\n",
" weight_decay=0.01,\n",
")\n",
"\n",
" trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=small_train_dataset,\n",
" eval_dataset=small_eval_dataset,\n",
" compute_metrics=compute_metrics,\n",
" callbacks=[DVCLiveCallback(log_model=True, report=\"notebook\")]\n",
" )\n",
" trainer.train()"
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=small_train_dataset,\n",
" eval_dataset=small_eval_dataset,\n",
" compute_metrics=compute_metrics,\n",
" callbacks=[DVCLiveCallback(live=Live(report=\"notebook\"), log_model=True)],\n",
")\n",
"trainer.train()"
]
},
{
Expand All @@ -249,8 +266,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wwMwHvVtdmsW",
"outputId": "7db5ce6b-4f70-4a5f-fa3e-84218d2c1446"
"id": "wwMwHvVtdmsW"
},
"outputs": [],
"source": [
Expand All @@ -270,8 +286,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TNBGUqoCdmsW",
"outputId": "2a4ebf29-e7e3-40f7-ec56-2ccc857c215f"
"id": "TNBGUqoCdmsW"
},
"outputs": [],
"source": [
Expand All @@ -282,8 +297,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sL5pH4X5dmsW",
"outputId": "168ffdad-baff-4b79-8596-c54a5f4e4c86"
"id": "sL5pH4X5dmsW"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -311,7 +325,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
5 changes: 5 additions & 0 deletions src/dvclive/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def __init__(
log_model: Optional[Union[Literal["all"], bool]] = None,
**kwargs,
):
logger.warning(
"This callback will be deprecated in DVCLive 4.0 in favor of"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Q] Do you want to deprecate it now and remove it in 4.0?

" `transformers.integrations.DVCLiveCallback`"
" https://dvc.org/doc/dvclive/ml-frameworks/huggingface."
dberenbaum marked this conversation as resolved.
Show resolved Hide resolved
)
super().__init__()
self._log_model = log_model
self.live = live if live is not None else Live(**kwargs)
Expand Down
78 changes: 54 additions & 24 deletions tests/frameworks/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
Trainer,
TrainingArguments,
)
from transformers.integrations import DVCLiveCallback as ExternalCallback

from dvclive.huggingface import DVCLiveCallback
from dvclive.huggingface import DVCLiveCallback as InternalCallback
Comment on lines +20 to +22
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test old callback (inside dvclive) and new callback (inside transformers)

except ImportError:
pytest.skip("skipping huggingface tests", allow_module_level=True)

Expand Down Expand Up @@ -101,21 +102,21 @@ def args():
evaluation_strategy="epoch",
num_train_epochs=2,
save_strategy="epoch",
report_to="none",
report_to="none", # Disable auto-reporting to avoid duplication
)


def test_huggingface_integration(tmp_dir, model, args, data, mocker):
@pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback])
def test_huggingface_integration(tmp_dir, model, args, data, mocker, callback):
trainer = Trainer(
model,
args,
train_dataset=data[0],
eval_dataset=data[1],
compute_metrics=compute_metrics,
)
callback = DVCLiveCallback()
live = callback.live
Copy link
Collaborator Author

@dberenbaum dberenbaum Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callback.live is not initialized yet in the external callback

spy = mocker.spy(live, "end")
callback = callback()
spy = mocker.spy(Live, "end")
trainer.add_callback(callback)
trainer.train()
spy.assert_called_once()
Expand All @@ -125,8 +126,6 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker):

logs, _ = parse_metrics(live)

assert len(logs) == 10
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem like a good test to have in dvclive since it relies on how many metrics hf automatically logs, including system metrics over which dvclive has no control.


scalars = os.path.join(live.plots_dir, Metric.subfolder)
assert os.path.join(scalars, "eval", "foo.tsv") in logs
assert os.path.join(scalars, "eval", "loss.tsv") in logs
Expand All @@ -138,19 +137,31 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker):
assert params["num_train_epochs"] == 2


@pytest.mark.parametrize("log_model", ["all", True, None])
@pytest.mark.parametrize("log_model", ["all", True, False])
@pytest.mark.parametrize("best", [True, False])
def test_huggingface_log_model(tmp_dir, model, data, mocker, log_model, best):
live_callback = DVCLiveCallback(log_model=log_model)
log_artifact = mocker.patch.object(live_callback.live, "log_artifact")
@pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback])
def test_huggingface_log_model(
tmp_dir,
mocked_dvc_repo,
model,
data,
args,
monkeypatch,
mocker,
log_model,
best,
callback,
):
live = Live()
log_artifact = mocker.patch.object(live, "log_artifact")
if callback == ExternalCallback:
monkeypatch.setenv("HF_DVCLIVE_LOG_MODEL", str(log_model))
live_callback = callback(live=live)
else:
live_callback = callback(live=live, log_model=log_model)

args.load_best_model_at_end = best

args = TrainingArguments(
"foo",
evaluation_strategy="epoch",
num_train_epochs=2,
save_strategy="epoch",
load_best_model_at_end=best,
)
trainer = Trainer(
model,
args,
Expand All @@ -164,11 +175,11 @@ def test_huggingface_log_model(tmp_dir, model, data, mocker, log_model, best):
expected_call_count = {
"all": 2,
True: 1,
None: 0,
False: 0,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

False is also a valid value for log_model

}
assert log_artifact.call_count == expected_call_count[log_model]

if log_model == "last":
if log_model is True:
Comment on lines -171 to +183
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition was never met since the expected value is True, not "last"

name = "best" if best else "last"
log_artifact.assert_called_with(
os.path.join(args.output_dir, name),
Expand All @@ -178,8 +189,27 @@ def test_huggingface_log_model(tmp_dir, model, data, mocker, log_model, best):
)


def test_huggingface_pass_logger():
@pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback])
def test_huggingface_pass_logger(callback):
logger = Live("train_logs")

assert DVCLiveCallback().live is not logger
assert DVCLiveCallback(live=logger).live is logger
assert callback().live is not logger
assert callback(live=logger).live is logger


@pytest.mark.parametrize("report_to", ["all", "dvclive", "none"])
def test_huggingface_report_to(model, report_to):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there are no real tests for the callbacks in transformers, added a test here to ensure report_to adds the dvclive callback when expected

args = TrainingArguments("foo", report_to=report_to)
trainer = Trainer(
model,
args,
)
live_cbs = [
cb
for cb in trainer.callback_handler.callbacks
if isinstance(cb, ExternalCallback)
]
if report_to == "none":
assert not any(live_cbs)
else:
assert any(live_cbs)
Loading