Skip to content

Commit

Permalink
Add cache_dir for other models (#184)
Browse files Browse the repository at this point in the history
* done

* Update docstrings

* Lint

---------

Co-authored-by: RR4787 <[email protected]>
Co-authored-by: Cory Stephenson <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2025
1 parent b7e5029 commit 49282fd
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 23 deletions.
40 changes: 31 additions & 9 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ def stable_diffusion_xl(
use_xformers: bool = True,
lora_rank: Optional[int] = None,
lora_alpha: Optional[int] = None,
cache_dir: str = '/tmp/hf_files',
local_files_only: bool = False,
):
"""Stable diffusion 2 training setup + SDXL UNet and VAE.
Expand Down Expand Up @@ -364,6 +366,9 @@ def stable_diffusion_xl(
use_xformers (bool): Whether to use xformers for attention. Defaults to True.
lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
cache_dir (str): Directory to cache local files in. Default: `'/tmp/hf_files'`.
local_files_only (bool): Whether to only use local files. Default: `False`.
"""
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)

Expand All @@ -377,10 +382,14 @@ def stable_diffusion_xl(
val_metrics = [MeanSquaredError()]

# Make the tokenizer and text encoder
tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names)
tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names,
cache_dir=cache_dir,
local_files_only=local_files_only)
text_encoder = MultiTextEncoder(model_names=text_encoder_names,
encode_latents_in_fp16=encode_latents_in_fp16,
pretrained_sdxl=pretrained)
pretrained_sdxl=pretrained,
cache_dir=cache_dir,
local_files_only=local_files_only)

precision = torch.float16 if encode_latents_in_fp16 else None
# Make the autoencoder
Expand Down Expand Up @@ -408,9 +417,15 @@ def stable_diffusion_xl(
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)

# Make the unet
unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0]
unet_config = PretrainedConfig.get_config_dict(unet_model_name,
subfolder='unet',
cache_dir=cache_dir,
local_files_only=local_files_only)[0]
if pretrained:
unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet')
unet = UNet2DConditionModel.from_pretrained(unet_model_name,
subfolder='unet',
cache_dir=cache_dir,
local_files_only=local_files_only)
if isinstance(vae, AutoEncoder) and vae.config['latent_channels'] != 4:
raise ValueError(f'Pretrained unet has 4 latent channels but the vae has {vae.latent_channels}.')
else:
Expand Down Expand Up @@ -612,6 +627,7 @@ def precomputed_text_latent_diffusion(
use_xformers: bool = True,
lora_rank: Optional[int] = None,
lora_alpha: Optional[int] = None,
local_files_only: bool = False,
):
"""Latent diffusion model training using precomputed text latents from T5-XXL and CLIP.
Expand Down Expand Up @@ -662,6 +678,7 @@ def precomputed_text_latent_diffusion(
use_xformers (bool): Whether to use xformers for attention. Defaults to True.
lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
local_files_only (bool): Whether to only use local files. Default: `False`.
"""
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)

Expand Down Expand Up @@ -695,7 +712,10 @@ def precomputed_text_latent_diffusion(
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)

# Make the unet
unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0]
unet_config = PretrainedConfig.get_config_dict(unet_model_name,
subfolder='unet',
cache_dir=cache_dir,
local_files_only=local_files_only)[0]

