-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add backend functions and classes for Flux implementation, Update the…
… way flux encoders/tokenizers are loaded for prompt encoding, Update way flux vae is loaded
- Loading branch information
1 parent
53052cf
commit f4f5c46
Showing
19 changed files
with
1,340 additions
and
197 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch | ||
from einops import rearrange | ||
from torch import Tensor | ||
|
||
|
||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: | ||
q, k = apply_rope(q, k, pe) | ||
|
||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v) | ||
x = rearrange(x, "B H L D -> B L (H D)") | ||
|
||
return x | ||
|
||
|
||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor: | ||
assert dim % 2 == 0 | ||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim | ||
omega = 1.0 / (theta**scale) | ||
out = torch.einsum("...n,d->...nd", pos, omega) | ||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) | ||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) | ||
return out.float() | ||
|
||
|
||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: | ||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) | ||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) | ||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] | ||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] | ||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
from torch import Tensor, nn | ||
|
||
from invokeai.backend.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, | ||
MLPEmbedder, SingleStreamBlock, | ||
timestep_embedding) | ||
|
||
@dataclass | ||
class FluxParams: | ||
in_channels: int | ||
vec_in_dim: int | ||
context_in_dim: int | ||
hidden_size: int | ||
mlp_ratio: float | ||
num_heads: int | ||
depth: int | ||
depth_single_blocks: int | ||
axes_dim: list[int] | ||
theta: int | ||
qkv_bias: bool | ||
guidance_embed: bool | ||
|
||
|
||
class Flux(nn.Module): | ||
""" | ||
Transformer model for flow matching on sequences. | ||
""" | ||
|
||
def __init__(self, params: FluxParams): | ||
super().__init__() | ||
|
||
self.params = params | ||
self.in_channels = params.in_channels | ||
self.out_channels = self.in_channels | ||
if params.hidden_size % params.num_heads != 0: | ||
raise ValueError( | ||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" | ||
) | ||
pe_dim = params.hidden_size // params.num_heads | ||
if sum(params.axes_dim) != pe_dim: | ||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") | ||
self.hidden_size = params.hidden_size | ||
self.num_heads = params.num_heads | ||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) | ||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) | ||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) | ||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) | ||
self.guidance_in = ( | ||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() | ||
) | ||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) | ||
|
||
self.double_blocks = nn.ModuleList( | ||
[ | ||
DoubleStreamBlock( | ||
self.hidden_size, | ||
self.num_heads, | ||
mlp_ratio=params.mlp_ratio, | ||
qkv_bias=params.qkv_bias, | ||
) | ||
for _ in range(params.depth) | ||
] | ||
) | ||
|
||
self.single_blocks = nn.ModuleList( | ||
[ | ||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) | ||
for _ in range(params.depth_single_blocks) | ||
] | ||
) | ||
|
||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) | ||
|
||
def forward( | ||
self, | ||
img: Tensor, | ||
img_ids: Tensor, | ||
txt: Tensor, | ||
txt_ids: Tensor, | ||
timesteps: Tensor, | ||
y: Tensor, | ||
guidance: Tensor | None = None, | ||
) -> Tensor: | ||
if img.ndim != 3 or txt.ndim != 3: | ||
raise ValueError("Input img and txt tensors must have 3 dimensions.") | ||
|
||
# running on sequences img | ||
img = self.img_in(img) | ||
vec = self.time_in(timestep_embedding(timesteps, 256)) | ||
if self.params.guidance_embed: | ||
if guidance is None: | ||
raise ValueError("Didn't get guidance strength for guidance distilled model.") | ||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) | ||
vec = vec + self.vector_in(y) | ||
txt = self.txt_in(txt) | ||
|
||
ids = torch.cat((txt_ids, img_ids), dim=1) | ||
pe = self.pe_embedder(ids) | ||
|
||
for block in self.double_blocks: | ||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe) | ||
|
||
img = torch.cat((txt, img), 1) | ||
for block in self.single_blocks: | ||
img = block(img, vec=vec, pe=pe) | ||
img = img[:, txt.shape[1] :, ...] | ||
|
||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) | ||
return img |
Oops, something went wrong.