diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index 5b94202fba..98e818a5c6 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -305,7 +305,6 @@ huggingface-cli login Here is how to generate SD3 images with a single prompt: ```bash -PT_HPU_MAX_COMPOUND_OP_SIZE=1 \ python text_to_image_generation.py \ --model_name_or_path stabilityai/stable-diffusion-3-medium-diffusers \ --prompts "Sailing ship painting by Van Gogh" \ @@ -321,9 +320,47 @@ python text_to_image_generation.py \ --bf16 ``` -> [!NOTE] -> For improved performance of the SD3 pipeline on Gaudi, it is recommended to configure the environment -> by setting PT_HPU_MAX_COMPOUND_OP_SIZE to 1. +This model can also be quantized with some ops running in FP8 precision. + +Before quantization, run stats collection using measure mode: + +```bash +QUANT_CONFIG=quantization/stable-diffusion-3/measure_config.json \ +python text_to_image_generation.py \ + --model_name_or_path stabilityai/stable-diffusion-3-medium-diffusers \ + --prompts "Sailing ship painting by Van Gogh" \ + --num_images_per_prompt 10 \ + --batch_size 1 \ + --num_inference_steps 28 \ + --image_save_dir /tmp/stable_diffusion_3_images \ + --scheduler default \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --sdp_on_bf16 \ + --bf16 \ + --quant_mode measure +``` + +After stats collection, here is how to run SD3 in quantization mode: + +```bash +QUANT_CONFIG=quantization/stable-diffusion-3/quantize_config.json \ +python text_to_image_generation.py \ + --model_name_or_path stabilityai/stable-diffusion-3-medium-diffusers \ + --prompts "Sailing ship painting by Van Gogh" \ + --num_images_per_prompt 10 \ + --batch_size 1 \ + --num_inference_steps 28 \ + --image_save_dir /tmp/stable_diffusion_3_images \ + --scheduler default \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --sdp_on_bf16 \ + --bf16 \ + --quant_mode quantize +``` ### FLUX.1 diff --git a/examples/stable-diffusion/quantization/stable-diffusion-3/measure_config.json b/examples/stable-diffusion/quantization/stable-diffusion-3/measure_config.json new file mode 100644 index 0000000000..ebf3baa292 --- /dev/null +++ b/examples/stable-diffusion/quantization/stable-diffusion-3/measure_config.json @@ -0,0 +1,5 @@ +{ + "method": "HOOKS", + "mode": "MEASURE", + "dump_stats_path": "quantization/stable-diffusion-3/measure_all/fp8" +} \ No newline at end of file diff --git a/examples/stable-diffusion/quantization/stable-diffusion-3/quantize_config.json b/examples/stable-diffusion/quantization/stable-diffusion-3/quantize_config.json new file mode 100644 index 0000000000..1fa98ebce0 --- /dev/null +++ b/examples/stable-diffusion/quantization/stable-diffusion-3/quantize_config.json @@ -0,0 +1,6 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "scale_method": "maxabs_hw_opt_weight", + "dump_stats_path": "quantization/stable-diffusion-3/measure_all/fp8" +} \ No newline at end of file diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index 8fc4e0de4c..b32fc5c3f6 100755 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -305,6 +305,12 @@ def main(): default=None, help="The file with prompts (for large number of images generation).", ) + parser.add_argument( + "--lora_scale", + type=float, + default=None, + help="A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.", + ) args = parser.parse_args() if args.optimize and not args.use_habana: @@ -380,6 +386,9 @@ def main(): if args.throughput_warmup_steps is not None: kwargs_call["throughput_warmup_steps"] = args.throughput_warmup_steps + if args.lora_scale is not None: + kwargs_call["lora_scale"] = args.lora_scale + negative_prompts = args.negative_prompts if args.distributed: distributed_state = PartialState() diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index bcd4b6f172..731e0434bf 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -20,6 +20,8 @@ import numpy as np import PIL import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.transformers import SD3Transformer2DModel from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Pipeline @@ -76,6 +78,101 @@ class GaudiStableDiffusion3PipelineOutput(BaseOutput): """ +# ToDo: Look into FusedJointAttnProcessor2_0 usage for sd3 pipeline, and check its perf using fused sdpa +class GaudiJointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections. + Copied from JointAttnProcessor2_0.forward: https://github.com/huggingface/diffusers/blob/89e4d6219805975bd7d253a267e1951badc9f1c0/src/diffusers/models/attention_processor.py + The only differences are: + - applied Fused SDPA from Habana's framework. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + + from habana_frameworks.torch.hpex.kernels import FusedSDPA + + hidden_states = FusedSDPA.apply(query, key, value, None, 0.0, False, None, "fast", None) + + # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class GaudiStableDiffusion3Pipeline(GaudiDiffusionPipeline, StableDiffusion3Pipeline): r""" Adapted from: https://github.com/huggingface/diffusers/blob/v0.29.2/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py#L128 @@ -165,6 +262,93 @@ def __init__( self.to(self._device) + @classmethod + def _split_inputs_into_batches( + cls, + batch_size, + latents, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ): + # Use torch.split to generate num_batches batches of size batch_size + latents_batches = list(torch.split(latents, batch_size)) + prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size)) + + if negative_prompt_embeds is not None: + negative_prompt_embeds_batches = list(torch.split(negative_prompt_embeds, batch_size)) + if pooled_prompt_embeds is not None: + pooled_prompt_embeds_batches = list(torch.split(pooled_prompt_embeds, batch_size)) + if negative_pooled_prompt_embeds is not None: + negative_pooled_prompt_embeds_batches = list(torch.split(negative_pooled_prompt_embeds, batch_size)) + + # If the last batch has less samples than batch_size, pad it with dummy samples + num_dummy_samples = 0 + if latents_batches[-1].shape[0] < batch_size: + num_dummy_samples = batch_size - latents_batches[-1].shape[0] + # Pad latents_batches + sequence_to_stack = (latents_batches[-1],) + tuple( + torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + latents_batches[-1] = torch.vstack(sequence_to_stack) + # Pad prompt_embeds_batches + sequence_to_stack = (prompt_embeds_batches[-1],) + tuple( + torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) + # Pad negative_prompt_embeds_batches if necessary + if negative_prompt_embeds is not None: + sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple( + torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) + # Pad add_text_embeds_batches if necessary + if pooled_prompt_embeds is not None: + sequence_to_stack = (pooled_prompt_embeds_batches[-1],) + tuple( + torch.zeros_like(pooled_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + pooled_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) + # Pad negative_pooled_prompt_embeds_batches if necessary + if negative_pooled_prompt_embeds is not None: + sequence_to_stack = (negative_pooled_prompt_embeds_batches[-1],) + tuple( + torch.zeros_like(negative_pooled_prompt_embeds_batches[-1][0][None, :]) + for _ in range(num_dummy_samples) + ) + negative_pooled_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) + + # Stack batches in the same tensor + latents_batches = torch.stack(latents_batches) + # if self.do_classifier_free_guidance: + + if negative_prompt_embeds is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + for i, (negative_prompt_embeds_batch, prompt_embeds_batch) in enumerate( + zip(negative_prompt_embeds_batches, prompt_embeds_batches[:]) + ): + prompt_embeds_batches[i] = torch.cat([negative_prompt_embeds_batch, prompt_embeds_batch]) + + prompt_embeds_batches = torch.stack(prompt_embeds_batches) + + if pooled_prompt_embeds is not None: + if negative_pooled_prompt_embeds is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + for i, (negative_pooled_prompt_embeds_batch, pooled_prompt_embeds_batch) in enumerate( + zip(negative_pooled_prompt_embeds_batches, pooled_prompt_embeds_batches[:]) + ): + pooled_prompt_embeds_batches[i] = torch.cat( + [negative_pooled_prompt_embeds_batch, pooled_prompt_embeds_batch] + ) + pooled_prompt_embeds_batches = torch.stack(pooled_prompt_embeds_batches) + else: + pooled_prompt_embeds_batches = None + + return latents_batches, prompt_embeds_batches, pooled_prompt_embeds_batches, num_dummy_samples + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -302,9 +486,39 @@ def __call__( [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ + import habana_frameworks.torch as ht import habana_frameworks.torch.core as htcore - with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast): + # Set dtype to BF16 only if --bf16 is used, else use device's default autocast precision + # When --bf16 is used, bf16_full_eval=True, which disables use_torch_autocast + with torch.autocast( + device_type="hpu", + enabled=self.gaudi_config.use_torch_autocast, + dtype=torch.bfloat16 if not self.gaudi_config.use_torch_autocast else None, + ): + quant_mode = kwargs.get("quant_mode", None) + if quant_mode == "measure" or quant_mode == "quantize": + import os + + quant_config_path = os.getenv("QUANT_CONFIG") + + if not quant_config_path: + raise ImportError( + "Error: QUANT_CONFIG path is not defined. Please define path to quantization configuration JSON file." + ) + elif not os.path.isfile(quant_config_path): + raise ImportError(f"Error: QUANT_CONFIG path '{quant_config_path}' is not valid") + + htcore.hpu_set_env() + from neural_compressor.torch.quantization import FP8Config, convert, prepare + + config = FP8Config.from_json_file(quant_config_path) + if config.measure: + self.transformer = prepare(self.transformer, config) + elif config.quantize: + self.transformer = convert(self.transformer, config) + htcore.hpu_initialize(self.transformer, mark_only_scales_as_const=True) + height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor @@ -333,14 +547,17 @@ def __call__( # 2. Define call parameters if prompt is not None and isinstance(prompt, str): - batch_size = 1 + num_prompts = 1 elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) + num_prompts = len(prompt) else: - batch_size = prompt_embeds.shape[0] + num_prompts = prompt_embeds.shape[0] + num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size) device = self._execution_device + lora_scale = kwargs.get("lora_scale", None) if kwargs is not None else None + ( prompt_embeds, negative_prompt_embeds, @@ -362,12 +579,9 @@ def __call__( clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + lora_scale=lora_scale, ) - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) - # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -376,7 +590,7 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( - batch_size * num_images_per_prompt, + num_prompts * num_images_per_prompt, num_channels_latents, height, width, @@ -386,14 +600,6 @@ def __call__( latents, ) - # 5-1. Define call parameters - if prompt is not None and isinstance(prompt, str): - num_prompts = 1 - elif prompt is not None and isinstance(prompt, list): - num_prompts = len(prompt) - else: - num_prompts = prompt_embeds.shape[0] - num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size) logger.info( f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt," f" {batch_size} sample(s) per batch, {num_batches} total batch(es)." @@ -402,40 +608,86 @@ def __call__( logger.warning("The first two iterations are slower so it is recommended to feed more batches.") throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3) + use_warmup_inference_steps = ( + num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps + ) - t0 = time.time() - t1 = t0 + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index() hb_profiler = HabanaProfile( warmup=profiling_warmup_steps, active=profiling_steps, record_shapes=False, ) + hb_profiler.start() - # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # because compilation occurs in the first two iterations - if i == throughput_warmup_steps: + # 6. Split Input data to batches (HPU-specific step) + latents_batches, text_embeddings_batches, pooled_prompt_embeddings_batches, num_dummy_samples = ( + self._split_inputs_into_batches( + batch_size, + latents, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + ) + + outputs = { + "images": [], + } + + for block in self.transformer.transformer_blocks: + block.attn.processor = GaudiJointAttnProcessor2_0() + ht.hpu.synchronize() + + t0 = time.time() + t1 = t0 + + # 7. Denoising loop + for j in range(num_batches): + latents_batch = latents_batches[0] + latents_batches = torch.roll(latents_batches, shifts=-1, dims=0) + text_embeddings_batch = text_embeddings_batches[0] + text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0) + pooled_prompt_embeddings_batch = pooled_prompt_embeddings_batches[0] + pooled_prompt_embeddings_batches = torch.roll(pooled_prompt_embeddings_batches, shifts=-1, dims=0) + + if hasattr(self.scheduler, "_init_step_index"): + # Reset scheduler step index for next batch + self.scheduler.timesteps = timesteps + self.scheduler._init_step_index(timesteps[0]) + + # Throughput is calculated after warmup iterations + if j == throughput_warmup_steps: + t1 = time.time() + + for i in self.progress_bar(range(len(timesteps))): + timestep = timesteps[0] + timesteps = torch.roll(timesteps, shifts=-1, dims=0) + + if use_warmup_inference_steps and i == throughput_warmup_steps and j == num_batches - 1: t1 = time.time() if self.interrupt: continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch + ) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) + timestep_batch = timestep.expand(latent_model_input.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + noise_pred = self.transformer_hpu( + latent_model_input, + timestep_batch, + text_embeddings_batch, + pooled_prompt_embeddings_batch, + self.joint_attention_kwargs, + ) # perform guidance if self.do_classifier_free_guidance: @@ -443,62 +695,173 @@ def __call__( noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + latents_dtype = latents_batch.dtype + latents_batch = self.scheduler.step(noise_pred, timestep, latents_batch, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) + latents_batch = callback_outputs.pop("latents", latents_batch) - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() + _prompt_embeds = callback_outputs.pop("prompt_embeds", None) + _negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", None) + if _prompt_embeds is not None and _negative_prompt_embeds is not None: + text_embeddings_batch = torch.cat([_negative_prompt_embeds, _prompt_embeds]) + _pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", None) + _negative_pooled_prompt_embeds = callback_outputs.pop("negative_pooled_prompt_embeds", None) + if _pooled_prompt_embeds is not None and _negative_pooled_prompt_embeds is not None: + pooled_prompt_embeddings_batch = torch.cat( + [_negative_pooled_prompt_embeds, _pooled_prompt_embeds] + ) hb_profiler.step() htcore.mark_step(sync=True) - hb_profiler.stop() + if output_type == "latent": + image = latents_batch - t1 = warmup_inference_steps_time_adjustment(t1, t1, num_inference_steps, throughput_warmup_steps) - speed_metrics_prefix = "generation" - speed_measures = speed_metrics( - split=speed_metrics_prefix, - start_time=t0, - num_samples=num_batches * batch_size, - num_steps=num_batches * batch_size * num_inference_steps, - start_time_after_warmup=t1, - ) - logger.info(f"Speed metrics: {speed_measures}") - if output_type == "latent": - image = latents + else: + latents_batch = (latents_batch / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents_batch, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) - else: - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + outputs["images"].append(image) + + # End of Denoising loop + + hb_profiler.stop() - image = self.vae.decode(latents, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) + ht.hpu.synchronize() + speed_metrics_prefix = "generation" + if use_warmup_inference_steps: + t1 = warmup_inference_steps_time_adjustment(t1, t1, num_inference_steps, throughput_warmup_steps) + speed_measures = speed_metrics( + split=speed_metrics_prefix, + start_time=t0, + num_samples=batch_size + if t1 == t0 or use_warmup_inference_steps + else (num_batches - throughput_warmup_steps) * batch_size, + num_steps=batch_size * num_inference_steps + if use_warmup_inference_steps + else (num_batches - throughput_warmup_steps) * batch_size * num_inference_steps, + start_time_after_warmup=t1, + ) + logger.info(f"Speed metrics: {speed_measures}") + + if quant_mode == "measure": + from neural_compressor.torch.quantization import finalize_calibration + + finalize_calibration(self.transformer) + + # 8 Output Images + # Remove dummy generations if needed + if num_dummy_samples > 0: + outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples] + + # Process generated images + for i, image in enumerate(outputs["images"][:]): + if i == 0: + outputs["images"].clear() + + # image = self.image_processor.postprocess(image, output_type=output_type) + + if output_type == "pil" and isinstance(image, list): + outputs["images"] += image + elif output_type in ["np", "numpy"] and isinstance(image, np.ndarray): + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = np.concatenate((outputs["images"], image), axis=0) + else: + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = torch.cat((outputs["images"], image), 0) # Offload all models self.maybe_free_model_hooks() if not return_dict: - return (image,) + return outputs["images"] return GaudiStableDiffusion3PipelineOutput( - images=image, + images=outputs["images"], throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"], ) + + @torch.no_grad() + def transformer_hpu( + self, + latent_model_input, + timestep, + text_embeddings_batch, + pooled_prompt_embeddings_batch, + joint_attention_kwargs, + ): + if self.use_hpu_graphs: + return self.capture_replay( + latent_model_input, + timestep, + text_embeddings_batch, + pooled_prompt_embeddings_batch, + joint_attention_kwargs, + ) + else: + return self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=text_embeddings_batch, + pooled_projections=pooled_prompt_embeddings_batch, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=False, + )[0] + + @torch.no_grad() + def capture_replay( + self, + latent_model_input, + timestep, + encoder_hidden_states, + pooled_prompt_embeddings_batch, + joint_attention_kwargs, + ): + inputs = [ + latent_model_input, + timestep, + encoder_hidden_states, + pooled_prompt_embeddings_batch, + joint_attention_kwargs, + ] + h = self.ht.hpu.graphs.input_hash(inputs) + cached = self.cache.get(h) + + if cached is None: + # Capture the graph and cache it + with self.ht.hpu.stream(self.hpu_stream): + graph = self.ht.hpu.HPUGraph() + graph.capture_begin() + + outputs = self.transformer( + hidden_states=inputs[0], + timestep=inputs[1], + encoder_hidden_states=inputs[2], + pooled_projections=inputs[3], + joint_attention_kwargs=inputs[4], + return_dict=False, + )[0] + + graph.capture_end() + graph_inputs = inputs + graph_outputs = outputs + self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph) + return outputs + + # Replay the cached graph with updated inputs + self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs) + cached.graph.replay() + self.ht.core.hpu.default_stream().synchronize() + + return cached.graph_outputs