From ba113e5680774f369243d79c3e3a2e2fc0017902 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 6 Oct 2023 13:50:46 +0200 Subject: [PATCH] Fix llama ONNX export (#1432) * fix * fix export --- optimum/exporters/onnx/model_configs.py | 41 +++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 401d995fdc..d0372f3145 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -216,7 +216,48 @@ class OPTOnnxConfig(TextDecoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig +class LlamaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + **kwargs, + ) + self.num_key_value_heads = normalized_config.num_key_value_heads + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + shape = ( + self.batch_size, + self.num_key_value_heads, + self.sequence_length, + self.hidden_size // self.num_attention_heads, + ) + return [ + ( + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] + + class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, LlamaDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = LlamaDummyPastKeyValuesGenerator + DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig