Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for exporting LLaMA to ONNX format #922

Closed
wants to merge 10 commits into from

Conversation

nenkoru
Copy link
Contributor

@nenkoru nenkoru commented Mar 24, 2023

This PR adds support for a LLaMA model to be exported to ONNX format.

Fixes # (issue)
#918

Before submitting

  • [] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@nenkoru
Copy link
Contributor Author

nenkoru commented Mar 24, 2023

First time contribution. Need help to run a CI.
Was unable to run tests myself because don't have enough powerful GPUs.

Comment on lines 78 to 79
"opt": "hf-internal-testing/tiny-random-llama",
"llama": "hf-internal-testing/tiny-random-OPTModel",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you inverted OPT and Llama here

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@nenkoru
Copy link
Contributor Author

nenkoru commented Mar 25, 2023

Okay, so hf-internal-security model doesn't have any pytorch_model*.bin
Would it be okay to use this one?
https://huggingface.co/decapoda-research/llama-7b-hf

And for test_pipeline_ort_model I guess we need to wait until transformers release 4.28 occur.
Because it fails on getting on a configuration for the model.

I will change tiny-random to llama if that's fine as well as fixing a comment from @regisss

@nenkoru
Copy link
Contributor Author

nenkoru commented Mar 25, 2023

Okay, I managed to get a powerful machine with lots of RAM and confirm it works fine.

Tested 7billion model
Though it produces some warnings during a load, but a model produces normal outputs.
The process takes more than 100gb of space and peak RAM usage was at 77-80gb to do the whole thing, from loading the model into ORT and then exporting it. The exported model takes 51gb of hard-drive space.

Inference:
"Tell me about Alpacas"
use_cache=False
num_beams=1
max_new_tokens=128
temperature=0.1
top_p=0.75
top_k=40

fp32
AMD EPYC 32 cores

28-31s ONNX no optimizations
vs
305seconds plain torch

The argument `from_transformers` is deprecated, and will be removed in optimum 2.0. Use `export` instead Framework not specified. Using pt to export to ONNX. Loading checkpoint shards: 0%| | 0/3 [00:00 True /home/user/miniconda3/envs/jupyter/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:475: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if input_shape[-1] > 1: /home/user/miniconda3/envs/jupyter/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:46: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) /home/user/miniconda3/envs/jupyter/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:108: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if seq_len > self.max_seq_len_cached: /home/user/miniconda3/envs/jupyter/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:231: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): /home/user/miniconda3/envs/jupyter/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:238: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): /home/user/miniconda3/envs/jupyter/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:243: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) /home/user/miniconda3/envs/jupyter/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:249: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): ============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 ============= verbose: False, log level: Level.ERROR ======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Saving external data to one file...
Using framework PyTorch: 2.0.0+cu117
Overriding 1 configuration item(s)
- use_cache -> True
Asked a sequence length of 16, but a sequence length of 1 will be used with use_past == True for input_ids.
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
IOStream.flush timed out
============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Saving external data to one file...
Generation config file not found, using a generation config created from the model config.

@nenkoru
Copy link
Contributor Author

nenkoru commented Mar 27, 2023

Small update:
After a lot of digging I understood that if the model is larger than a pytorch model - this is not an expected behaviour.
There could be a lot of things behind a scene that I have no background in ML to understand. However, with a common sense I came across of this issue which suggests to try to remove shared weights. Which I don't exactly know if this model has.
As well as that, just running exported model without it's ...with_past.onnx counterpart doesn't really go that well. Even if to add a use_cache=False into an initializer of ...CausalLM - it loads, but on inference it fails with an exception(which I would provide in an upcoming edit of this comment). So with that in mind I would try merging decoders together and trying to make one model file with both past and basic model in place.

Exception tb --------------------------------------------------------------------------- IndexError Traceback (most recent call last) Cell In[4], line 3 1 device = torch.device("cuda") 2 t1 = time.time() ----> 3 evall = evaluate("Tell me about Alpacas", use_cache=False) 4 print(time.time() - t1)

