Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qLoRA support #398

Closed
ehartford opened this issue Jul 30, 2023 · 5 comments
Closed

qLoRA support #398

ehartford opened this issue Jul 30, 2023 · 5 comments

Comments

@ehartford
Copy link

I tried adding Flash Attention into qLoRA, I receive the following error:

RuntimeError: FlashAttention only support fp16 and bf16 data type

Is it possible to add support for 4-bit qLoRA?

@ehartford
Copy link
Author

The related issue in qLoRA

artidoro/qlora#221

@ehartford
Copy link
Author

My full stack trace:

Traceback (most recent call last):
  File "/home/eric/git/qlora/qlora.py", line 846, in <module>
    train()
  File "/home/eric/git/qlora/qlora.py", line 808, in train
    train_result = trainer.train()
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 2654, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 2679, in compute_loss
    outputs = model(**inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/utils/operations.py", line 581, in forward
    return model_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/utils/operations.py", line 569, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/peft/peft_model.py", line 922, in forward
    return self.base_model(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 806, in forward
    outputs = self.model(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 685, in forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
    outputs = run_function(*args)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 681, in custom_forward
    return module(*inputs, output_attentions, None)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/git/qlora/patch_flash_attn.py", line 87, in forward
    output_unpad = flash_attn_varlen_qkvpacked_func(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 408, in flash_attn_varlen_qkvpacked_func
    return FlashAttnVarlenQKVPackedFunc.apply(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 123, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_varlen_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type

@tridao
Copy link
Contributor

tridao commented Jul 31, 2023

We're not planning to write custom kernel for 4-bit. Can you just cast the input (q, k, v) to fp16/bf16, call FlashAttention, then convert the output to whichever dtype?

@ehartford ehartford closed this as not planned Won't fix, can't repro, duplicate, stale Jul 31, 2023
@ehartford
Copy link
Author

Understood thank you

@zlh1992
Copy link

zlh1992 commented Aug 5, 2023

Understood thank you

hi, do you check the result of flash_attention 2 and qlora 4bit?

I use flash_attn_func function to calculate the attention of GQA ,and i cast the input (q, k, v) to bf16. Durning the training process, the loss is healthy but the qlora adaper model file sames error. When I use the qlora adaper model file, the generatation result is error.

query_states, key_states, value_states = query_states.to(torch.bfloat16), key_states.to(torch.bfloat16), value_states.to(torch.bfloat16)
output = flash_attn_func(query_states, key_states, value_states, dropout_p=0.0, causal=True)
output = output.float()
output = output.contiguous().view(bsz, q_len, -1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants