Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds sdxl's VAE decoder implementation #653

Merged
merged 7 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,11 @@ jobs:
--with-clip-data \
--with-flux-data \
--with-t5-data \
--with-vae-data \
sharktank/tests/models/clip/clip_test.py \
sharktank/tests/models/t5/t5_test.py \
sharktank/tests/models/flux/flux_test.py \
sharktank/tests/models/vae/vae_test.py \
--durations=0


Expand Down
9 changes: 9 additions & 0 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def pytest_addoption(parser):
),
)

parser.addoption(
"--with-vae-data",
action="store_true",
default=False,
help=(
"Enable tests that use vae data such as models not part of the source code."
),
)

# TODO: Remove all hardcoded paths in CI tests
parser.addoption(
"--llama3-8b-tokenizer-path",
Expand Down
1 change: 1 addition & 0 deletions sharktank/requirements-tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ datasets==3.0.0
parameterized
pytest==8.0.0
pytest-html
diffusers
1 change: 0 additions & 1 deletion sharktank/sharktank/models/punet/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,6 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor) -> torch.Tenso
hidden_states = ops.elementwise(self.nonlinearity, hidden_states)
hidden_states = self.conv1(hidden_states)

assert self.time_emb_proj is not None
if self.time_emb_proj is not None:
temb = ops.elementwise(self.nonlinearity, temb)
temb = self.time_emb_proj(temb)[:, :, None, None]
Expand Down
28 changes: 28 additions & 0 deletions sharktank/sharktank/models/vae/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# VAE decoder

This is vae implemented in the style used for SDXL and referenced from diffusers implementation.

## Preparing dataset
If not sharding or quantizing, the official model can be imported as from huggingface:

```
model_dir=$(huggingface-cli download \
stabilityai/stable-diffusion-xl-base-1.0 \
vae/config.json vae/diffusion_pytorch_model.safetensors)

python -m sharktank.models.punet.tools.import_hf_dataset \
--params $model_dir/vae/diffusion_pytorch_model.safetensors
--config-json $model_dir/vae/config.json --output-irpa-file ~/models/vae.irpa
```

# Run Vae decoder model eager mode
```
python -m sharktank.models.vae.tools.run_vae --irpa-file ~/models/vae.irpa --device cpu
```

## License

Significant portions of this implementation were derived from diffusers,
licensed under Apache2: https://github.com/huggingface/diffusers
While much was a simple reverse engineering of the config.json and parameters,
code was taken where appropriate.
61 changes: 61 additions & 0 deletions sharktank/sharktank/models/vae/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# Significant portions of this implementation were derived from diffusers,
# licensed under Apache2: https://github.com/huggingface/diffusers
# While much was a simple reverse engineering of the config.json and parameters,
# code was taken where appropriate.

from typing import List, Optional, Sequence, Tuple, Union

from dataclasses import dataclass
import inspect
import warnings

__all__ = [
"HParams",
]


@dataclass
class HParams:
# Per block sequences. These are normalized from either an int (dubplicated
# to the number of down_blocks or a list.
layers_per_block: Tuple[int]

act_fn: str = "silu"
block_out_channels: Sequence[int] = (128, 256, 512, 512)
in_channels: int = 3
up_block_types: Sequence[str] = (
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
)
layers_per_block: int = 2
norm_num_groups: int = 32
scaling_factor: float = 0.13025

def assert_default_values(self, attr_names: Sequence[str]):
for name in attr_names:
actual = getattr(self, name)
required = getattr(HParams, name)
if actual != required:
raise ValueError(
f"NYI: HParams.{name} != {required!r} (got {actual!r})"
)

@classmethod
def from_dict(cls, d: dict):
if "layers_per_block" not in d:
d["layers_per_block"] = 2

allowed = inspect.signature(cls).parameters
declared_kwargs = {k: v for k, v in d.items() if k in allowed}
extra_kwargs = [k for k in d.keys() if k not in allowed]
if extra_kwargs:
warnings.warn(f"Unhandled vae.HParams: {extra_kwargs}")
return cls(**declared_kwargs)
261 changes: 261 additions & 0 deletions sharktank/sharktank/models/vae/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# Copyright 2024 Advanced Micro Devices, Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized. This should probably be broken up and put in sharktank/sharktank/layers/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it should. I kept it separate following punet layers and was planning to clean up both in a follow up PR

#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Optional, Sequence, Tuple

import math

import torch
import torch.nn as nn

from sharktank import ops
from sharktank.layers import *
from sharktank.types import *
from sharktank.models.punet.layers import (
ResnetBlock2D,
Upsample2D,
GroupNormLayer,
AttentionLayer,
)
from .config import *