Cell In[3], line 44, in evaluate(instruction, input, temperature, top_p, top_k, num_beams, max_new_tokens, **kwargs)
36 generation_config = GenerationConfig(
37 temperature=temperature,
38 top_p=top_p,
(...)
41 **kwargs,
42 )
43 with torch.no_grad():
---> 44 generation_output = model_onnx.generate(
45 input_ids=input_ids,
46 generation_config=generation_config,
47 max_new_tokens=max_new_tokens,
48 return_dict_in_generate=True,
49 #use_cache=False
50 )
51 #print(generation_output)
52 s = generation_output

File ~/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)

File ~/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/generation/utils.py:1416, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, **kwargs)
1410 raise ValueError(
1411 f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
1412 " greedy search."
1413 )
1415 # 11. run greedy search
-> 1416 return self.greedy_search(
1417 input_ids,
1418 logits_processor=logits_processor,
1419 stopping_criteria=stopping_criteria,
1420 pad_token_id=generation_config.pad_token_id,
1421 eos_token_id=generation_config.eos_token_id,
1422 output_scores=generation_config.output_scores,
1423 return_dict_in_generate=generation_config.return_dict_in_generate,
1424 synced_gpus=synced_gpus,
1425 **model_kwargs,
1426 )
1428 elif is_contrastive_search_gen_mode:
1429 if generation_config.num_return_sequences > 1:

File ~/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/generation/utils.py:2211, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
2208 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2210 # forward pass to get next token
-> 2211 outputs = self(
2212 **model_inputs,
2213 return_dict=True,
2214 output_attentions=output_attentions,
2215 output_hidden_states=output_hidden_states,
2216 )
2218 if synced_gpus and this_peer_finished:
2219 continue # don't waste resources running the code we don't need

File ~/miniconda3/envs/llama/lib/python3.10/site-packages/optimum/modeling_base.py:85, in OptimizedModel.call(self, *args, **kwargs)
84 def call(self, *args, **kwargs):
---> 85 return self.forward(*args, **kwargs)

File ~/miniconda3/envs/llama/lib/python3.10/site-packages/optimum/onnxruntime/modeling_decoder.py:573, in ORTModelForCausalLM.forward(self, input_ids, attention_mask, past_key_values, labels, **kwargs)
556 @add_start_docstrings_to_model_forward(
557 CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length")
558 + TEXT_GENERATION_EXAMPLE.format(
(...)
570 **kwargs,
571 ) -> CausalLMOutputWithCrossAttentions:
572 if past_key_values is None or self.use_cache is False:
--> 573 outputs = self.decoder(
574 input_ids=input_ids,
575 attention_mask=attention_mask,
576 past_key_values=past_key_values,
577 labels=labels,
578 )
579 elif self.use_merged is True:
580 outputs = self.decoder(
581 input_ids=input_ids[:, -1:],
582 past_key_values=past_key_values,
583 attention_mask=attention_mask,
584 )

File ~/miniconda3/envs/llama/lib/python3.10/site-packages/optimum/onnxruntime/base.py:63, in ORTModelPart.call(self, *args, **kwargs)
62 def call(self, *args, **kwargs):
---> 63 return self.forward(*args, **kwargs)

File ~/miniconda3/envs/llama/lib/python3.10/site-packages/optimum/onnxruntime/base.py:307, in ORTDecoder.forward(self, input_ids, attention_mask, past_key_values, labels, use_cache_branch)
304 model_inputs.append(labels)
305 known_output_shapes.update({"loss": []})
--> 307 io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
308 self.session,
309 *model_inputs,
310 known_output_shapes=known_output_shapes,
311 ordered_input_names=self._ordered_input_names,
312 )
314 io_binding.synchronize_inputs()
315 self.session.run_with_iobinding(io_binding)

File ~/miniconda3/envs/llama/lib/python3.10/site-packages/optimum/onnxruntime/modeling_ort.py:746, in ORTModel._prepare_io_binding(self, model, ordered_input_names, known_output_shapes, outputs_to_not_bind, *model_inputs)
744 if tensor is None:
745 continue
--> 746 name = ordered_input_names[idx]
747 input_name_to_tensor[name] = tensor
748 tensor = tensor.contiguous()

