You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I am trying to train LLMs on a language modelling task with differential privacy using opacus. While my code is working using gpt2, it is throwing RuntimeError: Tensor on device cuda:0 is not on the expected device meta! when using bert-base-cased.
To Reproduce
The code I use is the following, the model is a AutoModelForLanguageModelling from the transformers library :
def train(self, model, lr, train_dataset, eval_dataset, num_epochs):
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=self.config.train_batch_size,
collate_fn=self.data_collator,
)
model = model.to(self.device)
# Set the model to train mode (HuggingFace models load in eval mode)
model = model.train()
# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
DELTA = 1 / len(train_dataloader)
privacy_engine = PrivacyEngine()
model, optimizer, train_dataloader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_dataloader,
target_delta=DELTA,
target_epsilon=7.5,
epochs=num_epochs,
max_grad_norm=0.1,
)
for epoch in range(1, num_epochs+1):
losses = []
with BatchMemoryManager(
data_loader=train_dataloader,
max_physical_batch_size=4,
optimizer=optimizer
) as memory_safe_data_loader:
for step, batch in enumerate(tqdm(memory_safe_data_loader)):
optimizer.zero_grad()
inputs = {k: batch[k].to(self.device) for k in batch if k != "labels"}
outputs = model(**inputs) # output = loss, logits, hidden_states, attentions
loss = outputs[0].mean()
loss.backward()
losses.append(loss.item())
optimizer.step()
if step > 0 and step % 5000 == 0:
train_loss = np.mean(losses)
eps = privacy_engine.get_epsilon(DELTA)
print(
f"Epoch: {epoch} | "
f"Step: {step} | "
f"Train loss: {train_loss:.3f} | "
f"ɛ: {eps:.2f}"
)
The full error :
Traceback (most recent call last):
File "/home/lmagnana/nlp-attacks/examples/special_finetunings/n2c2_ner_mlm_finetuning.py", line 72, in <module>
models, metrics = finetuner.run(dataset, test_size, epochs, pathlib.Path(output_dir), output_name=ouput_name)
File "/home/lmagnana/nlp-attacks/nlp_attacks/finetuners/Finetuner.py", line 248, in run
model = self.train(model, self.config.learning_rate, ds["train"], ds["test"], epochs)
File "/home/lmagnana/nlp-attacks/nlp_attacks/finetuners/PrivacyPreservingLanguageModelling.py", line 141, in train
loss.backward()
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward
torch.autograd.backward(
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
_engine_run_backward(
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 72, in __call__
return self.hook(module, *args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/grad_sample_module.py", line 340, in capture_backprops_hook
grad_samples = grad_sampler_fn(module, activations, backprops)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 108, in ft_compute_per_sample_gradient
per_sample_grads = layer.ft_compute_sample_grad(
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/apis.py", line 188, in wrapped
return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 281, in vmap_impl
return _flat_vmap(
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 47, in fn
return f(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 403, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/apis.py", line 363, in wrapper
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1285, in grad_impl
results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 47, in fn
return f(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1249, in grad_and_value_impl
output = func(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 85, in compute_loss_stateless_model
output = flayer(params, batched_activations)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 50, in fmodel
return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/functional_call.py", line 143, in functional_call
return nn.utils.stateless._functional_call(
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/utils/stateless.py", line 263, in _functional_call
return module(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 787, in forward
hidden_states = self.decoder(hidden_states)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/wrappers.py", line 252, in _fn
result = fn(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/wrappers.py", line 137, in _fn
result = fn(**bound.arguments)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_refs/__init__.py", line 1091, in add
output = prims.add(a, b)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_ops.py", line 594, in __call__
return self_._op(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims/__init__.py", line 359, in _prim_elementwise_meta
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/__init__.py", line 740, in check_same_device
raise RuntimeError(msg)
RuntimeError: Tensor on device cuda:0 is not on the expected device meta!
Expected behavior
The code should work with both a gpt2 and a bert-base-cased model.
🐛 Bug
Hello, I am trying to train LLMs on a language modelling task with differential privacy using opacus. While my code is working using gpt2, it is throwing
RuntimeError: Tensor on device cuda:0 is not on the expected device meta!
when using bert-base-cased.To Reproduce
The code I use is the following, the model is a
AutoModelForLanguageModelling
from thetransformers
library :The full error :
Expected behavior
The code should work with both a gpt2 and a bert-base-cased model.
Environment
Thanks in advance for your replies.
The text was updated successfully, but these errors were encountered: