Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama2 7b model #89

Open
ha-seungwon opened this issue Nov 23, 2024 · 4 comments
Open

llama2 7b model #89

ha-seungwon opened this issue Nov 23, 2024 · 4 comments

Comments

@ha-seungwon
Copy link

Hello,

Thank you for your interesting project.

Can I use OnnxStream task in Llama2 -7b fp16 model??

@vitoplantamura
Copy link
Owner

vitoplantamura commented Nov 24, 2024 via email

@ha-seungwon
Copy link
Author

ha-seungwon commented Nov 25, 2024

Hello,

So is it not possible to customise another LLM model?

Thanks

@vitoplantamura
Copy link
Owner

vitoplantamura commented Nov 26, 2024 via email

@ha-seungwon
Copy link
Author

ha-seungwon commented Nov 27, 2024

Hello,

I already try but some error comes out plz help me.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn as nn
import onnx

# Llama2 모델 로드
model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)

# Llama2 ONNX 변환용 래퍼 모델 정의
class LlamaModel(nn.Module):
    def __init__(self, model):
        super(LlamaModel, self).__init__()
        self.model = model

    def forward(self, input_ids, attention_mask, position_ids, *past_key_values):
        past_key_values = tuple(
            (past_key_values[i], past_key_values[i + 1]) for i in range(0, len(past_key_values), 2)
        )
        outputs = self.model(
            use_cache=True,
            return_dict=True,
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
        )
        pkv = outputs.past_key_values
        # logits와 각 past_key_values를 반환
        return [outputs.logits] + [item for pair in pkv for item in pair]

# 더미 입력 생성
with torch.no_grad():
    dummy_input = (
        torch.tensor([[1, 2, 3]], dtype=torch.int64),  # input_ids
        torch.tensor([[1, 1, 1]], dtype=torch.int64),  # attention_mask
        torch.tensor([[0, 1, 2]], dtype=torch.int64)   # position_ids
    )

    # 32개 레이어의 past_key_values 추가 (batch_size=1, num_heads=32, past_seq_len=4, head_dim=128)
    for _ in range(32):
        dummy_input += (torch.randn(1, 32, 4, 128, dtype=torch.float16),)  # key
        dummy_input += (torch.randn(1, 32, 4, 128, dtype=torch.float16),)  # value

    # 입력 및 출력 이름 정의
    input_names = ["input_ids", "attention_mask", "position_ids"] + [f"pkv{i}" for i in range(64)]  # 32 layers * 2 (key, value)
    output_names = ["logits"] + [f"opkv{i}" for i in range(64)]  # 32 layers * 2 (key, value)

    # ONNX 변환
    torch.onnx.export(
        LlamaModel(model),
        dummy_input,
        "./onnx_export_model/model.onnx",
        verbose=False,
        input_names=input_names,
        output_names=output_names,
        opset_version=14,
        do_constant_folding=True,
        export_params=True,
        dynamic_axes={
            "input_ids": {1: "sequence"},
            "attention_mask": {1: "sequence"},
            "position_ids": {1: "sequence"},
            **{f"pkv{i}": {2: "past_seq_len"} for i in range(64)},
        },
    )

)

after export my model and "onnxsim_large_model" > "onnx2txt"

Gather -> 68
Shape -> 37
Add -> 227
Range -> 1
Unsqueeze -> 41
Slice -> 162
Cast -> 136
Equal -> 3
And -> 1
Where -> 2
Expand -> 5
Concat -> 130
Reshape -> 129
ScatterND -> 1
Pow -> 65
ReduceMean -> 65
Sqrt -> 65
Div -> 65
Mul -> 386
MatMul -> 290
Transpose -> 161
Cos -> 1
Sin -> 1
Neg -> 64
Softmax -> 32
Sigmoid -> 32
TOTAL -> 2170

output of my onnx2txt code

my error is
image
image

how can I fix it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants