Skip to content

Commit

Permalink
Add gradient blending to tile seams in MultiDiffusion.
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanJDick authored and hipsterusername committed Jul 19, 2024
1 parent 97a7f51 commit e16faa6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
_, _, latent_height, latent_width = latents.shape

# Calculate the tile locations to cover the latent-space image.
# TODO(ryand): In the future, we may want to revisit the tile overlap strategy. Things to consider:
# - How much overlap 'context' to provide for each denoising step.
# - How much overlap to use during merging/blending.
# - Should we 'jitter' the tile locations in each step so that the seams are in different places?
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
Expand Down
54 changes: 39 additions & 15 deletions invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def multi_diffusion_denoise(
# full noise. Investigate the history of why this got commented out.
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
assert isinstance(latents, torch.Tensor) # For static type checking.

# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
# cropping into regions.
Expand Down Expand Up @@ -122,29 +123,52 @@ def multi_diffusion_denoise(
control_data=region_conditioning.control_data,
)

# Store the results from the region.
# If two tiles overlap by more than the target overlap amount, crop the left and top edges of the
# affected tiles to achieve the target overlap.
# Build a region_weight matrix that applies gradient blending to the edges of the region.
region = region_conditioning.region
top_adjustment = max(0, region.overlap.top - target_overlap)
left_adjustment = max(0, region.overlap.left - target_overlap)
region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom)
region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right)
merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[
:, :, top_adjustment:, left_adjustment:
]
# For now, we treat every region as having the same weight.
merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0
_, _, region_height, region_width = step_output.prev_sample.shape
region_weight = torch.ones(
(1, 1, region_height, region_width),
dtype=latents.dtype,
device=latents.device,
)
if region.overlap.left > 0:
left_grad = torch.linspace(
0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype
).view((1, 1, 1, -1))
region_weight[:, :, :, : region.overlap.left] *= left_grad
if region.overlap.top > 0:
top_grad = torch.linspace(
0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype
).view((1, 1, -1, 1))
region_weight[:, :, : region.overlap.top, :] *= top_grad
if region.overlap.right > 0:
right_grad = torch.linspace(
1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype
).view((1, 1, 1, -1))
region_weight[:, :, :, -region.overlap.right :] *= right_grad
if region.overlap.bottom > 0:
bottom_grad = torch.linspace(
1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype
).view((1, 1, -1, 1))
region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad

# Update the merged results with the region results.
merged_latents[
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
] += step_output.prev_sample * region_weight
merged_latents_weights[
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
] += region_weight

pred_orig_sample = getattr(step_output, "pred_original_sample", None)
if pred_orig_sample is not None:
# If one region has pred_original_sample, then we can assume that all regions will have it, because
# they all use the same scheduler.
if merged_pred_original is None:
merged_pred_original = torch.zeros_like(latents)
merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[
:, :, top_adjustment:, left_adjustment:
]
merged_pred_original[
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
] += pred_orig_sample

# Normalize the merged results.
latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)
Expand Down

0 comments on commit e16faa6

Please sign in to comment.