Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove conditional import of flash_attn #43

Open
b8raoult opened this issue Sep 16, 2024 · 0 comments
Open

Remove conditional import of flash_attn #43

b8raoult opened this issue Sep 16, 2024 · 0 comments

Comments

@b8raoult
Copy link
Collaborator

b8raoult commented Sep 16, 2024

What happened?

Inference crashes without any meaningful error message (just exists back to the shell) when training was done with flash_attn installed and not inference, and vice-versa.

What are the steps to reproduce the bug?

Two ways to reproduce the issue

  1. Train a model with flash_attn installed, run inference with flash_attn not installed
  2. Train the model without flash_attn, run inference in an environment where flash_attn is installed

This can also apply to a training that is restarted from a checkpoint.

Version

all

Platform (OS and architecture)

any

Relevant log output

try:
    from flash_attn import flash_attn_func as attn_func
except ImportError:
    from torch.nn.functional import scaled_dot_product_attention as attn_func

    _FLASH_ATTENTION_AVAILABLE = False
else:
    _FLASH_ATTENTION_AVAILABLE = True

Accompanying data

No response

Organisation

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant