Skip to content

Commit

Permalink
Allegro VAE fix (#9811)
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
a-r-r-o-w authored Oct 30, 2024
1 parent 0d1d267 commit 9a92b81
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,8 +1091,6 @@ def forward(
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
encoder_local_batch_size: int = 2,
decoder_local_batch_size: int = 2,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
Expand All @@ -1103,18 +1101,14 @@ def forward(
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
generator (`torch.Generator`, *optional*):
PyTorch random number generator.
encoder_local_batch_size (`int`, *optional*, defaults to 2):
Local batch size for the encoder's batch inference.
decoder_local_batch_size (`int`, *optional*, defaults to 2):
Local batch size for the decoder's batch inference.
"""
x = sample
posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample
dec = self.decode(z).sample

if not return_dict:
return (dec,)
Expand Down

0 comments on commit 9a92b81

Please sign in to comment.