From 029ad3319eccf1217e191335f872ef9f0148c903 Mon Sep 17 00:00:00 2001 From: deepak-gowda-narayana Date: Mon, 6 Jan 2025 20:45:35 +0000 Subject: [PATCH 1/6] Optimized SD3 pipeline: * HPU graphs enabled * Batching for inference enabled * Fused SDPA integrated * FP8 quantization enabled Co-authored-by: Daniel Socek --- examples/stable-diffusion/README.md | 45 +- .../stable-diffusion-3/measure_config.json | 5 + .../stable-diffusion-3/quantize_config.json | 6 + .../pipeline_stable_diffusion_3.py | 683 +++++++++++++----- 4 files changed, 564 insertions(+), 175 deletions(-) create mode 100644 examples/stable-diffusion/quantization/stable-diffusion-3/measure_config.json create mode 100644 examples/stable-diffusion/quantization/stable-diffusion-3/quantize_config.json diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index 5b94202fba..3adca760da 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -305,7 +305,6 @@ huggingface-cli login Here is how to generate SD3 images with a single prompt: ```bash -PT_HPU_MAX_COMPOUND_OP_SIZE=1 \ python text_to_image_generation.py \ --model_name_or_path stabilityai/stable-diffusion-3-medium-diffusers \ --prompts "Sailing ship painting by Van Gogh" \ @@ -321,9 +320,47 @@ python text_to_image_generation.py \ --bf16 ``` -> [!NOTE] -> For improved performance of the SD3 pipeline on Gaudi, it is recommended to configure the environment -> by setting PT_HPU_MAX_COMPOUND_OP_SIZE to 1. +This model can also be quantized with some ops running in FP8 precision. + +Before quantization, run stats collection using measure mode: + +```bash +QUANT_CONFIG=quantization/stable-diffusion-3/measure_config.json \ +python text_to_image_generation.py \ + --model_name_or_path stabilityai/stable-diffusion-3-medium-diffusers \ + --prompts "Sailing ship painting by Van Gogh" \ + --num_images_per_prompt 10 \ + --batch_size 1 \ + --num_inference_steps 28 \ + --image_save_dir /tmp/stable_diffusion_3_images \ + --scheduler default \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --sdp_on_bf16 \ + --bf16 + --quant_mode measure +``` + +After stats collection, here is how to run SD3 in quantization mode: + +```bash +QUANT_CONFIG=quantization/stable-diffusion-3/quantize_config.json \ +python text_to_image_generation.py \ + --model_name_or_path stabilityai/stable-diffusion-3-medium-diffusers \ + --prompts "Sailing ship painting by Van Gogh" \ + --num_images_per_prompt 10 \ + --batch_size 1 \ + --num_inference_steps 28 \ + --image_save_dir /tmp/stable_diffusion_3_images \ + --scheduler default \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --sdp_on_bf16 \ + --bf16 + --quant_mode quantize +``` ### FLUX.1 diff --git a/examples/stable-diffusion/quantization/stable-diffusion-3/measure_config.json b/examples/stable-diffusion/quantization/stable-diffusion-3/measure_config.json new file mode 100644 index 0000000000..ebf3baa292 --- /dev/null +++ b/examples/stable-diffusion/quantization/stable-diffusion-3/measure_config.json @@ -0,0 +1,5 @@ +{ + "method": "HOOKS", + "mode": "MEASURE", + "dump_stats_path": "quantization/stable-diffusion-3/measure_all/fp8" +} \ No newline at end of file diff --git a/examples/stable-diffusion/quantization/stable-diffusion-3/quantize_config.json b/examples/stable-diffusion/quantization/stable-diffusion-3/quantize_config.json new file mode 100644 index 0000000000..1fa98ebce0 --- /dev/null +++ b/examples/stable-diffusion/quantization/stable-diffusion-3/quantize_config.json @@ -0,0 +1,6 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "scale_method": "maxabs_hw_opt_weight", + "dump_stats_path": "quantization/stable-diffusion-3/measure_all/fp8" +} \ No newline at end of file diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index bcd4b6f172..9f178deaa4 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -20,6 +20,8 @@ import numpy as np import PIL import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.transformers import SD3Transformer2DModel from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Pipeline @@ -76,6 +78,97 @@ class GaudiStableDiffusion3PipelineOutput(BaseOutput): """ +# ToDo: Look into FusedJointAttnProcessor2_0 usage for sd3 pipeline, and check its perf using fused sdpa +class GaudiJointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + + from habana_frameworks.torch.hpex.kernels import FusedSDPA + + hidden_states = FusedSDPA.apply(query, key, value, None, 0.0, False, None, "fast", None) + + # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class GaudiStableDiffusion3Pipeline(GaudiDiffusionPipeline, StableDiffusion3Pipeline): r""" Adapted from: https://github.com/huggingface/diffusers/blob/v0.29.2/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py#L128 @@ -165,6 +258,93 @@ def __init__( self.to(self._device) + @classmethod + def _split_inputs_into_batches( + cls, + batch_size, + latents, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ): + # Use torch.split to generate num_batches batches of size batch_size + latents_batches = list(torch.split(latents, batch_size)) + prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size)) + + if negative_prompt_embeds is not None: + negative_prompt_embeds_batches = list(torch.split(negative_prompt_embeds, batch_size)) + if pooled_prompt_embeds is not None: + pooled_prompt_embeds_batches = list(torch.split(pooled_prompt_embeds, batch_size)) + if negative_pooled_prompt_embeds is not None: + negative_pooled_prompt_embeds_batches = list(torch.split(negative_pooled_prompt_embeds, batch_size)) + + # If the last batch has less samples than batch_size, pad it with dummy samples + num_dummy_samples = 0 + if latents_batches[-1].shape[0] < batch_size: + num_dummy_samples = batch_size - latents_batches[-1].shape[0] + # Pad latents_batches + sequence_to_stack = (latents_batches[-1],) + tuple( + torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + latents_batches[-1] = torch.vstack(sequence_to_stack) + # Pad prompt_embeds_batches + sequence_to_stack = (prompt_embeds_batches[-1],) + tuple( + torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) + # Pad negative_prompt_embeds_batches if necessary + if negative_prompt_embeds is not None: + sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple( + torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) + # Pad add_text_embeds_batches if necessary + if pooled_prompt_embeds is not None: + sequence_to_stack = (pooled_prompt_embeds_batches[-1],) + tuple( + torch.zeros_like(pooled_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + pooled_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) + # Pad negative_pooled_prompt_embeds_batches if necessary + if negative_pooled_prompt_embeds is not None: + sequence_to_stack = (negative_pooled_prompt_embeds_batches[-1],) + tuple( + torch.zeros_like(negative_pooled_prompt_embeds_batches[-1][0][None, :]) + for _ in range(num_dummy_samples) + ) + negative_pooled_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) + + # Stack batches in the same tensor + latents_batches = torch.stack(latents_batches) + # if self.do_classifier_free_guidance: + + if negative_prompt_embeds is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + for i, (negative_prompt_embeds_batch, prompt_embeds_batch) in enumerate( + zip(negative_prompt_embeds_batches, prompt_embeds_batches[:]) + ): + prompt_embeds_batches[i] = torch.cat([negative_prompt_embeds_batch, prompt_embeds_batch]) + + prompt_embeds_batches = torch.stack(prompt_embeds_batches) + + if pooled_prompt_embeds is not None: + if negative_pooled_prompt_embeds is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + for i, (negative_pooled_prompt_embeds_batch, pooled_prompt_embeds_batch) in enumerate( + zip(negative_pooled_prompt_embeds_batches, pooled_prompt_embeds_batches[:]) + ): + pooled_prompt_embeds_batches[i] = torch.cat( + [negative_pooled_prompt_embeds_batch, pooled_prompt_embeds_batch] + ) + pooled_prompt_embeds_batches = torch.stack(pooled_prompt_embeds_batches) + else: + pooled_prompt_embeds_batches = None + + return latents_batches, prompt_embeds_batches, pooled_prompt_embeds_batches, num_dummy_samples + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -302,203 +482,364 @@ def __call__( [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ + import habana_frameworks.torch as ht import habana_frameworks.torch.core as htcore - with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast): - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - prompt_3, - height, - width, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - negative_prompt_3=negative_prompt_3, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - max_sequence_length=max_sequence_length, - ) + quant_mode = kwargs.get("quant_mode", None) + if quant_mode == "measure" or quant_mode == "quantize": + import os + + quant_config_path = os.getenv("QUANT_CONFIG") + htcore.hpu_set_env() + from neural_compressor.torch.quantization import FP8Config, convert, prepare + + config = FP8Config.from_json_file(quant_config_path) + if config.measure: + self.transformer = prepare(self.transformer, config) + elif config.quantize: + self.transformer = convert(self.transformer, config) + htcore.hpu_initialize(self.transformer, mark_only_scales_as_const=True) + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) - self._guidance_scale = guidance_scale - self._clip_skip = clip_skip - self._joint_attention_kwargs = joint_attention_kwargs - self._interrupt = False + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + num_prompts = 1 + elif prompt is not None and isinstance(prompt, list): + num_prompts = len(prompt) + else: + num_prompts = prompt_embeds.shape[0] + num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size) + + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + num_prompts * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) - device = self._execution_device + logger.info( + f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt," + f" {batch_size} sample(s) per batch, {num_batches} total batch(es)." + ) + if num_batches < 3: + logger.warning("The first two iterations are slower so it is recommended to feed more batches.") - ( + throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3) + use_warmup_inference_steps = ( + num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps + ) + + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index() + + hb_profiler = HabanaProfile( + warmup=profiling_warmup_steps, + active=profiling_steps, + record_shapes=False, + ) + + hb_profiler.start() + + # 6. Split Input data to batches (HPU-specific step) + latents_batches, text_embeddings_batches, pooled_prompt_embeddings_batches, num_dummy_samples = ( + self._split_inputs_into_batches( + batch_size, + latents, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_3=prompt_3, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - negative_prompt_3=negative_prompt_3, - do_classifier_free_guidance=self.do_classifier_free_guidance, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - device=device, - clip_skip=self.clip_skip, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, ) + ) - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) - - # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) + outputs = { + "images": [], + } - # 5-1. Define call parameters - if prompt is not None and isinstance(prompt, str): - num_prompts = 1 - elif prompt is not None and isinstance(prompt, list): - num_prompts = len(prompt) - else: - num_prompts = prompt_embeds.shape[0] - num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size) - logger.info( - f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt," - f" {batch_size} sample(s) per batch, {num_batches} total batch(es)." - ) - if num_batches < 3: - logger.warning("The first two iterations are slower so it is recommended to feed more batches.") + for block in self.transformer.transformer_blocks: + block.attn.processor = GaudiJointAttnProcessor2_0() + ht.hpu.synchronize() - throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3) + t0 = time.time() + t1 = t0 - t0 = time.time() - t1 = t0 + # 7. Denoising loop + for j in range(num_batches): + latents_batch = latents_batches[0] + latents_batches = torch.roll(latents_batches, shifts=-1, dims=0) + text_embeddings_batch = text_embeddings_batches[0] + text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0) + pooled_prompt_embeddings_batch = pooled_prompt_embeddings_batches[0] + pooled_prompt_embeddings_batches = torch.roll(pooled_prompt_embeddings_batches, shifts=-1, dims=0) - hb_profiler = HabanaProfile( - warmup=profiling_warmup_steps, - active=profiling_steps, - record_shapes=False, - ) - hb_profiler.start() - - # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # because compilation occurs in the first two iterations - if i == throughput_warmup_steps: - t1 = time.time() - - if self.interrupt: - continue - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) + if hasattr(self.scheduler, "_init_step_index"): + # Reset scheduler step index for next batch + self.scheduler.timesteps = timesteps + self.scheduler._init_step_index(timesteps[0]) - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() + # Throughput is calculated after warmup iterations + if j == throughput_warmup_steps: + t1 = time.time() - hb_profiler.step() - htcore.mark_step(sync=True) + for i in self.progress_bar(range(len(timesteps))): + timestep = timesteps[0] + timesteps = torch.roll(timesteps, shifts=-1, dims=0) - hb_profiler.stop() + if use_warmup_inference_steps and i == throughput_warmup_steps and j == num_batches - 1: + t1 = time.time() - t1 = warmup_inference_steps_time_adjustment(t1, t1, num_inference_steps, throughput_warmup_steps) - speed_metrics_prefix = "generation" - speed_measures = speed_metrics( - split=speed_metrics_prefix, - start_time=t0, - num_samples=num_batches * batch_size, - num_steps=num_batches * batch_size * num_inference_steps, - start_time_after_warmup=t1, + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch + ) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_batch = timestep.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer_hpu( + latent_model_input, + timestep_batch, + text_embeddings_batch, + pooled_prompt_embeddings_batch, + self.joint_attention_kwargs, ) - logger.info(f"Speed metrics: {speed_measures}") + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents_batch.dtype + latents_batch = self.scheduler.step(noise_pred, timestep, latents_batch, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) + + latents_batch = callback_outputs.pop("latents", latents_batch) + + _prompt_embeds = callback_outputs.pop("prompt_embeds", None) + _negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", None) + if _prompt_embeds is not None and _negative_prompt_embeds is not None: + text_embeddings_batch = torch.cat([_negative_prompt_embeds, _prompt_embeds]) + _pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", None) + _negative_pooled_prompt_embeds = callback_outputs.pop("negative_pooled_prompt_embeds", None) + if _pooled_prompt_embeds is not None and _negative_pooled_prompt_embeds is not None: + pooled_prompt_embeddings_batch = torch.cat( + [_negative_pooled_prompt_embeds, _pooled_prompt_embeds] + ) + + hb_profiler.step() + htcore.mark_step(sync=True) + if output_type == "latent": - image = latents + image = latents_batch else: - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - - image = self.vae.decode(latents, return_dict=False)[0] + latents_batch = (latents_batch / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents_batch, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) - # Offload all models - self.maybe_free_model_hooks() + outputs["images"].append(image) + + # End of Denoising loop + + hb_profiler.stop() + + ht.hpu.synchronize() + speed_metrics_prefix = "generation" + if use_warmup_inference_steps: + t1 = warmup_inference_steps_time_adjustment(t1, t1, num_inference_steps, throughput_warmup_steps) + speed_measures = speed_metrics( + split=speed_metrics_prefix, + start_time=t0, + num_samples=batch_size + if t1 == t0 or use_warmup_inference_steps + else (num_batches - throughput_warmup_steps) * batch_size, + num_steps=batch_size * num_inference_steps + if use_warmup_inference_steps + else (num_batches - throughput_warmup_steps) * batch_size * num_inference_steps, + start_time_after_warmup=t1, + ) + logger.info(f"Speed metrics: {speed_measures}") + + if quant_mode == "measure": + from neural_compressor.torch.quantization import finalize_calibration + + finalize_calibration(self.transformer) - if not return_dict: - return (image,) + # 8 Output Images + # Remove dummy generations if needed + if num_dummy_samples > 0: + outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples] - return GaudiStableDiffusion3PipelineOutput( - images=image, - throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"], + # Process generated images + for i, image in enumerate(outputs["images"][:]): + if i == 0: + outputs["images"].clear() + + # image = self.image_processor.postprocess(image, output_type=output_type) + + if output_type == "pil" and isinstance(image, list): + outputs["images"] += image + elif output_type in ["np", "numpy"] and isinstance(image, np.ndarray): + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = np.concatenate((outputs["images"], image), axis=0) + else: + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = torch.cat((outputs["images"], image), 0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return outputs["images"] + + return GaudiStableDiffusion3PipelineOutput( + images=outputs["images"], + throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"], + ) + + @torch.no_grad() + def transformer_hpu( + self, + latent_model_input, + timestep, + text_embeddings_batch, + pooled_prompt_embeddings_batch, + joint_attention_kwargs, + ): + if self.use_hpu_graphs: + return self.capture_replay( + latent_model_input, + timestep, + text_embeddings_batch, + pooled_prompt_embeddings_batch, + joint_attention_kwargs, ) + else: + return self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=text_embeddings_batch, + pooled_projections=pooled_prompt_embeddings_batch, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=False, + )[0] + + @torch.no_grad() + def capture_replay( + self, + latent_model_input, + timestep, + encoder_hidden_states, + pooled_prompt_embeddings_batch, + joint_attention_kwargs, + ): + inputs = [ + latent_model_input, + timestep, + encoder_hidden_states, + pooled_prompt_embeddings_batch, + joint_attention_kwargs, + ] + h = self.ht.hpu.graphs.input_hash(inputs) + cached = self.cache.get(h) + + if cached is None: + # Capture the graph and cache it + with self.ht.hpu.stream(self.hpu_stream): + graph = self.ht.hpu.HPUGraph() + graph.capture_begin() + + outputs = self.transformer( + hidden_states=inputs[0], + timestep=inputs[1], + encoder_hidden_states=inputs[2], + pooled_projections=inputs[3], + joint_attention_kwargs=inputs[4], + return_dict=False, + )[0] + + graph.capture_end() + graph_inputs = inputs + graph_outputs = outputs + self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph) + return outputs + + # Replay the cached graph with updated inputs + self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs) + cached.graph.replay() + self.ht.core.hpu.default_stream().synchronize() + + return cached.graph_outputs From 5755dbd4f1bd09bca6e11dc82ecdc57a2b0b9a56 Mon Sep 17 00:00:00 2001 From: deepak-gowda-narayana Date: Tue, 21 Jan 2025 20:57:19 +0000 Subject: [PATCH 2/6] Add lora_scale support --- .../text_to_image_generation.py | 9 +++++++++ .../pipeline_stable_diffusion_3.py | 19 ++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index 8fc4e0de4c..f42f7ca981 100755 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -305,6 +305,12 @@ def main(): default=None, help="The file with prompts (for large number of images generation).", ) + parser.add_argument( + "--lora_scale", + type=float, + default=None, + help="A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.", + ) args = parser.parse_args() if args.optimize and not args.use_habana: @@ -379,6 +385,9 @@ def main(): if args.throughput_warmup_steps is not None: kwargs_call["throughput_warmup_steps"] = args.throughput_warmup_steps + + if args.lora_scale is not None: + kwargs_call["lora_scale"] = args.lora_scale negative_prompts = args.negative_prompts if args.distributed: diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 9f178deaa4..da40fefd19 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -80,7 +80,11 @@ class GaudiStableDiffusion3PipelineOutput(BaseOutput): # ToDo: Look into FusedJointAttnProcessor2_0 usage for sd3 pipeline, and check its perf using fused sdpa class GaudiJointAttnProcessor2_0: - """Attention processor used typically in processing the SD3-like self-attention projections.""" + """Attention processor used typically in processing the SD3-like self-attention projections. + Copied from JointAttnProcessor2_0.forward: https://github.com/huggingface/diffusers/blob/89e4d6219805975bd7d253a267e1951badc9f1c0/src/diffusers/models/attention_processor.py + The only differences are: + - applied Fused SDPA from Habana's framework. + """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -490,6 +494,14 @@ def __call__( import os quant_config_path = os.getenv("QUANT_CONFIG") + + if not quant_config_path: + raise ImportError( + "Error: QUANT_CONFIG path is not defined. Please define path to quantization configuration JSON file." + ) + elif not os.path.isfile(quant_config_path): + raise ImportError(f"Error: QUANT_CONFIG path '{quant_config_path}' is not valid") + htcore.hpu_set_env() from neural_compressor.torch.quantization import FP8Config, convert, prepare @@ -537,6 +549,10 @@ def __call__( device = self._execution_device + lora_scale = ( + kwargs.get("lora_scale", None) if kwargs is not None else None + ) + ( prompt_embeds, negative_prompt_embeds, @@ -558,6 +574,7 @@ def __call__( clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + lora_scale=lora_scale, ) # 4. Prepare timesteps From f9751612d9994e21901ac6d1c450d7c9876e239a Mon Sep 17 00:00:00 2001 From: Daniel Socek Date: Tue, 21 Jan 2025 22:44:19 +0000 Subject: [PATCH 3/6] Fix bf16 sdp and autocast context Signed-off-by: Daniel Socek --- .../pipeline_stable_diffusion_3.py | 526 +++++++++--------- 1 file changed, 267 insertions(+), 259 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index da40fefd19..1641cdd68d 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -489,303 +489,311 @@ def __call__( import habana_frameworks.torch as ht import habana_frameworks.torch.core as htcore - quant_mode = kwargs.get("quant_mode", None) - if quant_mode == "measure" or quant_mode == "quantize": - import os - - quant_config_path = os.getenv("QUANT_CONFIG") - - if not quant_config_path: - raise ImportError( - "Error: QUANT_CONFIG path is not defined. Please define path to quantization configuration JSON file." - ) - elif not os.path.isfile(quant_config_path): - raise ImportError(f"Error: QUANT_CONFIG path '{quant_config_path}' is not valid") - - htcore.hpu_set_env() - from neural_compressor.torch.quantization import FP8Config, convert, prepare - - config = FP8Config.from_json_file(quant_config_path) - if config.measure: - self.transformer = prepare(self.transformer, config) - elif config.quantize: - self.transformer = convert(self.transformer, config) - htcore.hpu_initialize(self.transformer, mark_only_scales_as_const=True) - - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - prompt_3, - height, - width, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - negative_prompt_3=negative_prompt_3, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - max_sequence_length=max_sequence_length, - ) - - self._guidance_scale = guidance_scale - self._clip_skip = clip_skip - self._joint_attention_kwargs = joint_attention_kwargs - self._interrupt = False - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - num_prompts = 1 - elif prompt is not None and isinstance(prompt, list): - num_prompts = len(prompt) - else: - num_prompts = prompt_embeds.shape[0] - num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size) - - device = self._execution_device - - lora_scale = ( - kwargs.get("lora_scale", None) if kwargs is not None else None - ) - - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_3=prompt_3, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - negative_prompt_3=negative_prompt_3, - do_classifier_free_guidance=self.do_classifier_free_guidance, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - device=device, - clip_skip=self.clip_skip, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) - - # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( - num_prompts * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) + # Set dtype to BF16 only if --bf16 is used, else use device's default autocast precision + # When --bf16 is used, bf16_full_eval=True, which disables use_torch_autocast + with torch.autocast( + device_type="hpu", + enabled=self.gaudi_config.use_torch_autocast, + dtype=torch.bfloat16 if not self.gaudi_config.use_torch_autocast else None + ): + + quant_mode = kwargs.get("quant_mode", None) + if quant_mode == "measure" or quant_mode == "quantize": + import os + + quant_config_path = os.getenv("QUANT_CONFIG") + + if not quant_config_path: + raise ImportError( + "Error: QUANT_CONFIG path is not defined. Please define path to quantization configuration JSON file." + ) + elif not os.path.isfile(quant_config_path): + raise ImportError(f"Error: QUANT_CONFIG path '{quant_config_path}' is not valid") + + htcore.hpu_set_env() + from neural_compressor.torch.quantization import FP8Config, convert, prepare + + config = FP8Config.from_json_file(quant_config_path) + if config.measure: + self.transformer = prepare(self.transformer, config) + elif config.quantize: + self.transformer = convert(self.transformer, config) + htcore.hpu_initialize(self.transformer, mark_only_scales_as_const=True) + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) - logger.info( - f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt," - f" {batch_size} sample(s) per batch, {num_batches} total batch(es)." - ) - if num_batches < 3: - logger.warning("The first two iterations are slower so it is recommended to feed more batches.") + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False - throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3) - use_warmup_inference_steps = ( - num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps - ) - - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index() + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + num_prompts = 1 + elif prompt is not None and isinstance(prompt, list): + num_prompts = len(prompt) + else: + num_prompts = prompt_embeds.shape[0] + num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size) - hb_profiler = HabanaProfile( - warmup=profiling_warmup_steps, - active=profiling_steps, - record_shapes=False, - ) + device = self._execution_device - hb_profiler.start() + lora_scale = ( + kwargs.get("lora_scale", None) if kwargs is not None else None + ) - # 6. Split Input data to batches (HPU-specific step) - latents_batches, text_embeddings_batches, pooled_prompt_embeddings_batches, num_dummy_samples = ( - self._split_inputs_into_batches( - batch_size, - latents, + ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, ) - ) - - outputs = { - "images": [], - } - for block in self.transformer.transformer_blocks: - block.attn.processor = GaudiJointAttnProcessor2_0() - ht.hpu.synchronize() + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + num_prompts * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) - t0 = time.time() - t1 = t0 + logger.info( + f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt," + f" {batch_size} sample(s) per batch, {num_batches} total batch(es)." + ) + if num_batches < 3: + logger.warning("The first two iterations are slower so it is recommended to feed more batches.") - # 7. Denoising loop - for j in range(num_batches): - latents_batch = latents_batches[0] - latents_batches = torch.roll(latents_batches, shifts=-1, dims=0) - text_embeddings_batch = text_embeddings_batches[0] - text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0) - pooled_prompt_embeddings_batch = pooled_prompt_embeddings_batches[0] - pooled_prompt_embeddings_batches = torch.roll(pooled_prompt_embeddings_batches, shifts=-1, dims=0) + throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3) + use_warmup_inference_steps = ( + num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps + ) - if hasattr(self.scheduler, "_init_step_index"): - # Reset scheduler step index for next batch - self.scheduler.timesteps = timesteps - self.scheduler._init_step_index(timesteps[0]) + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index() - # Throughput is calculated after warmup iterations - if j == throughput_warmup_steps: - t1 = time.time() + hb_profiler = HabanaProfile( + warmup=profiling_warmup_steps, + active=profiling_steps, + record_shapes=False, + ) - for i in self.progress_bar(range(len(timesteps))): - timestep = timesteps[0] - timesteps = torch.roll(timesteps, shifts=-1, dims=0) + hb_profiler.start() + + # 6. Split Input data to batches (HPU-specific step) + latents_batches, text_embeddings_batches, pooled_prompt_embeddings_batches, num_dummy_samples = ( + self._split_inputs_into_batches( + batch_size, + latents, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + ) - if use_warmup_inference_steps and i == throughput_warmup_steps and j == num_batches - 1: + outputs = { + "images": [], + } + + for block in self.transformer.transformer_blocks: + block.attn.processor = GaudiJointAttnProcessor2_0() + ht.hpu.synchronize() + + t0 = time.time() + t1 = t0 + + # 7. Denoising loop + for j in range(num_batches): + latents_batch = latents_batches[0] + latents_batches = torch.roll(latents_batches, shifts=-1, dims=0) + text_embeddings_batch = text_embeddings_batches[0] + text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0) + pooled_prompt_embeddings_batch = pooled_prompt_embeddings_batches[0] + pooled_prompt_embeddings_batches = torch.roll(pooled_prompt_embeddings_batches, shifts=-1, dims=0) + + if hasattr(self.scheduler, "_init_step_index"): + # Reset scheduler step index for next batch + self.scheduler.timesteps = timesteps + self.scheduler._init_step_index(timesteps[0]) + + # Throughput is calculated after warmup iterations + if j == throughput_warmup_steps: t1 = time.time() - if self.interrupt: - continue + for i in self.progress_bar(range(len(timesteps))): + timestep = timesteps[0] + timesteps = torch.roll(timesteps, shifts=-1, dims=0) - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch - ) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep_batch = timestep.expand(latent_model_input.shape[0]) - - noise_pred = self.transformer_hpu( - latent_model_input, - timestep_batch, - text_embeddings_batch, - pooled_prompt_embeddings_batch, - self.joint_attention_kwargs, - ) + if use_warmup_inference_steps and i == throughput_warmup_steps and j == num_batches - 1: + t1 = time.time() - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + if self.interrupt: + continue - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents_batch.dtype - latents_batch = self.scheduler.step(noise_pred, timestep, latents_batch, return_dict=False)[0] + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch + ) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_batch = timestep.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer_hpu( + latent_model_input, + timestep_batch, + text_embeddings_batch, + pooled_prompt_embeddings_batch, + self.joint_attention_kwargs, + ) - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - latents_batch = callback_outputs.pop("latents", latents_batch) + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents_batch.dtype + latents_batch = self.scheduler.step(noise_pred, timestep, latents_batch, return_dict=False)[0] - _prompt_embeds = callback_outputs.pop("prompt_embeds", None) - _negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", None) - if _prompt_embeds is not None and _negative_prompt_embeds is not None: - text_embeddings_batch = torch.cat([_negative_prompt_embeds, _prompt_embeds]) - _pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", None) - _negative_pooled_prompt_embeds = callback_outputs.pop("negative_pooled_prompt_embeds", None) - if _pooled_prompt_embeds is not None and _negative_pooled_prompt_embeds is not None: - pooled_prompt_embeddings_batch = torch.cat( - [_negative_pooled_prompt_embeds, _pooled_prompt_embeds] - ) + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) - hb_profiler.step() - htcore.mark_step(sync=True) + latents_batch = callback_outputs.pop("latents", latents_batch) - if output_type == "latent": - image = latents_batch + _prompt_embeds = callback_outputs.pop("prompt_embeds", None) + _negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", None) + if _prompt_embeds is not None and _negative_prompt_embeds is not None: + text_embeddings_batch = torch.cat([_negative_prompt_embeds, _prompt_embeds]) + _pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", None) + _negative_pooled_prompt_embeds = callback_outputs.pop("negative_pooled_prompt_embeds", None) + if _pooled_prompt_embeds is not None and _negative_pooled_prompt_embeds is not None: + pooled_prompt_embeddings_batch = torch.cat( + [_negative_pooled_prompt_embeds, _pooled_prompt_embeds] + ) - else: - latents_batch = (latents_batch / self.vae.config.scaling_factor) + self.vae.config.shift_factor - image = self.vae.decode(latents_batch, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) - - outputs["images"].append(image) - - # End of Denoising loop - - hb_profiler.stop() - - ht.hpu.synchronize() - speed_metrics_prefix = "generation" - if use_warmup_inference_steps: - t1 = warmup_inference_steps_time_adjustment(t1, t1, num_inference_steps, throughput_warmup_steps) - speed_measures = speed_metrics( - split=speed_metrics_prefix, - start_time=t0, - num_samples=batch_size - if t1 == t0 or use_warmup_inference_steps - else (num_batches - throughput_warmup_steps) * batch_size, - num_steps=batch_size * num_inference_steps - if use_warmup_inference_steps - else (num_batches - throughput_warmup_steps) * batch_size * num_inference_steps, - start_time_after_warmup=t1, - ) - logger.info(f"Speed metrics: {speed_measures}") + hb_profiler.step() + htcore.mark_step(sync=True) - if quant_mode == "measure": - from neural_compressor.torch.quantization import finalize_calibration + if output_type == "latent": + image = latents_batch - finalize_calibration(self.transformer) + else: + latents_batch = (latents_batch / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents_batch, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + outputs["images"].append(image) + + # End of Denoising loop + + hb_profiler.stop() + + ht.hpu.synchronize() + speed_metrics_prefix = "generation" + if use_warmup_inference_steps: + t1 = warmup_inference_steps_time_adjustment(t1, t1, num_inference_steps, throughput_warmup_steps) + speed_measures = speed_metrics( + split=speed_metrics_prefix, + start_time=t0, + num_samples=batch_size + if t1 == t0 or use_warmup_inference_steps + else (num_batches - throughput_warmup_steps) * batch_size, + num_steps=batch_size * num_inference_steps + if use_warmup_inference_steps + else (num_batches - throughput_warmup_steps) * batch_size * num_inference_steps, + start_time_after_warmup=t1, + ) + logger.info(f"Speed metrics: {speed_measures}") - # 8 Output Images - # Remove dummy generations if needed - if num_dummy_samples > 0: - outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples] + if quant_mode == "measure": + from neural_compressor.torch.quantization import finalize_calibration - # Process generated images - for i, image in enumerate(outputs["images"][:]): - if i == 0: - outputs["images"].clear() + finalize_calibration(self.transformer) - # image = self.image_processor.postprocess(image, output_type=output_type) + # 8 Output Images + # Remove dummy generations if needed + if num_dummy_samples > 0: + outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples] - if output_type == "pil" and isinstance(image, list): - outputs["images"] += image - elif output_type in ["np", "numpy"] and isinstance(image, np.ndarray): - if len(outputs["images"]) == 0: - outputs["images"] = image - else: - outputs["images"] = np.concatenate((outputs["images"], image), axis=0) - else: - if len(outputs["images"]) == 0: - outputs["images"] = image + # Process generated images + for i, image in enumerate(outputs["images"][:]): + if i == 0: + outputs["images"].clear() + + # image = self.image_processor.postprocess(image, output_type=output_type) + + if output_type == "pil" and isinstance(image, list): + outputs["images"] += image + elif output_type in ["np", "numpy"] and isinstance(image, np.ndarray): + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = np.concatenate((outputs["images"], image), axis=0) else: - outputs["images"] = torch.cat((outputs["images"], image), 0) + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = torch.cat((outputs["images"], image), 0) - # Offload all models - self.maybe_free_model_hooks() + # Offload all models + self.maybe_free_model_hooks() - if not return_dict: - return outputs["images"] + if not return_dict: + return outputs["images"] - return GaudiStableDiffusion3PipelineOutput( - images=outputs["images"], - throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"], - ) + return GaudiStableDiffusion3PipelineOutput( + images=outputs["images"], + throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"], + ) @torch.no_grad() def transformer_hpu( From 3df23533990c88c352771b814c4572ad5606d7d4 Mon Sep 17 00:00:00 2001 From: Daniel Socek Date: Wed, 29 Jan 2025 17:50:26 +0000 Subject: [PATCH 4/6] Rebase and fix style Signed-off-by: Daniel Socek --- examples/stable-diffusion/text_to_image_generation.py | 2 +- .../stable_diffusion_3/pipeline_stable_diffusion_3.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index f42f7ca981..b32fc5c3f6 100755 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -385,7 +385,7 @@ def main(): if args.throughput_warmup_steps is not None: kwargs_call["throughput_warmup_steps"] = args.throughput_warmup_steps - + if args.lora_scale is not None: kwargs_call["lora_scale"] = args.lora_scale diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 1641cdd68d..731e0434bf 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -494,9 +494,8 @@ def __call__( with torch.autocast( device_type="hpu", enabled=self.gaudi_config.use_torch_autocast, - dtype=torch.bfloat16 if not self.gaudi_config.use_torch_autocast else None + dtype=torch.bfloat16 if not self.gaudi_config.use_torch_autocast else None, ): - quant_mode = kwargs.get("quant_mode", None) if quant_mode == "measure" or quant_mode == "quantize": import os @@ -557,9 +556,7 @@ def __call__( device = self._execution_device - lora_scale = ( - kwargs.get("lora_scale", None) if kwargs is not None else None - ) + lora_scale = kwargs.get("lora_scale", None) if kwargs is not None else None ( prompt_embeds, From a6f84e8c89d251a5371b1fd807b66e9a94224b0e Mon Sep 17 00:00:00 2001 From: Daniel Socek Date: Thu, 30 Jan 2025 13:24:34 +0100 Subject: [PATCH 5/6] Update examples/stable-diffusion/README.md Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- examples/stable-diffusion/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index 3adca760da..a5f8d611be 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -338,7 +338,7 @@ python text_to_image_generation.py \ --use_hpu_graphs \ --gaudi_config Habana/stable-diffusion \ --sdp_on_bf16 \ - --bf16 + --bf16 \ --quant_mode measure ``` From 397803e608850752bf22a3c7be6bf5e28043653d Mon Sep 17 00:00:00 2001 From: Daniel Socek Date: Thu, 30 Jan 2025 13:24:46 +0100 Subject: [PATCH 6/6] Update examples/stable-diffusion/README.md Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- examples/stable-diffusion/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index a5f8d611be..98e818a5c6 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -358,7 +358,7 @@ python text_to_image_generation.py \ --use_hpu_graphs \ --gaudi_config Habana/stable-diffusion \ --sdp_on_bf16 \ - --bf16 + --bf16 \ --quant_mode quantize ```