__all__ = ["UNetMidBlock2D", "UpDecoderBlock2D", "AttentionLayer"]

# TODO Remove and integrate with punet AttentionLayer
class AttentionLayer(ThetaLayer):
def __init__(
self,
theta: Theta,
heads: int, # in_channels // attention_head_dim
dim_head,
rescale_output_factor: float,
eps: float,
norm_num_groups: int,
residual_connection: bool,
):
super().__init__(theta)
self.heads = heads
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection

if norm_num_groups is not None:
self.group_norm = GroupNormLayer(
theta("group_norm"), num_groups=norm_num_groups, eps=eps
)
else:
self.group_norm = None

self.norm_q = None
self.norm_k = None

self.norm_cross = None
self.to_q = LinearLayer(theta("to_q"))
self.to_k = LinearLayer(theta("to_k"))
self.to_v = LinearLayer(theta("to_v"))

self.added_proj_bias = True
self.to_out = LinearLayer(theta("to_out")(0))

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
residual = hidden_states

input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)

if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)

query = self.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads

query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)

if self.norm_q is not None:
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k(key)

hidden_states = ops.scaled_dot_product_attention(
query, key, value, a=attention_mask
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, self.heads * head_dim
)

# linear proj
hidden_states = self.to_out(hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)

if self.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / self.rescale_output_factor
return hidden_states


class UpDecoderBlock2D(ThetaLayer):
def __init__(
self,
theta: Theta,
*,
num_layers: int,
resnet_eps: float,
resnet_act_fn: str,
resnet_groups: int,
resnet_out_scale_factor: Optional[float],
resnet_time_scale_shift: str,
temb_channels: int,
dropout: float,
add_upsample: bool,
):
super().__init__(theta)
resnets = []

for i in range(num_layers):
resnets.append(
ResnetBlock2D(
theta("resnets")(i),
groups=resnet_groups,
eps=resnet_eps,
non_linearity=resnet_act_fn,
output_scale_factor=resnet_out_scale_factor,
time_embedding_norm=resnet_time_scale_shift,
temb_channels=temb_channels,
dropout=dropout,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList(
[Upsample2D(theta("upsamplers")("0"), padding=1)]
)
else:
self.upsamplers = None

def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
upsample_size: Optional[int] = None,
) -> torch.Tensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states


class UNetMidBlock2D(ThetaLayer):
def __init__(
self,
theta: Theta,
temb_channels: int,
dropout: float,
num_layers: int,
resnet_eps: float,
resnet_time_scale_shift: str,
resnet_act_fn: str,
resnet_groups: int,
resnet_pre_norm: bool,
add_attention: bool,
attention_head_dim: int,
output_scale_factor: float,
attn_groups: Optional[int] = None,
):
super().__init__(theta)

resnet_groups = resnet_groups if resnet_time_scale_shift == "default" else None

# there is always at least one resnet
if resnet_time_scale_shift == "spatial":
# TODO Implement ResnetBlockCondNorm2d block for spatial time scale shift
raise AssertionError(f"ResnetBlockCondNorm2d not yet implemented")
else:
resnets = [
ResnetBlock2D(
theta("resnets")(0),
groups=resnet_groups,
eps=resnet_eps,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
time_embedding_norm=resnet_time_scale_shift,
temb_channels=temb_channels,
dropout=dropout,
)
]
# TODO: loop through num_layers properly. Works for sdxl vae specifically but removed for export reasons
if add_attention:
self.attention = AttentionLayer(
theta("attentions")(0),
heads=1,
dim_head=attention_head_dim,
rescale_output_factor=1.0,
eps=resnet_eps,
norm_num_groups=attn_groups,
residual_connection=True,
)
else:
self.attention = None

if resnet_time_scale_shift == "spatial":
# TODO Implement ResnetBlock2D for spatial time scale shift support
raise AssertionError(
f"ResnetBlock2D spatial time scale shift not yet implemented"
)
else:
IanNod marked this conversation as resolved.
Show resolved Hide resolved
resnets.append(
ResnetBlock2D(
theta("resnets")(1),
groups=resnet_groups,
eps=resnet_eps,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
time_embedding_norm=resnet_time_scale_shift,
temb_channels=temb_channels,
dropout=dropout,
)
)
self.resnets = nn.ModuleList(resnets)

def forward(
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb)
if self.attention is not None:
hidden_states = self.attention(hidden_states)
hidden_states = self.resnets[1](hidden_states, temb)
return hidden_states
Loading
Loading