From 4a59f9d1e042b74fd7e84d20b1e42ed1025f019e Mon Sep 17 00:00:00 2001 From: Jingya Date: Fri, 20 Dec 2024 15:40:47 +0000 Subject: [PATCH] fix --- optimum/onnxruntime/runs/__init__.py | 6 +++--- optimum/utils/input_generators.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/optimum/onnxruntime/runs/__init__.py b/optimum/onnxruntime/runs/__init__.py index 1d982949344..d21db2a4aca 100644 --- a/optimum/onnxruntime/runs/__init__.py +++ b/optimum/onnxruntime/runs/__init__.py @@ -110,9 +110,9 @@ def __init__(self, run_config): model_class = FeaturesManager.get_model_class_for_feature(get_autoclass_name(self.task)) self.torch_model = model_class.from_pretrained(run_config["model_name_or_path"]) - self.return_body[ - "model_type" - ] = self.torch_model.config.model_type # return_body is initialized in parent class + self.return_body["model_type"] = ( + self.torch_model.config.model_type + ) # return_body is initialized in parent class def _launch_time(self, trial): batch_size = trial.suggest_categorical("batch_size", self.batch_sizes) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 47a4a85395b..4c2a9711103 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -672,7 +672,7 @@ def generate( None, ) - return super().generate(input_name, framework=framework, int_dtype=int_dtype) + return super().generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype, framework=framework) class DummyPastKeyValuesGenerator(DummyInputGenerator): @@ -1670,7 +1670,7 @@ def generate( shape = [self.batch_size] # With transformer diffusers, timestep is a 1D tensor return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) - return super().generate(input_name, framework, int_dtype, float_dtype) + return super().generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype, framework=framework) class DummyTransformerVisionInputGenerator(DummyVisionInputGenerator): @@ -1691,14 +1691,14 @@ def generate( framework: Optional[str] = None, ): if input_name == "encoder_hidden_states": - return super().generate(input_name, framework, int_dtype, float_dtype)[0] + return super().generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype, framework=framework)[0] elif input_name == "pooled_projections": return self.random_float_tensor( [self.batch_size, self.normalized_config.projection_size], framework=framework, dtype=float_dtype ) - return super().generate(input_name, framework, int_dtype, float_dtype) + return super().generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype, framework=framework) class DummyFluxTransformerVisionInputGenerator(DummyTransformerVisionInputGenerator): @@ -1725,7 +1725,7 @@ def generate( ) return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype) - return super().generate(input_name, framework, int_dtype, float_dtype) + return super().generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype, framework=framework) class DummyFluxTransformerTextInputGenerator(DummyTransformerTextInputGenerator): @@ -1754,4 +1754,4 @@ def generate( shape = [self.batch_size] return self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework, dtype=float_dtype) - return super().generate(input_name, framework, int_dtype, float_dtype) + return super().generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype, framework=framework)