IndexError: list index out of range

@fxmarty
Copy link
Contributor

fxmarty commented Mar 28, 2023

@nenkoru
Copy link
Contributor Author

nenkoru commented Mar 28, 2023

Hi @nenkoru , would that work https://huggingface.co/HuggingFaceM4/tiny-random-LlamaForCausalLM/tree/main?

I guess it won't work because for some reason ORTCausalForLM fails to work with use_cache=False option. I mean the instance of a class initializes, but inference won't work(provided a tb in a previous comment under a spoiler).
After a small monkey patching it turned out that model still tries to use past_key_values after a few iteations over a forward method(I might be wrong about a actions behind a scene of overall decoder based model inference).

UPD: tried this tiny-random llama(converted to onnx using optimum-cli) and the same stuff happens(tb in the comment above).
optimum-cli cmd: optimum-cli export onnx --model HuggingFaceM4/tiny-random-LlamaForCausalLM --task causal-lm --for-ort --fp16 --device cuda tiny-llama-onnx/
Code used for inference:

from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig

from optimum.onnxruntime import ORTModelForCausalLM
import torch
# Create a PATH to save the model
# Load the model converted to ORT (ONNX)
model_onnx = ORTModelForCausalLM.from_pretrained(
    "/root/tiny-llama-onnx", 
    #from_transformers=True,
    provider="CUDAExecutionProvider",
    torch_dtype=torch.float16,
    use_cache=False
    #low_cpu_mem_usage=True
)

tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")

inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
input_ids = inputs["input_ids"].to("cuda")
print(inputs)
generation_config = GenerationConfig(
  temperature=temperature,
  top_p=top_p,
  top_k=top_k,
  num_beams=num_beams,
  **kwargs,
)
with torch.no_grad():
  generation_output = model_onnx.generate(
      input_ids=input_ids,
      generation_config=generation_config,
      max_new_tokens=max_new_tokens,
      return_dict_in_generate=True,
      #use_cache=False
  )

@nenkoru
Copy link
Contributor Author

nenkoru commented Mar 29, 2023

Small observation:
use_io_binding=False fixes the problem, but the performance dips

UPD: monkey-patching ORTDecoder's forward method fixes the problem.
I am not sure how to implement ignoring of past_key_values, either within a method or somewhere else. I believe its out of scope of this PR.

UPD2: basically use_cache=False doesn't apply ignoring of past_key_values within forward method at all, and because of having no input_names for those inputs that are being made when executing the model - it fails with an exception. This is why this hack exists - to patch a forward method to do not use past_key_values

from optimum.onnxruntime.base import ORTDecoder
def wrapper(func):
    
    def decorator(*args, **kwargs):
        kwargs['past_key_values'] = None
        return func(*args, **kwargs)
    return decorator


wrapped = wrapper(ORTDecoder.forward)
setattr(ORTDecoder, "forward", wrapped)

@fxmarty
Copy link
Contributor

fxmarty commented Apr 4, 2023

@nenkoru You opened this PR on the main branch of your fork so it appears I can not push there. Using trl-internal-testing/tiny-random-LlamaForCausalLM, all ONNX and ORT tests pass (including gpu ones pytest tests/onnxruntime/test_*.py -k "llama" -s --exitfirst -m "gpu_test")

So to me once the conflict is resolved and transformers is release this looks good!

@nenkoru
Copy link
Contributor Author

nenkoru commented Apr 4, 2023

@nenkoru You opened this PR on the main branch of your fork so it appears I can not push there. Using trl-internal-testing/tiny-random-LlamaForCausalLM, all ONNX and ORT tests pass (including gpu ones pytest tests/onnxruntime/test_*.py -k "llama" -s --exitfirst -m "gpu_test")

I will change the link. But what you think about the issue I reported above?

@fxmarty
Copy link
Contributor

fxmarty commented Apr 4, 2023

About this issue:

The process takes more than 100gb of space and peak RAM usage was at 77-80gb to do the whole thing, from loading the model into ORT and then exporting it

