Skip to content

Commit

Permalink
Add llama onnx export & onnxruntime support (#975)
Browse files Browse the repository at this point in the history
* Add config for Llama

* Register Llama in tasks

* Add llama and it's corresponding tiny-random model from hf into tests

* Add tests for modeling and exporters

* Add entry for a Llama

* Add llama into supported normalized configs

* Add optimization support for llama

* Change tiny-llama source to trl-internal-testing

* Change tiny-llama source to trl-internal-testing

* can I push?

* fix tests

* fix task map

---------

Co-authored-by: Chernenko Ruslan <[email protected]>
  • Loading branch information
fxmarty and nenkoru authored Apr 17, 2023
1 parent 5486bc7 commit 22b10a4
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Supported architectures:
- LayoutLM-v3
- Levit
- LongT5
- Llama
- M2-M100
- Marian
- MBart
Expand Down
5 changes: 5 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ class OPTOnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class LlamaOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
past_key_shape = (
Expand Down
8 changes: 8 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,14 @@ class TasksManager:
"text-classification",
onnx="OPTOnnxConfig",
),
"llama": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="LlamaOnnxConfig",
),
"pegasus": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
1 change: 0 additions & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,6 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
"token_type_ids": None,
}

# Copied from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class ORTConfigManager:
"gptj": "gpt2",
# longt5 with O4 results in segmentation fault
"longt5": "bert",
"llama": "gpt2",
"marian": "bart",
"mbart": "bart",
"mt5": "bart",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class NormalizedConfigManager:
"gpt2": GPT2LikeNormalizedTextConfig,
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt_neox": NormalizedTextConfig,
"llama": NormalizedTextConfig,
"gptj": GPT2LikeNormalizedTextConfig,
"imagegpt": GPT2LikeNormalizedTextConfig,
"longt5": T5LikeNormalizedTextConfig,
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"levit": "hf-internal-testing/tiny-random-LevitModel",
"layoutlm": "hf-internal-testing/tiny-random-LayoutLMModel",
"layoutlmv3": "hf-internal-testing/tiny-random-LayoutLMv3Model",
"llama": "fxmarty/tiny-llama-fast-tokenizer",
"longt5": "fxmarty/tiny-random-working-LongT5Model",
# "longformer": "allenai/longformer-base-4096",
"m2m-100": "hf-internal-testing/tiny-random-m2m_100",
Expand Down Expand Up @@ -171,6 +172,7 @@
"levit": "facebook/levit-128S",
"layoutlm": "microsoft/layoutlm-base-uncased",
"layoutlmv3": "microsoft/layoutlmv3-base",
"llama": "decapoda-research/llama-65b-hf",
"longt5": "fxmarty/tiny-random-working-LongT5Model", # Not using google/long-t5-local-base because it takes too much time for testing.
# "longformer": "allenai/longformer-base-4096",
"m2m-100": "hf-internal-testing/tiny-random-m2m_100", # Not using facebook/m2m100_418M because it takes too much time for testing.
Expand Down
35 changes: 26 additions & 9 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,6 +1970,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
"gpt_neo",
"gpt_neox",
"gptj",
"llama",
]

FULL_GRID = {
Expand Down Expand Up @@ -2021,7 +2022,7 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str):
model = ORTModelForCausalLM.from_pretrained(self.onnx_model_dirs[test_name])
tokenizer = get_preprocessor(model_id)
text = "This is a sample output"
tokens = tokenizer(text, return_tensors="pt")
tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None)

# General case
outputs = model.generate(**tokens)
Expand All @@ -2030,7 +2031,7 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str):
self.assertTrue(len(res[0]) > len(text))

# With input ids
tokens = tokenizer(text, return_tensors="pt")
tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None)
outputs = model.generate(input_ids=tokens["input_ids"])
res = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertIsInstance(res[0], str)
Expand Down Expand Up @@ -2118,7 +2119,11 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
set_seed(SEED)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = get_preprocessor(model_id)
tokens = tokenizer("This is a sample output", return_tensors="pt")
tokens = tokenizer(
"This is a sample output",
return_tensors="pt",
return_token_type_ids=False if model_arch == "llama" else None,
)
onnx_outputs = onnx_model(**tokens)

self.assertTrue("logits" in onnx_outputs)
Expand Down Expand Up @@ -2217,12 +2222,16 @@ def test_pipeline_on_trt_execution_provider(self, test_name: str, model_arch: st

# build engine for a short sequence
text = ["short"]
encoded_input = tokenizer(text, return_tensors="pt").to("cuda")
encoded_input = tokenizer(
text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
).to("cuda")
_ = onnx_model(**encoded_input)

# build engine for a long sequence
text = [" a very long input just for demo purpose, this is very long" * 10]
encoded_input = tokenizer(text, return_tensors="pt").to("cuda")
encoded_input = tokenizer(
text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
).to("cuda")
_ = onnx_model(**encoded_input)

pipe = pipeline("text-generation", model=onnx_model, tokenizer=tokenizer, device=0)
Expand All @@ -2235,7 +2244,11 @@ def test_pipeline_on_trt_execution_provider(self, test_name: str, model_arch: st
self.assertTrue(isinstance(outputs[0]["generated_text"], str))
self.assertTrue(len(outputs[0]["generated_text"]) > len(text))

encoded_input = tokenizer(["Replace me by any text you'd like."], return_tensors="pt").to("cuda")
encoded_input = tokenizer(
["Replace me by any text you'd like."],
return_tensors="pt",
return_token_type_ids=False if model_arch == "llama" else None,
).to("cuda")
_ = onnx_model.generate(**encoded_input)

gc.collect()
Expand All @@ -2251,7 +2264,7 @@ def test_compare_with_and_without_past_key_values(self, model_arch):
model_id = MODEL_NAMES[model_arch]
tokenizer = get_preprocessor(model_id)
text = "My Name is Philipp and i live"
tokens = tokenizer(text, return_tensors="pt")
tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None)

model_with_pkv = ORTModelForCausalLM.from_pretrained(
self.onnx_model_dirs[model_arch + "_True"], use_cache=True
Expand Down Expand Up @@ -2302,7 +2315,7 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode
model_id = MODEL_NAMES[model_arch]
tokenizer = get_preprocessor(model_id)
text = "My Name is Philipp and i live"
tokens = tokenizer(text, return_tensors="pt")
tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None)

model_not_merged_dir = self.onnx_model_dirs[test_name + "_False"]
model_merged_dir = self.onnx_model_dirs[test_name + "_True"]
Expand Down Expand Up @@ -2372,7 +2385,11 @@ def test_compare_generation_to_io_binding(self, test_name: str, model_arch: str,
io_model = ORTModelForCausalLM.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True).to("cuda")

tokenizer = get_preprocessor(model_id)
tokens = tokenizer("This is a sample output", return_tensors="pt").to("cuda")
tokens = tokenizer(
"This is a sample output",
return_tensors="pt",
return_token_type_ids=False if model_arch == "llama" else None,
).to("cuda")
onnx_outputs = onnx_model.generate(**tokens)
io_outputs = io_model.generate(**tokens)

Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"layoutlm": "hf-internal-testing/tiny-random-LayoutLMModel",
"layoutlmv3": "hf-internal-testing/tiny-random-LayoutLMv3Model",
"longt5": "hf-internal-testing/tiny-random-LongT5Model",
"llama": "fxmarty/tiny-llama-fast-tokenizer",
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
"marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken
"mbart": "hf-internal-testing/tiny-random-mbart",
Expand Down

0 comments on commit 22b10a4

Please sign in to comment.