Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Tensor on device cuda:0 is not on the expected device meta! #737

Open
LucasMagnana opened this issue Mar 3, 2025 · 2 comments

Comments

@LucasMagnana
Copy link

LucasMagnana commented Mar 3, 2025

🐛 Bug

Hello, I am trying to train LLMs on a language modelling task with differential privacy using opacus. While my code is working using gpt2, it is throwing RuntimeError: Tensor on device cuda:0 is not on the expected device meta! when using bert-base-cased.

To Reproduce

The code I use is the following, the model is a AutoModelForLanguageModelling from the transformers library :

def train(self, model, lr, train_dataset, eval_dataset, num_epochs):
        train_dataloader = DataLoader(
            train_dataset,
            shuffle=True,
            batch_size=self.config.train_batch_size,
            collate_fn=self.data_collator,
        )
        model = model.to(self.device)
        # Set the model to train mode (HuggingFace models load in eval mode)
        model = model.train()
        # Define optimizer
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

        DELTA = 1 / len(train_dataloader)

        privacy_engine = PrivacyEngine()

        model, optimizer, train_dataloader = privacy_engine.make_private_with_epsilon(
            module=model,
            optimizer=optimizer,
            data_loader=train_dataloader,
            target_delta=DELTA,
            target_epsilon=7.5,
            epochs=num_epochs,
            max_grad_norm=0.1,
        )

        for epoch in range(1, num_epochs+1):
            losses = []

            with BatchMemoryManager(
                data_loader=train_dataloader,
                max_physical_batch_size=4,
                optimizer=optimizer
            ) as memory_safe_data_loader:
                for step, batch in enumerate(tqdm(memory_safe_data_loader)):
                    optimizer.zero_grad()

                    inputs = {k: batch[k].to(self.device) for k in batch if k != "labels"}

                    outputs = model(**inputs) # output = loss, logits, hidden_states, attentions

                    loss = outputs[0].mean()
                    loss.backward()
                    losses.append(loss.item())

                    optimizer.step()

                    if step > 0 and step % 5000 == 0:
                        train_loss = np.mean(losses)
                        eps = privacy_engine.get_epsilon(DELTA)

                        print(
                        f"Epoch: {epoch} | "
                        f"Step: {step} | "
                        f"Train loss: {train_loss:.3f} | "
                        f"ɛ: {eps:.2f}"
                        )

The full error :

Traceback (most recent call last):
  File "/home/lmagnana/nlp-attacks/examples/special_finetunings/n2c2_ner_mlm_finetuning.py", line 72, in <module>
    models, metrics = finetuner.run(dataset, test_size, epochs, pathlib.Path(output_dir), output_name=ouput_name)
  File "/home/lmagnana/nlp-attacks/nlp_attacks/finetuners/Finetuner.py", line 248, in run
    model = self.train(model, self.config.learning_rate, ds["train"], ds["test"], epochs) 
  File "/home/lmagnana/nlp-attacks/nlp_attacks/finetuners/PrivacyPreservingLanguageModelling.py", line 141, in train
    loss.backward()
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 72, in __call__
    return self.hook(module, *args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/grad_sample_module.py", line 340, in capture_backprops_hook
    grad_samples = grad_sampler_fn(module, activations, backprops)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 108, in ft_compute_per_sample_gradient
    per_sample_grads = layer.ft_compute_sample_grad(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 281, in vmap_impl
    return _flat_vmap(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 47, in fn
    return f(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 403, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/apis.py", line 363, in wrapper
    return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1285, in grad_impl
    results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 47, in fn
    return f(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1249, in grad_and_value_impl
    output = func(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 85, in compute_loss_stateless_model
    output = flayer(params, batched_activations)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 50, in fmodel
    return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/functional_call.py", line 143, in functional_call
    return nn.utils.stateless._functional_call(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/utils/stateless.py", line 263, in _functional_call
    return module(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 787, in forward
    hidden_states = self.decoder(hidden_states)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/wrappers.py", line 252, in _fn
    result = fn(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/wrappers.py", line 137, in _fn
    result = fn(**bound.arguments)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_refs/__init__.py", line 1091, in add
    output = prims.add(a, b)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_ops.py", line 594, in __call__
    return self_._op(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims/__init__.py", line 359, in _prim_elementwise_meta
    utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/__init__.py", line 740, in check_same_device
    raise RuntimeError(msg)
RuntimeError: Tensor on device cuda:0 is not on the expected device meta!

Expected behavior

The code should work with both a gpt2 and a bert-base-cased model.

Environment

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==8.9.2.26
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.5.82
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] torch==2.3.1
[pip3] opacus==1.5.3
[pip3] triton==2.3.1
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.5.82                  pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
[conda] torch                     2.3.1                    pypi_0    pypi
[conda] triton                    2.3.1                    pypi_0    pypi

Thanks in advance for your replies.

@EnayatUllah
Copy link
Contributor

Thank you for question. Could you please provide a minimal reproducible example using Colab. We provide an example you can follow here: https://colab.research.google.com/drive/1R-dnKipK9LOVV4_oKbuHoq4VvGKGbDnd?usp=sharing.

We can help with debugging then.

@LucasMagnana
Copy link
Author

I reproduced the bug in the notebook :
https://colab.research.google.com/drive/1b5ARlULcQapVpYTeEp-d8zpD19aN4-97

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants