forked from xdit-project/xDiT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprofile_macs.py
57 lines (50 loc) · 1.68 KB
/
profile_macs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import torch
from diffusers import StableDiffusionXLPipeline
from torchprofile import profile_macs
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--image_size",
type=int,
nargs="*",
default=1024,
help="Image size of generation",
)
args = parser.parse_args()
if isinstance(args.image_size, int):
args.image_size = [args.image_size // 8, args.image_size // 8]
elif len(args.image_size) == 1:
args.image_size = [args.image_size[0] // 8, args.image_size[0] // 8]
else:
assert len(args.image_size) == 2
args.image_size = [args.image_size[0] // 8, args.image_size[1] // 8]
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
unet = pipeline.unet
latent_model_input = torch.randn(2, 4, *args.image_size, dtype=unet.dtype).to(
"cuda"
)
t = torch.randn(1).to("cuda")
prompt_embeds = torch.randn(2, 77, 2048, dtype=unet.dtype).to("cuda")
add_text_embeds = torch.randn(2, 1280, dtype=unet.dtype).to("cuda")
add_time_ids = torch.randint(0, 1024, (2, 6)).to("cuda")
with torch.no_grad():
macs = profile_macs(
unet,
args=(
latent_model_input,
t,
prompt_embeds,
None,
None,
None,
None,
{"text_embeds": add_text_embeds, "time_ids": add_time_ids},
),
)
print(f"MACs: {macs / 1e9:.3f}G")