Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Dec 20, 2024
1 parent a27d657 commit 4a59f9d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions optimum/onnxruntime/runs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def __init__(self, run_config):
model_class = FeaturesManager.get_model_class_for_feature(get_autoclass_name(self.task))
self.torch_model = model_class.from_pretrained(run_config["model_name_or_path"])

self.return_body[
"model_type"
] = self.torch_model.config.model_type # return_body is initialized in parent class
self.return_body["model_type"] = (
self.torch_model.config.model_type
) # return_body is initialized in parent class

def _launch_time(self, trial):
batch_size = trial.suggest_categorical("batch_size", self.batch_sizes)
Expand Down
12 changes: 6 additions & 6 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def generate(
None,
)

return super().generate(input_name, framework=framework, int_dtype=int_dtype)
return super().generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype, framework=framework)


class DummyPastKeyValuesGenerator(DummyInputGenerator):
Expand Down Expand Up @@ -1670,7 +1670,7 @@ def generate(
shape = [self.batch_size] # With transformer diffusers, timestep is a 1D tensor
return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype)

return super().generate(input_name, framework, int_dtype, float_dtype)
return super().generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype, framework=framework)


class DummyTransformerVisionInputGenerator(DummyVisionInputGenerator):
Expand All @@ -1691,14 +1691,14 @@ def generate(
framework: Optional[str] = None,
):
if input_name == "encoder_hidden_states":
return super().generate(input_name, framework, int_dtype, float_dtype)[0]
return super().generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype, framework=framework)[0]

elif input_name == "pooled_projections":
return self.random_float_tensor(
[self.batch_size, self.normalized_config.projection_size], framework=framework, dtype=float_dtype
)

return super().generate(input_name, framework, int_dtype, float_dtype)
return super().generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype, framework=framework)


class DummyFluxTransformerVisionInputGenerator(DummyTransformerVisionInputGenerator):
Expand All @@ -1725,7 +1725,7 @@ def generate(
)
return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype)

return super().generate(input_name, framework, int_dtype, float_dtype)
return super().generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype, framework=framework)


class DummyFluxTransformerTextInputGenerator(DummyTransformerTextInputGenerator):
Expand Down Expand Up @@ -1754,4 +1754,4 @@ def generate(
shape = [self.batch_size]
return self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework, dtype=float_dtype)

return super().generate(input_name, framework, int_dtype, float_dtype)
return super().generate(input_name, int_dtype=int_dtype, float_dtype=float_dtype, framework=framework)

0 comments on commit 4a59f9d

Please sign in to comment.