It could be that the large memory requirement during the export has been fixed in #932 (if the issue was when merging the decoders). Although it could be that even before this postprocess step, as we export the model in two parts, (decoder_model.onnx and decoder_with_past_model.onnx), we have at some point memory duplication during the export. I will check that but it is a separate issue.

The exported model takes 51gb of hard-drive space.

To me the exported model size are decent for trl-internal-testing/tiny-random-LlamaForCausalLM compared to the pytorch .bin:
image
https://huggingface.co/trl-internal-testing/tiny-random-LlamaForCausalLM/tree/main

@fxmarty
Copy link
Contributor

fxmarty commented Apr 4, 2023

About the issue with the code snippet you gave: CUDA EP + use_io_binding=True (the default) + use_cache=False is not tested (see

@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
)
@require_torch_gpu
@pytest.mark.gpu_test
def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool):
). I'll do a PR to raise an error if an user tries this setting.

@nenkoru
Copy link
Contributor Author

nenkoru commented Apr 4, 2023

About the issue with the code snippet you gave: CUDA EP + use_io_binding=True (the default) + use_cache=False is not tested (see

@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
)
@require_torch_gpu
@pytest.mark.gpu_test
def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool):

). I'll do a PR to raise an error if an user tries this setting.

But what about my way of fixing that? It was working with that fancy monkey patching quite well

@fxmarty
Copy link
Contributor

fxmarty commented Apr 4, 2023

@nenkoru Yes I guess it works! Ideally we would want to fix the foward/IO Binding code itself and add the test for IO Binding + use_cache=False, but there's I think no strong reason for not reusing past key values.

@nenkoru
Copy link
Contributor Author

nenkoru commented Apr 4, 2023

@nenkoru Yes I guess it works! Ideally we would want to fix the foward/IO Binding code itself and add the test for IO Binding + use_cache=False, but there's I think no strong reason for not reusing past key values.

For me it was neccesary. Either way, as long as the exporter generates two files. One for the base model and the other one for the model with past keys - it was loading 2x more VRAM. Which is not an expected behaviour. And yep, was unable to merge even on 256gb RAM. I guess it was fixed now by #932 as you mentioned above. Still, I think that should be working as expected - no past keys. Just my two cents

@gilljon
Copy link

gilljon commented Apr 12, 2023

I've exported a Llama-based model (using past key values + merged); however, when I run inference on it, the memory usage is greater (~67GB for 13B parameter model) AND the runtime is significantly slower compared to the original model:

NUM TOKENS:  50
ORT TIMES:  [4.2505810260772705, 3.958254814147949, 3.9548168182373047, 3.8123669624328613, 3.7423346042633057]
NORMAL TIMES:  [1.8341279029846191, 1.831620693206787, 1.8367419242858887, 1.827073097229004, 1.8272616863250732]
NUM TOKENS:  100
ORT TIMES:  [12.765657663345337, 11.290537118911743, 11.22416877746582, 11.383142709732056, 11.26672911643982]
NORMAL TIMES:  [3.7361843585968018, 3.7319560050964355, 3.737166166305542, 3.7369601726531982, 3.735769748687744]
NUM TOKENS:  200
ORT TIMES:  [40.04459619522095, 35.00810360908508, 33.88215517997742, 33.990755558013916, 36.0823016166687]
NORMAL TIMES:  [7.595353603363037, 7.5860066413879395, 7.58692479133606, 7.597402095794678, 7.584235429763794]

