-
Notifications
You must be signed in to change notification settings - Fork 1
/
optimizations.txt
67 lines (65 loc) · 1.84 KB
/
optimizations.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
All Ops:
Timesteps (Triton)
TimestepEmbedding (Triton)
GroupNorm (Triton)
Conv2D (Torch) (Cuda maybe? From https://github.com/chengzeyi/stable-fast)
Dropout (Triton)
Linear (Triton)
SiLU (Triton)
Attention (Triton, FA2/FA1 based on available gpu)
GeGLU (Triton)
LayerNorm (Triton)
Fused Kernels:
ResNet:
TimestepEmbedding:
Linear; SiLU
Linear
GroupNorm; SiLU
GroupNorm; SiLU
Dropout
Attention:
QKV Proj can be fused into one kernel
FA2/FA1 depending on available GPU
Linear
Dropout
FeedForward
GeGLU:
Gated Linear; GeGLU (Gated might be diff to implement in Triton)
GeGLU
Dropout
Linear
BasicTransformerBlock:
LayerNorm (All LayerNorm is left unfused for now; possible entry point for better performace)
Attention (See above)
FeedForward (See above)
Transformer2DModel:
GroupNorm
Linear
BasicTransformerBlock (See above)
Linear
Downsample2D (Left untouced since only op is a conv2d)
Upsample2D (Left untouced since main op is a conv2d)
DownBlock2D:
ResNet (See above)
Downsample2D (See above)
CrossAttnDownBlock2D:
ResNet (See above)
Transformer2DModel (See above)
Downsample2D (See above)
CrossAttnUpBlock2D:
ResNet (See above)
Transformer2DModel (See above)
Downsample2D (See above)
UpBlock2D:
ResNet (See above)
UNetMidBlock2DCrossAttn:
Transforsmer2DModel (See above)
ResNet (See above)
UNet2DConditionModel:
TimestepEmbedding (See above)
Timesteps (See above)
DownBlock2D (See above)
CrossAttnDownBlock2D (See above)
CrossAttnUpBlock2D (See above)
UpBlock2D (See above)
GroupNorm; SiLU