From 12c530ef0a35cf9ab17ff7114d93716ee8df4980 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Thu, 28 Nov 2024 09:17:29 +0000 Subject: [PATCH] unbundle --- optimum/utils/input_generators.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index fbb77e6800a..18a2a5a3fd1 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -897,14 +897,14 @@ def __init__( ): self.task = task self.vocab_size = normalized_config.vocab_size - self.text_encoder_projection_dim = normalized_config.text_encoder_projection_dim - self.time_ids = 5 if normalized_config.requires_aesthetics_score else 6 + self.text_encoder_projection_dim = getattr(normalized_config, "text_encoder_projection_dim", None) + self.time_ids = 5 if getattr(normalized_config, "requires_aesthetics_score", False) else 6 if random_batch_size_range: low, high = random_batch_size_range self.batch_size = random.randint(low, high) else: self.batch_size = batch_size - self.time_cond_proj_dim = normalized_config.config.time_cond_proj_dim + self.time_cond_proj_dim = getattr(normalized_config.config, "time_cond_proj_dim", None) def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): if input_name == "timestep": @@ -912,8 +912,16 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) if input_name == "text_embeds": + if self.text_encoder_projection_dim is None: + raise ValueError( + "Unable to infer the value of `text_encoder_projection_dim` for generating `text_embeds`, please double check the config of your model." + ) dim = self.text_encoder_projection_dim elif input_name == "timestep_cond": + if self.time_cond_proj_dim is None: + raise ValueError( + "Unable to infer the value of `time_cond_proj_dim` for generating `timestep_cond`, please double check the config of your model." + ) dim = self.time_cond_proj_dim else: dim = self.time_ids