forked from xdit-project/xDiT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
envs.py
133 lines (111 loc) · 4.51 KB
/
envs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import torch
import diffusers
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from packaging import version
from xfuser.logger import init_logger
logger = init_logger(__name__)
if TYPE_CHECKING:
MASTER_ADDR: str = ""
MASTER_PORT: Optional[int] = None
CUDA_HOME: Optional[str] = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
XDIT_LOGGING_LEVEL: str = "INFO"
CUDA_VERSION: version.Version
TORCH_VERSION: version.Version
environment_variables: Dict[str, Callable[[], Any]] = {
# ================== Runtime Env Vars ==================
# used in distributed environment to determine the master address
"MASTER_ADDR": lambda: os.getenv("MASTER_ADDR", ""),
# used in distributed environment to manually set the communication port
"MASTER_PORT": lambda: (
int(os.getenv("MASTER_PORT", "0")) if "MASTER_PORT" in os.environ else None
),
# path to cudatoolkit home directory, under which should be bin, include,
# and lib directories.
"CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
# used to control the visible devices in the distributed setting
"CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
# this is used for configuring the default logging level
"XDIT_LOGGING_LEVEL": lambda: os.getenv("XDIT_LOGGING_LEVEL", "INFO"),
}
variables: Dict[str, Callable[[], Any]] = {
# ================== Other Vars ==================
# used in version checking
"CUDA_VERSION": lambda: version.parse(torch.version.cuda),
"TORCH_VERSION": lambda: version.parse(
version.parse(torch.__version__).base_version
),
}
class PackagesEnvChecker:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(PackagesEnvChecker, cls).__new__(cls)
cls._instance.initialize()
return cls._instance
def initialize(self):
self.packages_info = {
"has_flash_attn": self.check_flash_attn(),
"has_long_ctx_attn": self.check_long_ctx_attn(),
"diffusers_version": self.check_diffusers_version(),
}
def check_flash_attn(self):
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu_name = torch.cuda.get_device_name(device)
if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name:
return False
else:
from flash_attn import flash_attn_func
from flash_attn import __version__
if __version__ < "2.6.0":
raise ImportError(f"install flash_attn >= 2.6.0")
return True
except ImportError:
logger.warning(
f'Flash Attention library "flash_attn" not found, '
f"using pytorch attention implementation"
)
return False
def check_long_ctx_attn(self):
try:
from yunchang import (
set_seq_parallel_pg,
ring_flash_attn_func,
UlyssesAttention,
LongContextAttention,
LongContextAttentionQKVPacked,
)
return True
except ImportError:
logger.warning(
f'Ring Flash Attention library "yunchang" not found, '
f"using pytorch attention implementation"
)
return False
def check_diffusers_version(self):
if version.parse(
version.parse(diffusers.__version__).base_version
) < version.parse("0.30.0"):
raise RuntimeError(
f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported,"
f"please upgrade to version > 0.30.0"
)
return version.parse(version.parse(diffusers.__version__).base_version)
def get_packages_info(self):
return self.packages_info
PACKAGES_CHECKER = PackagesEnvChecker()
def __getattr__(name):
# lazy evaluation of environment variables
if name in environment_variables:
return environment_variables[name]()
if name in variables:
return variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(environment_variables.keys())