Skip to content

Commit

Permalink
Split and simplify model code
Browse files Browse the repository at this point in the history
Signed-off-by: Akhil Goel <[email protected]>
  • Loading branch information
akhilg-nv committed Aug 29, 2024
1 parent a8aacaf commit c7c81bd
Show file tree
Hide file tree
Showing 7 changed files with 678 additions and 602 deletions.
123 changes: 123 additions & 0 deletions tripy/examples/diffusion/clip_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import tripy as tp

import tripy as tp
from dataclasses import dataclass

from examples.diffusion.helper import scaled_dot_product_attention

@dataclass
class CLIPConfig:
vocab_size: int = 49408
embedding_size: int = 768
num_heads: int = 12
max_seq_len: int = 77
num_hidden_layers: int = 12

class CLIPMLP(tp.Module):
def __init__(self, config: CLIPConfig):
self.fc1 = tp.Linear(config.embedding_size, config.embedding_size * 4)
self.fc2 = tp.Linear(config.embedding_size * 4, config.embedding_size)

def __call__(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = tp.sigmoid(1.702 * hidden_states) * hidden_states # quick GELU
hidden_states = self.fc2(hidden_states)
return hidden_states


class CLIPAttention(tp.Module):
def __init__(self, config: CLIPConfig):
self.embed_dim = config.embedding_size
self.num_heads = config.num_heads
self.head_dim = self.embed_dim // self.num_heads
self.k_proj = tp.Linear(self.embed_dim, self.embed_dim)
self.v_proj = tp.Linear(self.embed_dim, self.embed_dim)
self.q_proj = tp.Linear(self.embed_dim, self.embed_dim)
self.out_proj = tp.Linear(self.embed_dim, self.embed_dim)

def __call__(self, hidden_states, causal_attention_mask):
bsz, tgt_len, embed_dim = hidden_states.shape[0], hidden_states.shape[1], hidden_states.shape[2]
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q, k, v = [
tp.transpose(
tp.reshape(x, (bsz, tgt_len, self.num_heads, self.head_dim)),
1,
2,
)
for x in (q, k, v)
]
attn_output = scaled_dot_product_attention(
q, k, v, embedding_dim=self.head_dim, attn_mask=causal_attention_mask
)
out = self.out_proj(tp.reshape(tp.transpose(attn_output, 1, 2), (bsz, tgt_len, embed_dim)))
return out


class CLIPEncoderLayer(tp.Module):
def __init__(self, config: CLIPConfig):
self.self_attn = CLIPAttention(config)
self.layer_norm1 = tp.LayerNorm(config.embedding_size)
self.mlp = CLIPMLP(config)
self.layer_norm2 = tp.LayerNorm(config.embedding_size)

def __call__(self, hidden_states, causal_attention_mask):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

return hidden_states


class CLIPEncoder(tp.Module):
def __init__(self, config: CLIPConfig):
self.layers = [CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]

def __call__(self, hidden_states, causal_attention_mask):
for l in self.layers:
hidden_states = l(hidden_states, causal_attention_mask)
return hidden_states


class CLIPTextEmbeddings(tp.Module):
def __init__(self, config: CLIPConfig):
self.token_embedding = tp.Embedding(config.vocab_size, config.embedding_size)
self.position_embedding = tp.Embedding(config.max_seq_len, config.embedding_size)

def __call__(self, input_ids, position_ids):
return self.token_embedding(input_ids) + self.position_embedding(position_ids)


class CLIPTextTransformer(tp.Module):
def __init__(self, config: CLIPConfig):
self.embeddings = CLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(config)
self.final_layer_norm = tp.LayerNorm(config.embedding_size)
self.max_seq_len = config.max_seq_len

def __call__(self, input_ids):
x = self.embeddings(input_ids, tp.reshape(tp.iota((input_ids.shape[1],), dtype=tp.int32), (1, -1)))
x = self.encoder(x, tp.triu(tp.full((1, 1, self.max_seq_len, self.max_seq_len), float("-inf")), 1))
return self.final_layer_norm(x)
39 changes: 20 additions & 19 deletions tripy/examples/diffusion/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
import torch
import cupy as cp
import numpy as np
import tripy as tp

from transformers import CLIPTokenizer
from examples.diffusion.model import CLIPConfig, StableDiffusion, get_alphas_cumprod
from examples.diffusion.clip_model import CLIPConfig
from examples.diffusion.model import StableDiffusion, StableDiffusionConfig, get_alphas_cumprod
from examples.diffusion.weight_loader import load_from_diffusers
import tripy as tp


def compile_model(model, inputs, verbose=False):
Expand Down Expand Up @@ -96,24 +97,24 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
def tripy_diffusion(args):
run_start_time = time.perf_counter()

# if os.path.isdir("engines"):
# print("[I] Loading cached engines from disk...")
# clip_compiled = tp.Executable.load(os.path.join("engines", "clip_executable.json"))
# unet_compiled = tp.Executable.load(os.path.join("engines", "unet_executable.json"))
# vae_compiled = tp.Executable.load(os.path.join("engines", "vae_executable.json"))
# else:
model = StableDiffusion()
print("[I] Loading model weights...", flush=True)
load_from_diffusers(model, tp.float32, debug=True)
clip_compiled = compile_clip(model.cond_stage_model.transformer.text_model, verbose=True)
unet_compiled = compile_unet(model, verbose=True)
vae_compiled = compile_vae(model.decode, verbose=True)
if os.path.isdir("engines"):
print("[I] Loading cached engines from disk...")
clip_compiled = tp.Executable.load(os.path.join("engines", "clip_executable.json"))
unet_compiled = tp.Executable.load(os.path.join("engines", "unet_executable.json"))
vae_compiled = tp.Executable.load(os.path.join("engines", "vae_executable.json"))
else:
model = StableDiffusion(StableDiffusionConfig)
print("[I] Loading model weights...", flush=True)
load_from_diffusers(model, tp.float32, debug=True)
clip_compiled = compile_clip(model.cond_stage_model.transformer.text_model, verbose=True)
unet_compiled = compile_unet(model, verbose=True)
vae_compiled = compile_vae(model.decode, verbose=True)

# os.mkdir("engines")
# print("[I] Saving engines to disk...")
# clip_compiled.save(os.path.join("engines", "clip_executable.json"))
# unet_compiled.save(os.path.join("engines", "unet_executable.json"))
# vae_compiled.save(os.path.join("engines", "vae_executable.json"))
os.mkdir("engines")
print("[I] Saving engines to disk...")
clip_compiled.save(os.path.join("engines", "clip_executable.json"))
unet_compiled.save(os.path.join("engines", "unet_executable.json"))
vae_compiled.save(os.path.join("engines", "vae_executable.json"))

# Run through CLIP to get context from prompt
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
Expand Down
43 changes: 43 additions & 0 deletions tripy/examples/diffusion/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import math
from functools import reduce
from typing import List, Callable, Optional

import tripy as tp


def scaled_dot_product_attention(
query: tp.Tensor,
key: tp.Tensor,
value: tp.Tensor,
embedding_dim: Optional[int] = None,
attn_mask: Optional[tp.Tensor] = None,
is_causal: bool = False,
) -> tp.Tensor:
"""
Computes scaled dot-product attention.
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
- Described: https://paperswithcode.com/method/scaled
- Paper: https://arxiv.org/abs/1706.03762v7
"""

if is_causal: # this path is not called in demoDiffusion
target_shape = query.shape[-2:-1] + key.shape[-2:-1]
# TODO: #228: WAR to prevent computing output rank in infer_rank for reshape
target_shape.trace_tensor.shape = (2,)
attn_mask = tp.cast(tp.tril(tp.ones(target_shape)), tp.bool)
if attn_mask is not None and attn_mask.dtype == tp.bool:
attn_mask = tp.where((attn_mask == 0), tp.ones_like(attn_mask) * -float("inf"), tp.zeros_like(attn_mask))
qk = query @ tp.transpose(key, -2, -1) / math.sqrt(embedding_dim)
return tp.cast(tp.softmax((qk + attn_mask) if attn_mask is not None else qk, -1), query.dtype) @ value


def sequential(input: tp.Tensor, ll: List[Callable[[tp.Tensor], tp.Tensor]]):
"""
Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
"""
return reduce(lambda x, f: f(x), ll, input)


def clamp(tensor: tp.Tensor, min: int, max: int):
return tp.minimum(tp.maximum(tensor, tp.ones_like(tensor) * min), tp.ones_like(tensor) * max)
Loading

0 comments on commit c7c81bd

Please sign in to comment.