-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
hf: warn of deprecating internal callback #740
Changes from 24 commits
8ecc790
4ef660e
c386e89
42cd1ab
cb89b98
e2c23d7
99aff66
801178e
fa244f9
8e9afeb
d52461b
b0c3c5c
f7b4efb
6bef79e
a9f01b0
88931f0
be1cfde
ec73bf7
e6f1845
27e38cb
8c25642
153f72b
4043b8d
8f2f91b
63c40f5
1977009
7e9d1f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,8 +17,9 @@ | |
Trainer, | ||
TrainingArguments, | ||
) | ||
from transformers.integrations import DVCLiveCallback as ExternalCallback | ||
|
||
from dvclive.huggingface import DVCLiveCallback | ||
from dvclive.huggingface import DVCLiveCallback as InternalCallback | ||
Comment on lines
+20
to
+22
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test old callback (inside dvclive) and new callback (inside transformers) |
||
except ImportError: | ||
pytest.skip("skipping huggingface tests", allow_module_level=True) | ||
|
||
|
@@ -101,21 +102,21 @@ def args(): | |
evaluation_strategy="epoch", | ||
num_train_epochs=2, | ||
save_strategy="epoch", | ||
report_to="none", | ||
report_to="none", # Disable auto-reporting to avoid duplication | ||
) | ||
|
||
|
||
def test_huggingface_integration(tmp_dir, model, args, data, mocker): | ||
@pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback]) | ||
def test_huggingface_integration(tmp_dir, model, args, data, mocker, callback): | ||
trainer = Trainer( | ||
model, | ||
args, | ||
train_dataset=data[0], | ||
eval_dataset=data[1], | ||
compute_metrics=compute_metrics, | ||
) | ||
callback = DVCLiveCallback() | ||
live = callback.live | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
spy = mocker.spy(live, "end") | ||
callback = callback() | ||
spy = mocker.spy(Live, "end") | ||
trainer.add_callback(callback) | ||
trainer.train() | ||
spy.assert_called_once() | ||
|
@@ -125,8 +126,6 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker): | |
|
||
logs, _ = parse_metrics(live) | ||
|
||
assert len(logs) == 10 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not seem like a good test to have in dvclive since it relies on how many metrics hf automatically logs, including system metrics over which dvclive has no control. |
||
|
||
scalars = os.path.join(live.plots_dir, Metric.subfolder) | ||
assert os.path.join(scalars, "eval", "foo.tsv") in logs | ||
assert os.path.join(scalars, "eval", "loss.tsv") in logs | ||
|
@@ -138,19 +137,31 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker): | |
assert params["num_train_epochs"] == 2 | ||
|
||
|
||
@pytest.mark.parametrize("log_model", ["all", True, None]) | ||
@pytest.mark.parametrize("log_model", ["all", True, False]) | ||
@pytest.mark.parametrize("best", [True, False]) | ||
def test_huggingface_log_model(tmp_dir, model, data, mocker, log_model, best): | ||
live_callback = DVCLiveCallback(log_model=log_model) | ||
log_artifact = mocker.patch.object(live_callback.live, "log_artifact") | ||
@pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback]) | ||
def test_huggingface_log_model( | ||
tmp_dir, | ||
mocked_dvc_repo, | ||
model, | ||
data, | ||
args, | ||
monkeypatch, | ||
mocker, | ||
log_model, | ||
best, | ||
callback, | ||
): | ||
live = Live() | ||
log_artifact = mocker.patch.object(live, "log_artifact") | ||
if callback == ExternalCallback: | ||
monkeypatch.setenv("HF_DVCLIVE_LOG_MODEL", str(log_model)) | ||
live_callback = callback(live=live) | ||
else: | ||
live_callback = callback(live=live, log_model=log_model) | ||
|
||
args.load_best_model_at_end = best | ||
|
||
args = TrainingArguments( | ||
"foo", | ||
evaluation_strategy="epoch", | ||
num_train_epochs=2, | ||
save_strategy="epoch", | ||
load_best_model_at_end=best, | ||
) | ||
trainer = Trainer( | ||
model, | ||
args, | ||
|
@@ -164,11 +175,11 @@ def test_huggingface_log_model(tmp_dir, model, data, mocker, log_model, best): | |
expected_call_count = { | ||
"all": 2, | ||
True: 1, | ||
None: 0, | ||
False: 0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
assert log_artifact.call_count == expected_call_count[log_model] | ||
|
||
if log_model == "last": | ||
if log_model is True: | ||
Comment on lines
-171
to
+183
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This condition was never met since the expected value is |
||
name = "best" if best else "last" | ||
log_artifact.assert_called_with( | ||
os.path.join(args.output_dir, name), | ||
|
@@ -178,8 +189,27 @@ def test_huggingface_log_model(tmp_dir, model, data, mocker, log_model, best): | |
) | ||
|
||
|
||
def test_huggingface_pass_logger(): | ||
@pytest.mark.parametrize("callback", [ExternalCallback, InternalCallback]) | ||
def test_huggingface_pass_logger(callback): | ||
logger = Live("train_logs") | ||
|
||
assert DVCLiveCallback().live is not logger | ||
assert DVCLiveCallback(live=logger).live is logger | ||
assert callback().live is not logger | ||
assert callback(live=logger).live is logger | ||
|
||
|
||
@pytest.mark.parametrize("report_to", ["all", "dvclive", "none"]) | ||
def test_huggingface_report_to(model, report_to): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since there are no real tests for the callbacks in |
||
args = TrainingArguments("foo", report_to=report_to) | ||
trainer = Trainer( | ||
model, | ||
args, | ||
) | ||
live_cbs = [ | ||
cb | ||
for cb in trainer.callback_handler.callbacks | ||
if isinstance(cb, ExternalCallback) | ||
] | ||
if report_to == "none": | ||
assert not any(live_cbs) | ||
else: | ||
assert any(live_cbs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Q] Do you want to deprecate it now and remove it in 4.0?