And when I try to convert it with the fp16 flag, I get a: failed:type error: type parameter (t) of optype (max) bound to different types (tensor(float16) and tensor(float) in node (max_1670)

Any suggestions? Mind providing me with your package versions?

@fxmarty
Copy link
Contributor

fxmarty commented Apr 12, 2023

Hi @gilljon is this with CPU execution provider or CUDA EP? Hopefully transformers release comes soon so that we can merge the PR and test it!

@gilljon
Copy link

gilljon commented Apr 12, 2023

Hi @gilljon is this with CPU execution provider or CUDA EP? Hopefully transformers release comes soon so that we can merge the PR and test it!

It seems like the issue was running an outdated torch. Upgrading torch to torch==2.1.0 fixed the conversion issue. That being said, when performance testing the model, I am still noticing a decrease in generation speed:

NUM TOKENS:  20
ORT TIMES:  [0.5817182064056396, 0.5831649303436279, 0.5821049213409424, 0.5825436115264893, 0.5824661254882812]
NORMAL TIMES:  [0.4072282314300537, 0.4071035385131836, 0.4074513912200928, 0.40746521949768066, 0.407412052154541]
NUM TOKENS:  50
ORT TIMES:  [1.5875389575958252, 1.565403938293457, 1.5717802047729492, 1.5658864974975586, 1.5674340724945068]
NORMAL TIMES:  [1.0803217887878418, 1.0725362300872803, 1.0720314979553223, 1.0732998847961426, 1.0723967552185059]
NUM TOKENS:  100
ORT TIMES:  [3.256457805633545, 3.2374284267425537, 3.2630958557128906, 3.241764545440674, 3.2146661281585693]
NORMAL TIMES:  [2.1945881843566895, 2.1934919357299805, 2.2603249549865723, 2.1941864490509033, 2.19608736038208]
NUM TOKENS:  200
ORT TIMES:  [6.571287631988525, 7.072286605834961, 6.634500503540039, 6.614758729934692, 6.610734224319458]
NORMAL TIMES:  [4.461749792098999, 4.541143417358398, 4.538911581039429, 4.538933992385864, 4.537773847579956]

The model has the following ort_config.json:

{
  "one_external_file": true,
  "opset": null,
  "optimization": {
    "disable_attention": null,
    "disable_attention_fusion": false,
    "disable_bias_gelu": null,
    "disable_bias_gelu_fusion": false,
    "disable_bias_skip_layer_norm": null,
    "disable_bias_skip_layer_norm_fusion": false,
    "disable_embed_layer_norm": true,
    "disable_embed_layer_norm_fusion": true,
    "disable_gelu": null,
    "disable_gelu_fusion": false,
    "disable_layer_norm": null,
    "disable_layer_norm_fusion": false,
    "disable_shape_inference": true,
    "disable_skip_layer_norm": null,
    "disable_skip_layer_norm_fusion": false,
    "enable_gelu_approximation": false,
    "enable_transformers_specific_optimizations": true,
    "fp16": false,
    "no_attention_mask": false,
    "optimization_level": 2,
    "optimize_for_gpu": false,
    "optimize_with_onnxruntime_only": null,
    "use_mask_index": false
  },
  "optimum_version": "1.7.4.dev0",
  "quantization": {},
  "transformers_version": "4.28.0.dev0",
  "use_external_data_format": true
}

and was created using the standard optimum export command with the flags: --fp16 --device cuda --optimize O2.

Any ideas? This is with CUDAExecutionProvider. Also worth noting is the memory usage: 67GB for the 7B model.

@fxmarty
Copy link
Contributor

fxmarty commented Apr 13, 2023

For runtime, I am not sure, I will profile better once this is merged.

For memory, this is probably related: microsoft/onnxruntime#14526

I've seen a similar issue with gpt-j.

@sam-h-bean
Copy link
Contributor

What is the status on this one @fxmarty / @gilljon? I'd love to try hosting this model on triton in onnx

@gilljon
Copy link

gilljon commented Apr 16, 2023

@sam-h-bean You can definitely host it on triton in onnx, although the memory footprint is still > than raw Python. Alternatively, you can deploy it on Triton using Python Backend.

@sam-h-bean
Copy link
Contributor

@gilljon do you think it will get to the point where the memory in onnx is smaller? At 67GB that wouldn't even fit on an A100.

@fxmarty
Copy link
Contributor

fxmarty commented Apr 17, 2023

Using #975 instead of this branch as I could not push here. Llama onnx export will be included in todays release!

Not sure if this will work well with TRT, and I always have issues with ORT CUDA EP memory usage.

@fxmarty fxmarty closed this Apr 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants