Skip to content

Commit 0699212

Browse files
committed
support vram management in flux
1 parent 46d4616 commit 0699212

File tree

8 files changed

+246
-6
lines changed

8 files changed

+246
-6
lines changed

diffsynth/models/model_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re
8080
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
8181
loaded_model_names, loaded_models = [], []
8282
for model_name, model_class in zip(model_names, model_classes):
83-
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
83+
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
84+
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
85+
else:
86+
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
8487
if torch_dtype == torch.float16 and hasattr(model, "half"):
8588
model = model.half()
8689
try:

diffsynth/models/sd3_text_encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ def __init__(self, vocab_size=49408):
99
super().__init__(vocab_size=vocab_size)
1010

1111
def forward(self, input_ids, clip_skip=2, extra_mask=None):
12-
embeds = self.token_embedding(input_ids) + self.position_embeds
12+
embeds = self.token_embedding(input_ids)
13+
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
1314
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
1415
if extra_mask is not None:
1516
attn_mask[:, extra_mask[0]==0] = float("-inf")

diffsynth/pipelines/base.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,22 @@ def load_models_to_device(self, loadmodel_names=[]):
101101
if model_name not in loadmodel_names:
102102
model = getattr(self, model_name)
103103
if model is not None:
104-
model.cpu()
104+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
105+
for module in model.modules():
106+
if hasattr(module, "offload"):
107+
module.offload()
108+
else:
109+
model.cpu()
105110
# load the needed models to device
106111
for model_name in loadmodel_names:
107112
model = getattr(self, model_name)
108113
if model is not None:
109-
model.to(self.device)
114+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
115+
for module in model.modules():
116+
if hasattr(module, "onload"):
117+
module.onload()
118+
else:
119+
model.to(self.device)
110120
# fresh the cuda cache
111121
torch.cuda.empty_cache()
112122

diffsynth/pipelines/flux_image.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from ..models.tiler import FastTileWorker
1212
from transformers import SiglipVisionModel
1313
from copy import deepcopy
14+
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
15+
from ..models.flux_dit import RMSNorm
16+
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
1417

1518

1619
class FluxImagePipeline(BasePipeline):
@@ -31,6 +34,105 @@ def __init__(self, device="cuda", torch_dtype=torch.float16):
3134
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
3235

3336

37+
def enable_vram_management(self, num_persistent_param_in_dit=None):
38+
dtype = next(iter(self.text_encoder_1.parameters())).dtype
39+
enable_vram_management(
40+
self.text_encoder_1,
41+
module_map = {
42+
torch.nn.Linear: AutoWrappedLinear,
43+
torch.nn.Embedding: AutoWrappedModule,
44+
torch.nn.LayerNorm: AutoWrappedModule,
45+
},
46+
module_config = dict(
47+
offload_dtype=dtype,
48+
offload_device="cpu",
49+
onload_dtype=dtype,
50+
onload_device="cpu",
51+
computation_dtype=self.torch_dtype,
52+
computation_device=self.device,
53+
),
54+
)
55+
dtype = next(iter(self.text_encoder_2.parameters())).dtype
56+
enable_vram_management(
57+
self.text_encoder_2,
58+
module_map = {
59+
torch.nn.Linear: AutoWrappedLinear,
60+
torch.nn.Embedding: AutoWrappedModule,
61+
T5LayerNorm: AutoWrappedModule,
62+
T5DenseActDense: AutoWrappedModule,
63+
T5DenseGatedActDense: AutoWrappedModule,
64+
},
65+
module_config = dict(
66+
offload_dtype=dtype,
67+
offload_device="cpu",
68+
onload_dtype=dtype,
69+
onload_device="cpu",
70+
computation_dtype=self.torch_dtype,
71+
computation_device=self.device,
72+
),
73+
)
74+
dtype = next(iter(self.dit.parameters())).dtype
75+
enable_vram_management(
76+
self.dit,
77+
module_map = {
78+
RMSNorm: AutoWrappedModule,
79+
torch.nn.Linear: AutoWrappedLinear,
80+
},
81+
module_config = dict(
82+
offload_dtype=dtype,
83+
offload_device="cpu",
84+
onload_dtype=dtype,
85+
onload_device="cuda",
86+
computation_dtype=self.torch_dtype,
87+
computation_device=self.device,
88+
),
89+
max_num_param=num_persistent_param_in_dit,
90+
overflow_module_config = dict(
91+
offload_dtype=dtype,
92+
offload_device="cpu",
93+
onload_dtype=dtype,
94+
onload_device="cpu",
95+
computation_dtype=self.torch_dtype,
96+
computation_device=self.device,
97+
),
98+
)
99+
dtype = next(iter(self.vae_decoder.parameters())).dtype
100+
enable_vram_management(
101+
self.vae_decoder,
102+
module_map = {
103+
torch.nn.Linear: AutoWrappedLinear,
104+
torch.nn.Conv2d: AutoWrappedModule,
105+
torch.nn.GroupNorm: AutoWrappedModule,
106+
},
107+
module_config = dict(
108+
offload_dtype=dtype,
109+
offload_device="cpu",
110+
onload_dtype=dtype,
111+
onload_device="cpu",
112+
computation_dtype=self.torch_dtype,
113+
computation_device=self.device,
114+
),
115+
)
116+
dtype = next(iter(self.vae_encoder.parameters())).dtype
117+
enable_vram_management(
118+
self.vae_encoder,
119+
module_map = {
120+
torch.nn.Linear: AutoWrappedLinear,
121+
torch.nn.Conv2d: AutoWrappedModule,
122+
torch.nn.GroupNorm: AutoWrappedModule,
123+
},
124+
module_config = dict(
125+
offload_dtype=dtype,
126+
offload_device="cpu",
127+
onload_dtype=dtype,
128+
onload_device="cpu",
129+
computation_dtype=self.torch_dtype,
130+
computation_device=self.device,
131+
),
132+
)
133+
self.enable_cpu_offload()
134+
135+
34136
def denoising_model(self):
35137
return self.dit
36138

