-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds sdxl's VAE decoder implementation #653
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
a2fe5c8
Adds sdxl's VAE decoder implementation
IanNod 7e6c462
Assorted fixes and addressing comments
IanNod 7b310e5
Removes vae test pytest dependencies as it didn't order tests properly
IanNod 8199ce8
Black formating
IanNod 06df348
Update to address import_hf_dataset changes
IanNod 17757dd
Changes testing to use MI300 runner
IanNod fb95d88
Merge branch 'nod-ai:main' into vae
IanNod File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ datasets==3.0.0 | |
parameterized | ||
pytest==8.0.0 | ||
pytest-html | ||
diffusers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# VAE decoder | ||
|
||
This is vae implemented in the style used for SDXL and referenced from diffusers implementation. | ||
|
||
## Preparing dataset | ||
If not sharding or quantizing, the official model can be imported as from huggingface: | ||
|
||
``` | ||
model_dir=$(huggingface-cli download \ | ||
stabilityai/stable-diffusion-xl-base-1.0 \ | ||
vae/config.json vae/diffusion_pytorch_model.safetensors) | ||
|
||
python -m sharktank.models.punet.tools.import_hf_dataset \ | ||
--params $model_dir/vae/diffusion_pytorch_model.safetensors | ||
--config-json $model_dir/vae/config.json --output-irpa-file ~/models/vae.irpa | ||
``` | ||
|
||
# Run Vae decoder model eager mode | ||
``` | ||
python -m sharktank.models.vae.tools.run_vae --irpa-file ~/models/vae.irpa --device cpu | ||
``` | ||
|
||
## License | ||
|
||
Significant portions of this implementation were derived from diffusers, | ||
licensed under Apache2: https://github.com/huggingface/diffusers | ||
While much was a simple reverse engineering of the config.json and parameters, | ||
code was taken where appropriate. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
# | ||
# Significant portions of this implementation were derived from diffusers, | ||
# licensed under Apache2: https://github.com/huggingface/diffusers | ||
# While much was a simple reverse engineering of the config.json and parameters, | ||
# code was taken where appropriate. | ||
|
||
from typing import List, Optional, Sequence, Tuple, Union | ||
|
||
from dataclasses import dataclass | ||
import inspect | ||
import warnings | ||
|
||
__all__ = [ | ||
"HParams", | ||
] | ||
|
||
|
||
@dataclass | ||
class HParams: | ||
# Per block sequences. These are normalized from either an int (dubplicated | ||
# to the number of down_blocks or a list. | ||
layers_per_block: Tuple[int] | ||
|
||
act_fn: str = "silu" | ||
block_out_channels: Sequence[int] = (128, 256, 512, 512) | ||
in_channels: int = 3 | ||
up_block_types: Sequence[str] = ( | ||
"UpDecoderBlock2D", | ||
"UpDecoderBlock2D", | ||
"UpDecoderBlock2D", | ||
"UpDecoderBlock2D", | ||
) | ||
layers_per_block: int = 2 | ||
norm_num_groups: int = 32 | ||
scaling_factor: float = 0.13025 | ||
|
||
def assert_default_values(self, attr_names: Sequence[str]): | ||
for name in attr_names: | ||
actual = getattr(self, name) | ||
required = getattr(HParams, name) | ||
if actual != required: | ||
raise ValueError( | ||
f"NYI: HParams.{name} != {required!r} (got {actual!r})" | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, d: dict): | ||
if "layers_per_block" not in d: | ||
d["layers_per_block"] = 2 | ||
|
||
allowed = inspect.signature(cls).parameters | ||
declared_kwargs = {k: v for k, v in d.items() if k in allowed} | ||
extra_kwargs = [k for k in d.keys() if k not in allowed] | ||
if extra_kwargs: | ||
warnings.warn(f"Unhandled vae.HParams: {extra_kwargs}") | ||
return cls(**declared_kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
from typing import Optional, Sequence, Tuple | ||
|
||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from sharktank import ops | ||
from sharktank.layers import * | ||
from sharktank.types import * | ||
from sharktank.models.punet.layers import ( | ||
ResnetBlock2D, | ||
Upsample2D, | ||
GroupNormLayer, | ||
AttentionLayer, | ||
) | ||
from .config import * | ||
|
||
|
||
__all__ = ["UNetMidBlock2D", "UpDecoderBlock2D", "AttentionLayer"] | ||
|
||
# TODO Remove and integrate with punet AttentionLayer | ||
class AttentionLayer(ThetaLayer): | ||
def __init__( | ||
self, | ||
theta: Theta, | ||
heads: int, # in_channels // attention_head_dim | ||
dim_head, | ||
rescale_output_factor: float, | ||
eps: float, | ||
norm_num_groups: int, | ||
residual_connection: bool, | ||
): | ||
super().__init__(theta) | ||
self.heads = heads | ||
self.rescale_output_factor = rescale_output_factor | ||
self.residual_connection = residual_connection | ||
|
||
if norm_num_groups is not None: | ||
self.group_norm = GroupNormLayer( | ||
theta("group_norm"), num_groups=norm_num_groups, eps=eps | ||
) | ||
else: | ||
self.group_norm = None | ||
|
||
self.norm_q = None | ||
self.norm_k = None | ||
|
||
self.norm_cross = None | ||
self.to_q = LinearLayer(theta("to_q")) | ||
self.to_k = LinearLayer(theta("to_k")) | ||
self.to_v = LinearLayer(theta("to_v")) | ||
|
||
self.added_proj_bias = True | ||
self.to_out = LinearLayer(theta("to_out")(0)) | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
encoder_hidden_states: Optional[torch.Tensor] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
): | ||
residual = hidden_states | ||
|
||
input_ndim = hidden_states.ndim | ||
if input_ndim == 4: | ||
batch_size, channel, height, width = hidden_states.shape | ||
hidden_states = hidden_states.view( | ||
batch_size, channel, height * width | ||
).transpose(1, 2) | ||
|
||
batch_size, sequence_length, _ = ( | ||
hidden_states.shape | ||
if encoder_hidden_states is None | ||
else encoder_hidden_states.shape | ||
) | ||
|
||
if self.group_norm is not None: | ||
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose( | ||
1, 2 | ||
) | ||
|
||
query = self.to_q(hidden_states) | ||
|
||
if encoder_hidden_states is None: | ||
encoder_hidden_states = hidden_states | ||
|
||
key = self.to_k(encoder_hidden_states) | ||
value = self.to_v(encoder_hidden_states) | ||
inner_dim = key.shape[-1] | ||
head_dim = inner_dim // self.heads | ||
|
||
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) | ||
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) | ||
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) | ||
|
||
if self.norm_q is not None: | ||
query = self.norm_q(query) | ||
if self.norm_k is not None: | ||
key = self.norm_k(key) | ||
|
||
hidden_states = ops.scaled_dot_product_attention( | ||
query, key, value, a=attention_mask | ||
) | ||
hidden_states = hidden_states.transpose(1, 2).reshape( | ||
batch_size, -1, self.heads * head_dim | ||
) | ||
|
||
# linear proj | ||
hidden_states = self.to_out(hidden_states) | ||
|
||
if input_ndim == 4: | ||
hidden_states = hidden_states.transpose(-1, -2).reshape( | ||
batch_size, channel, height, width | ||
) | ||
|
||
if self.residual_connection: | ||
hidden_states = hidden_states + residual | ||
|
||
hidden_states = hidden_states / self.rescale_output_factor | ||
return hidden_states | ||
|
||
|
||
class UpDecoderBlock2D(ThetaLayer): | ||
def __init__( | ||
self, | ||
theta: Theta, | ||
*, | ||
num_layers: int, | ||
resnet_eps: float, | ||
resnet_act_fn: str, | ||
resnet_groups: int, | ||
resnet_out_scale_factor: Optional[float], | ||
resnet_time_scale_shift: str, | ||
temb_channels: int, | ||
dropout: float, | ||
add_upsample: bool, | ||
): | ||
super().__init__(theta) | ||
resnets = [] | ||
|
||
for i in range(num_layers): | ||
resnets.append( | ||
ResnetBlock2D( | ||
theta("resnets")(i), | ||
groups=resnet_groups, | ||
eps=resnet_eps, | ||
non_linearity=resnet_act_fn, | ||
output_scale_factor=resnet_out_scale_factor, | ||
time_embedding_norm=resnet_time_scale_shift, | ||
temb_channels=temb_channels, | ||
dropout=dropout, | ||
) | ||
) | ||
self.resnets = nn.ModuleList(resnets) | ||
if add_upsample: | ||
self.upsamplers = nn.ModuleList( | ||
[Upsample2D(theta("upsamplers")("0"), padding=1)] | ||
) | ||
else: | ||
self.upsamplers = None | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
temb: Optional[torch.Tensor] = None, | ||
upsample_size: Optional[int] = None, | ||
) -> torch.Tensor: | ||
for resnet in self.resnets: | ||
hidden_states = resnet(hidden_states, temb=temb) | ||
if self.upsamplers is not None: | ||
for upsampler in self.upsamplers: | ||
hidden_states = upsampler(hidden_states, upsample_size) | ||
return hidden_states | ||
|
||
|
||
class UNetMidBlock2D(ThetaLayer): | ||
def __init__( | ||
self, | ||
theta: Theta, | ||
temb_channels: int, | ||
dropout: float, | ||
num_layers: int, | ||
resnet_eps: float, | ||
resnet_time_scale_shift: str, | ||
resnet_act_fn: str, | ||
resnet_groups: int, | ||
resnet_pre_norm: bool, | ||
add_attention: bool, | ||
attention_head_dim: int, | ||
output_scale_factor: float, | ||
attn_groups: Optional[int] = None, | ||
): | ||
super().__init__(theta) | ||
|
||
resnet_groups = resnet_groups if resnet_time_scale_shift == "default" else None | ||
|
||
# there is always at least one resnet | ||
if resnet_time_scale_shift == "spatial": | ||
# TODO Implement ResnetBlockCondNorm2d block for spatial time scale shift | ||
raise AssertionError(f"ResnetBlockCondNorm2d not yet implemented") | ||
else: | ||
resnets = [ | ||
ResnetBlock2D( | ||
theta("resnets")(0), | ||
groups=resnet_groups, | ||
eps=resnet_eps, | ||
non_linearity=resnet_act_fn, | ||
output_scale_factor=output_scale_factor, | ||
time_embedding_norm=resnet_time_scale_shift, | ||
temb_channels=temb_channels, | ||
dropout=dropout, | ||
) | ||
] | ||
# TODO: loop through num_layers properly. Works for sdxl vae specifically but removed for export reasons | ||
if add_attention: | ||
self.attention = AttentionLayer( | ||
theta("attentions")(0), | ||
heads=1, | ||
dim_head=attention_head_dim, | ||
rescale_output_factor=1.0, | ||
eps=resnet_eps, | ||
norm_num_groups=attn_groups, | ||
residual_connection=True, | ||
) | ||
else: | ||
self.attention = None | ||
|
||
if resnet_time_scale_shift == "spatial": | ||
# TODO Implement ResnetBlock2D for spatial time scale shift support | ||
raise AssertionError( | ||
f"ResnetBlock2D spatial time scale shift not yet implemented" | ||
) | ||
else: | ||
IanNod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
resnets.append( | ||
ResnetBlock2D( | ||
theta("resnets")(1), | ||
groups=resnet_groups, | ||
eps=resnet_eps, | ||
non_linearity=resnet_act_fn, | ||
output_scale_factor=output_scale_factor, | ||
time_embedding_norm=resnet_time_scale_shift, | ||
temb_channels=temb_channels, | ||
dropout=dropout, | ||
) | ||
) | ||
self.resnets = nn.ModuleList(resnets) | ||
|
||
def forward( | ||
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
hidden_states = self.resnets[0](hidden_states, temb) | ||
if self.attention is not None: | ||
hidden_states = self.attention(hidden_states) | ||
hidden_states = self.resnets[1](hidden_states, temb) | ||
return hidden_states |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realized. This should probably be broken up and put in
sharktank/sharktank/layers/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it should. I kept it separate following punet layers and was planning to clean up both in a follow up PR