if isinstance(vae, AutoEncoder):
# Adapt the unet config to account for differing number of latent channels if necessary
Expand Down Expand Up @@ -792,20 +812,22 @@ def precomputed_text_latent_diffusion(
if include_text_encoders:
dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}
dtype = dtype_map[text_encoder_dtype]
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True)
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl',
cache_dir=cache_dir,
local_files_only=local_files_only)
clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
subfolder='tokenizer',
cache_dir=cache_dir,
local_files_only=False)
local_files_only=local_files_only)
t5_encoder = AutoModel.from_pretrained('google/t5-v1_1-xxl',
torch_dtype=dtype,
cache_dir=cache_dir,
local_files_only=False).encoder.eval()
local_files_only=local_files_only).encoder.eval()
clip_encoder = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
subfolder='text_encoder',
torch_dtype=dtype,
cache_dir=cache_dir,
local_files_only=False).cuda().eval()
local_files_only=local_files_only).cuda().eval()
# Make the composer model
model = PrecomputedTextLatentDiffusion(
unet=unet,
Expand Down
49 changes: 35 additions & 14 deletions diffusion/models/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ class MultiTextEncoder(torch.nn.Module):
the projected output from a CLIPTextModelWithProjection. Default: ``False``.
"""

def __init__(
self,
model_names: Union[str, Tuple[str, ...]],
model_dim_keys: Optional[Union[str, List[str]]] = None,
encode_latents_in_fp16: bool = True,
pretrained_sdxl: bool = False,
):
def __init__(self,
model_names: Union[str, Tuple[str, ...]],
model_dim_keys: Optional[Union[str, List[str]]] = None,
encode_latents_in_fp16: bool = True,
pretrained_sdxl: bool = False,
cache_dir: str = '/tmp/hf_files',
local_files_only: bool = False):
super().__init__()
self.pretrained_sdxl = pretrained_sdxl

Expand All @@ -50,7 +50,10 @@ def __init__(
name_split = model_name.split('/')
base_name = '/'.join(name_split[:2])
subfolder = '/'.join(name_split[2:])
text_encoder_config = PretrainedConfig.get_config_dict(base_name, subfolder=subfolder)[0]
text_encoder_config = PretrainedConfig.get_config_dict(base_name,
subfolder=subfolder,
cache_dir=cache_dir,
local_files_only=local_files_only)[0]

# Add text_encoder output dim to total dim
dim_found = False
Expand All @@ -70,14 +73,25 @@ def __init__(
architectures = text_encoder_config['architectures']
if architectures == ['CLIPTextModel']:
self.text_encoders.append(
CLIPTextModel.from_pretrained(base_name, subfolder=subfolder, torch_dtype=torch_dtype))
CLIPTextModel.from_pretrained(base_name,
subfolder=subfolder,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
local_files_only=local_files_only))
elif architectures == ['CLIPTextModelWithProjection']:
self.text_encoders.append(
CLIPTextModelWithProjection.from_pretrained(base_name, subfolder=subfolder,
torch_dtype=torch_dtype))
CLIPTextModelWithProjection.from_pretrained(base_name,
subfolder=subfolder,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
local_files_only=local_files_only))
else:
self.text_encoders.append(
AutoModel.from_pretrained(base_name, subfolder=subfolder, torch_dtype=torch_dtype))
AutoModel.from_pretrained(base_name,
subfolder=subfolder,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
local_files_only=local_files_only))
self.architectures += architectures

@property
Expand Down Expand Up @@ -125,7 +139,10 @@ class MultiTokenizer:
"org_name/repo_name/subfolder" where the subfolder is excluded if it is not used in the repo.
"""

def __init__(self, tokenizer_names_or_paths: Union[str, Tuple[str, ...]]):
def __init__(self,
tokenizer_names_or_paths: Union[str, Tuple[str, ...]],
cache_dir: str = '/tmp/hf_files',
local_files_only: bool = False):
if isinstance(tokenizer_names_or_paths, str):
tokenizer_names_or_paths = (tokenizer_names_or_paths,)

Expand All @@ -134,7 +151,11 @@ def __init__(self, tokenizer_names_or_paths: Union[str, Tuple[str, ...]]):
path_split = tokenizer_name_or_path.split('/')
base_name = '/'.join(path_split[:2])
subfolder = '/'.join(path_split[2:])
self.tokenizers.append(AutoTokenizer.from_pretrained(base_name, subfolder=subfolder))
self.tokenizers.append(
AutoTokenizer.from_pretrained(base_name,
subfolder=subfolder,
cache_dir=cache_dir,
local_files_only=local_files_only))

self.model_max_length = min([t.model_max_length for t in self.tokenizers])

Expand Down

0 comments on commit 49282fd

Please sign in to comment.