@@ -62,10 +164,10 @@ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: Lis
62164

63165

64166
@staticmethod
65-
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
167+
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
66168
pipe = FluxImagePipeline(
67169
device=model_manager.device if device is None else device,
68-
torch_dtype=model_manager.torch_dtype,
170+
torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype,
69171
)
70172
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
71173
return pipe
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .layers import *
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch, copy
2+
from ..models.utils import init_weights_on_device
3+
4+
5+
def cast_to(weight, dtype, device):
6+
r = torch.empty_like(weight, dtype=dtype, device=device)
7+
r.copy_(weight)
8+
return r
9+
10+
11+
class AutoWrappedModule(torch.nn.Module):
12+
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
13+
super().__init__()
14+
self.module = module.to(dtype=offload_dtype, device=offload_device)
15+
self.offload_dtype = offload_dtype
16+
self.offload_device = offload_device
17+
self.onload_dtype = onload_dtype
18+
self.onload_device = onload_device
19+
self.computation_dtype = computation_dtype
20+
self.computation_device = computation_device
21+
self.state = 0
22+
23+
def offload(self):
24+
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
25+
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
26+
self.state = 0
27+
28+
def onload(self):
29+
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
30+
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
31+
self.state = 1
32+
33+
def forward(self, *args, **kwargs):
34+
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
35+
module = self.module
36+
else:
37+
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
38+
return module(*args, **kwargs)
39+
40+
41+
class AutoWrappedLinear(torch.nn.Linear):
42+
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
43+
with init_weights_on_device(device=torch.device("meta")):
44+
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
45+
self.weight = module.weight
46+
self.bias = module.bias
47+
self.offload_dtype = offload_dtype
48+
self.offload_device = offload_device
49+
self.onload_dtype = onload_dtype
50+
self.onload_device = onload_device
51+
self.computation_dtype = computation_dtype
52+
self.computation_device = computation_device
53+
self.state = 0
54+
55+
def offload(self):
56+
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
57+
self.to(dtype=self.offload_dtype, device=self.offload_device)
58+
self.state = 0
59+
60+
def onload(self):
61+
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
62+
self.to(dtype=self.onload_dtype, device=self.onload_device)
63+
self.state = 1
64+
65+
def forward(self, x, *args, **kwargs):
66+
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
67+
weight, bias = self.weight, self.bias
68+
else:
69+
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
70+
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
71+
return torch.nn.functional.linear(x, weight, bias)
72+
73+
74+
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
75+
for name, module in model.named_children():
76+
for source_module, target_module in module_map.items():
77+
if isinstance(module, source_module):
78+
num_param = sum(p.numel() for p in module.parameters())
79+
if max_num_param is not None and total_num_param + num_param > max_num_param:
80+
module_config_ = overflow_module_config
81+
else:
82+
module_config_ = module_config
83+
module_ = target_module(module, **module_config_)
84+
setattr(model, name, module_)
85+
total_num_param += num_param
86+
break
87+
else:
88+
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
89+
return total_num_param
90+
91+
92+
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
93+
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
94+
model.vram_management_enabled = True
95+

examples/vram_management/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# VRAM Management
2+
3+
Experimental feature. Still under development.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
from diffsynth import ModelManager, FluxImagePipeline
3+
4+
5+
model_manager = ModelManager(
6+
file_path_list=[
7+
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
8+
"models/FLUX/FLUX.1-dev/text_encoder_2",
9+
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
10+
"models/FLUX/FLUX.1-dev/ae.safetensors",
11+
],
12+
torch_dtype=torch.float8_e4m3fn,
13+
device="cpu"
14+
)
15+
pipe = FluxImagePipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
16+
17+
# Enable VRAM management
18+
# `num_persistent_param_in_dit` indicates the number of parameters that reside persistently in VRAM within the DiT model.
19+
# When `num_persistent_param_in_dit=None`, it means all parameters reside persistently in memory.
20+
# When `num_persistent_param_in_dit=7*10**9`, it indicates that 7 billion parameters reside persistently in memory.
21+
# When `num_persistent_param_in_dit=0`, it means no parameters reside persistently in memory, and they are loaded layer by layer during inference.
22+
pipe.enable_vram_management(num_persistent_param_in_dit=None)
23+
24+
image = pipe(prompt="a beautiful orange cat", seed=0)
25+
image.save("image.jpg")

0 commit comments

Comments
 (0)