Skip to content

Commit bef1b1a

Browse files
committed
tests: revert change of torch_require_multi_gpu to be device agnostic
The 11c27dd modified `torch_require_multi_gpu()` to be device agnostic instead of being CUDA specific. This broke some tests which are rightfully CUDA specific, such as: * `tests/trainer/test_trainer_distributed.py::TestTrainerDistributed` In the current Transformers tests architecture `require_torch_multi_accelerator()` should be used to mark multi-GPU tests agnostic to device. This change addresses the issue introduced by 11c27dd and reverts modification of `torch_require_multi_gpu()`. Fixes: 11c27dd ("Enable BNB multi-backend support (#31098)") Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent 99e0ab6 commit bef1b1a

File tree

3 files changed

+9
-20
lines changed

3 files changed

+9
-20
lines changed

src/transformers/testing_utils.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -237,17 +237,6 @@ def parse_int_from_env(key, default=None):
237237
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)
238238

239239

240-
def get_device_count():
241-
import torch
242-
243-
if is_torch_xpu_available():
244-
num_devices = torch.xpu.device_count()
245-
else:
246-
num_devices = torch.cuda.device_count()
247-
248-
return num_devices
249-
250-
251240
def is_pt_tf_cross_test(test_case):
252241
"""
253242
Decorator marking a test as a test that control interactions between PyTorch and TensorFlow.
@@ -770,17 +759,17 @@ def require_spacy(test_case):
770759

771760
def require_torch_multi_gpu(test_case):
772761
"""
773-
Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
774-
multiple GPUs.
762+
Decorator marking a test that requires a multi-GPU CUDA setup (in PyTorch). These tests are skipped on a machine without
763+
multiple CUDA GPUs.
775764
776765
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
777766
"""
778767
if not is_torch_available():
779768
return unittest.skip(reason="test requires PyTorch")(test_case)
780769

781-
device_count = get_device_count()
770+
import torch
782771

783-
return unittest.skipUnless(device_count > 1, "test requires multiple GPUs")(test_case)
772+
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple CUDA GPUs")(test_case)
784773

785774

786775
def require_torch_multi_accelerator(test_case):

tests/quantization/bnb/test_4bit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
require_bitsandbytes,
3939
require_torch,
4040
require_torch_gpu_if_bnb_not_multi_backend_enabled,
41-
require_torch_multi_gpu,
41+
require_torch_multi_accelerator,
4242
slow,
4343
torch_device,
4444
)
@@ -514,7 +514,7 @@ def test_pipeline(self):
514514
self.assertIn(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUTS)
515515

516516

517-
@require_torch_multi_gpu
517+
@require_torch_multi_accelerator
518518
@apply_skip_if_not_implemented
519519
class Bnb4bitTestMultiGpu(Base4bitTest):
520520
def setUp(self):

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
require_bitsandbytes,
4040
require_torch,
4141
require_torch_gpu_if_bnb_not_multi_backend_enabled,
42-
require_torch_multi_gpu,
42+
require_torch_multi_accelerator,
4343
slow,
4444
torch_device,
4545
)
@@ -669,7 +669,7 @@ def test_pipeline(self):
669669
self.assertIn(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUTS)
670670

671671

672-
@require_torch_multi_gpu
672+
@require_torch_multi_accelerator
673673
@apply_skip_if_not_implemented
674674
class MixedInt8TestMultiGpu(BaseMixedInt8Test):
675675
def setUp(self):
@@ -698,7 +698,7 @@ def test_multi_gpu_loading(self):
698698
self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
699699

700700

701-
@require_torch_multi_gpu
701+
@require_torch_multi_accelerator
702702
@apply_skip_if_not_implemented
703703
class MixedInt8TestCpuGpu(BaseMixedInt8Test):
704704
def setUp(self):

0 commit comments

Comments
 (0)