From 02c6ed5f413384d543bcf83a3a9094be2c0429a5 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Wed, 15 May 2024 19:06:40 +0400 Subject: [PATCH] Make stable diffusion unet and vae number of channels static (#1840) --- optimum/exporters/onnx/model_configs.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 496957b2b5d..d4c4ac934b9 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -981,7 +981,7 @@ class UNetOnnxConfig(VisionOnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = { - "sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + "sample": {0: "batch_size", 2: "height", 3: "width"}, "timestep": {0: "steps"}, "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, } @@ -998,7 +998,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: @property def outputs(self) -> Dict[str, Dict[int, str]]: return { - "out_sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + "out_sample": {0: "batch_size", 2: "height", 3: "width"}, } @property @@ -1045,13 +1045,13 @@ class VaeEncoderOnnxConfig(VisionOnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: return { - "sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + "sample": {0: "batch_size", 2: "height", 3: "width"}, } @property def outputs(self) -> Dict[str, Dict[int, str]]: return { - "latent_sample": {0: "batch_size", 1: "num_channels_latent", 2: "height_latent", 3: "width_latent"}, + "latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"}, } @@ -1069,13 +1069,13 @@ class VaeDecoderOnnxConfig(VisionOnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: return { - "latent_sample": {0: "batch_size", 1: "num_channels_latent", 2: "height_latent", 3: "width_latent"}, + "latent_sample": {0: "batch_size", 2: "height_latent", 3: "width_latent"}, } @property def outputs(self) -> Dict[str, Dict[int, str]]: return { - "sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + "sample": {0: "batch_size", 2: "height", 3: "width"}, }