Skip to content

Commit

Permalink
fix: flash attn
Browse files Browse the repository at this point in the history
  • Loading branch information
0xDing committed Sep 25, 2023
1 parent 0e3cbfd commit 369fb41
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions apps/webui/src/webui/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from webui.constants import ALREADY_CONVERTED_MARK, ChatbotValue, Conversation
from yuren_core.constants import IM_END_TOKEN, IM_START_TOKEN
from yuren_core.errors import MaxTokenLengthError
from yuren_core.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn

# from yuren_core.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn

logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -248,14 +249,14 @@ def sample_decode(
im_end_token_id = tokenizer.convert_tokens_to_ids([IM_END_TOKEN])[0]
for i in range(max_new_tokens):
if i == 0:
outputs = model(torch.as_tensor(input_ids, device=device), use_cache=False)
outputs = model(torch.as_tensor(input_ids, device=device), use_cache=True)
logits = outputs.logits
past_key_values = outputs.past_key_values
else:
attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + 1, device=device)
outputs = model(
input_ids=torch.as_tensor([[token]], device=device), # noqa: F821
use_cache=False,
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
Expand Down Expand Up @@ -353,7 +354,7 @@ def load_tokenizer_and_model(base_model, load_8bit=False):
pass
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
if device == "cuda":
replace_llama_attn_with_flash_attn()
# replace_llama_attn_with_flash_attn()
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
Expand Down

0 comments on commit 369fb41

Please sign in to comment.