Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why Does tile_decode Only Tile Along the Temporal Dimension? #8

Open
WilliamLLee opened this issue Feb 7, 2025 · 2 comments
Open

Comments

@WilliamLLee
Copy link

In the tile_decode implementation, tiling is performed only along the temporal dimension (t), while the spatial dimensions (h, w) are not tiled. This is different from implementations like cogvideox in diffusers, where tiling is done in both spatial and temporal dimensions.

def tile_decode(self, x):
    b, c, t, h, w = x.shape
    
    start_end = self.build_chunk_start_end(t, decoder_mode=True)
    
    result = []
    for idx, (start, end) in enumerate(start_end):
        self._set_first_chunk(idx == 0) 
        
        if end + 1 < t:
            chunk = x[:, :, start:end+1, :, :]
        else:
            chunk = x[:, :, start:end, :, :]
            
        if self.use_quant_layer:
            chunk = self.post_quant_conv(chunk)
        chunk = self.decoder(chunk)[0]
        
        if end + 1 < t:
            chunk = chunk[:, :, :-self.temporal_uptimes]
            result.append(chunk.clone())
        else:
            result.append(chunk.clone())
        
    return torch.cat(result, dim=2)

Why does this implementation only perform tiling along the temporal dimension, while the spatial dimensions remain unchanged? In comparison, cogvideox from diffusers performs tiling along both spatial and temporal dimensions.

Would it be possible to extend this implementation to support spatial tiling as well?

@qqingzheng
Copy link
Collaborator

qqingzheng commented Feb 7, 2025

Q1: Why does our implementation only perform tiling along the temporal dimension?
By taking advantage of causal convolution, we can tile in the time domain without any loss. In simple terms, the results from direct inference and tiled inference are exactly the same. Plus, after tiling in the time domain, the memory usage dropped enough to meet our training needs, so we didn’t need to tile further.
Note: You can see the details in our report.

Q2: Can it extend to spatial tiling?
Yes. It can also be extended to spatial tiling, similar to methods like CogVideoX. However, it may disrupt the latent space, posing unknown risks for diffusion training and degrading video reconstruction quality. Since temporal tiling already meets our memory requirements for video generation pretraining, we chose not to implement spatial tiling.

Thank you for your interest in our project.

Related code to lossless temporal tiling taking the advantage of causal property of causal convolution:
https://github.com/PKU-YuanGroup/WF-VAE/blob/main/causalvideovae/model/modules/conv.py#L89

@WilliamLLee
Copy link
Author

Thank you for your explanation, much appreciated!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants