Skip to content

Commit

Permalink
SD3 small tweaks for numerics
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 17, 2024
1 parent 6c9d96d commit d7c709e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 12 deletions.
59 changes: 50 additions & 9 deletions models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,16 +445,18 @@ def generate_images(
numpy_images = []

for i in range(batch_count):
generator = torch.random.manual_seed(seed + i)
generator = torch.Generator().manual_seed(int(seed))
shape = (
self.batch_size,
16,
self.height // 8,
self.width // 8,
)
rand_sample = torch.randn(
(
self.batch_size,
16,
self.height // 8,
self.width // 8,
),
shape,
generator=generator,
dtype=torch_dtype,
dtype=torch.float32,
layout=torch.strided,
)
samples.append(
ireert.asdevicearray(
Expand Down Expand Up @@ -499,7 +501,6 @@ def generate_images(
prompt_embeds, pooled_prompt_embeds = self.runners[
"text_encoders"
].ctx.modules.compiled_text_encoder["encode_tokens"](*text_encoders_inputs)

encode_prompts_end = time.time()

for i in range(batch_count):
Expand Down Expand Up @@ -617,11 +618,51 @@ def generate_images(
image.save(img_path)
print(img_path, "saved")
return

def run_diffusers_cpu(
hf_model_name,
prompt,
negative_prompt,
guidance_scale,
seed,
height,
width,
num_inference_steps,
):
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained(hf_model_name, torch_dtype=torch.float32)
pipe = pipe.to("cpu")
generator = torch.Generator().manual_seed(int(seed))

image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width,
generator=generator,
).images[0]
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
image.save(f"diffusers_reference_output_{timestamp}.png")


if __name__ == "__main__":
from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args

if args.compare_vs_torch:
run_diffusers_cpu(
args.hf_model_name,
args.prompt,
args.negative_prompt,
args.guidance_scale,
args.seed,
args.height,
args.width,
args.num_inference_steps,
)
exit()
map = empty_pipe_dict
mlirs = copy.deepcopy(map)
vmfbs = copy.deepcopy(map)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def __init__(
def initialize(self, sample):
step_count = torch.tensor(len(self.timesteps))
timesteps = self.model.timesteps
# ops.trace_tensor("timesteps", self.timesteps)
ops.trace_tensor("sample", sample[:,:,0,0])
return (
sample.type(self.dtype),
sample,
step_count,
timesteps.type(torch.float32),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def forward(self, tokens_g, tokens_l, tokens_t5xxl, neg_g, neg_l, neg_t5):
neg_cond, neg_cond_pool = self.get_cond(neg_l, neg_g, neg_t5)

prompt_embeds = torch.cat([neg_cond, conditioning], dim=0)
pooled_prompt_embeds = torch.cat([cond_pool, neg_cond_pool], dim=0)
pooled_prompt_embeds = torch.cat([neg_cond_pool, cond_pool], dim=0)

return prompt_embeds, pooled_prompt_embeds

Expand Down

0 comments on commit d7c709e

Please sign in to comment.