diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_mlperf.py b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_mlperf.py index 42c703b78b..6681dc113d 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_mlperf.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_mlperf.py @@ -757,7 +757,7 @@ def __call__( throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3) use_warmup_inference_steps = ( - num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps + num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps ) self._num_timesteps = len(timesteps) @@ -769,9 +769,6 @@ def __call__( if j == throughput_warmup_steps: ht.hpu.synchronize() t1 = time.time() - if use_warmup_inference_steps: - ht.hpu.synchronize() - t0_inf = time.time() latents = latents_batches[0] latents_batches = torch.roll(latents_batches, shifts=-1, dims=0) @@ -841,10 +838,9 @@ def __call__( hb_profiler.step() else: for i in range(num_inference_steps): - if use_warmup_inference_steps and i == throughput_warmup_steps: + if use_warmup_inference_steps and i == throughput_warmup_steps and j == num_batches - 1: ht.hpu.synchronize() - t1_inf = time.time() - t1 += t1_inf - t0_inf + t1 = time.time() t = timesteps[0] timesteps = torch.roll(timesteps, shifts=-1, dims=0) @@ -875,10 +871,6 @@ def __call__( callback_on_step_end_tensor_inputs, ) hb_profiler.step() - if use_warmup_inference_steps: - ht.hpu.synchronize() - t1 = warmup_inference_steps_time_adjustment(t1, t1_inf, num_inference_steps, throughput_warmup_steps) - if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast @@ -897,18 +889,20 @@ def __call__( output_images.append(image) + hb_profiler.stop() speed_metrics_prefix = "generation" ht.hpu.synchronize() - + if use_warmup_inference_steps: + ht.hpu.synchronize() + t1 = warmup_inference_steps_time_adjustment(t1, t1, num_inference_steps, throughput_warmup_steps) if t1 == t0 or use_warmup_inference_steps: - num_samples = num_batches * batch_size - num_steps = (num_inference_steps - throughput_warmup_steps) * num_batches * batch_size + num_samples = batch_size + num_steps = batch_size * num_inference_steps else: num_samples = (num_batches - throughput_warmup_steps) * batch_size num_steps = (num_batches - throughput_warmup_steps) * num_inference_steps * batch_size - speed_measures = speed_metrics( split=speed_metrics_prefix, start_time=t0,