From 49282fd7b2ecde2d5bafb1ca6476d9239865665f Mon Sep 17 00:00:00 2001 From: coryMosaicML <83666378+coryMosaicML@users.noreply.github.com> Date: Fri, 10 Jan 2025 11:04:13 -0800 Subject: [PATCH] Add cache_dir for other models (#184) * done * Update docstrings * Lint --------- Co-authored-by: RR4787 Co-authored-by: Cory Stephenson --- diffusion/models/models.py | 40 ++++++++++++++++++++------ diffusion/models/text_encoder.py | 49 +++++++++++++++++++++++--------- 2 files changed, 66 insertions(+), 23 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 6be52f5a..4fc62aa7 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -307,6 +307,8 @@ def stable_diffusion_xl( use_xformers: bool = True, lora_rank: Optional[int] = None, lora_alpha: Optional[int] = None, + cache_dir: str = '/tmp/hf_files', + local_files_only: bool = False, ): """Stable diffusion 2 training setup + SDXL UNet and VAE. @@ -364,6 +366,9 @@ def stable_diffusion_xl( use_xformers (bool): Whether to use xformers for attention. Defaults to True. lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None. lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None. + cache_dir (str): Directory to cache local files in. Default: `'/tmp/hf_files'`. + local_files_only (bool): Whether to only use local files. Default: `False`. + """ latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) @@ -377,10 +382,14 @@ def stable_diffusion_xl( val_metrics = [MeanSquaredError()] # Make the tokenizer and text encoder - tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names) + tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names, + cache_dir=cache_dir, + local_files_only=local_files_only) text_encoder = MultiTextEncoder(model_names=text_encoder_names, encode_latents_in_fp16=encode_latents_in_fp16, - pretrained_sdxl=pretrained) + pretrained_sdxl=pretrained, + cache_dir=cache_dir, + local_files_only=local_files_only) precision = torch.float16 if encode_latents_in_fp16 else None # Make the autoencoder @@ -408,9 +417,15 @@ def stable_diffusion_xl( downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) # Make the unet - unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0] + unet_config = PretrainedConfig.get_config_dict(unet_model_name, + subfolder='unet', + cache_dir=cache_dir, + local_files_only=local_files_only)[0] if pretrained: - unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet') + unet = UNet2DConditionModel.from_pretrained(unet_model_name, + subfolder='unet', + cache_dir=cache_dir, + local_files_only=local_files_only) if isinstance(vae, AutoEncoder) and vae.config['latent_channels'] != 4: raise ValueError(f'Pretrained unet has 4 latent channels but the vae has {vae.latent_channels}.') else: @@ -612,6 +627,7 @@ def precomputed_text_latent_diffusion( use_xformers: bool = True, lora_rank: Optional[int] = None, lora_alpha: Optional[int] = None, + local_files_only: bool = False, ): """Latent diffusion model training using precomputed text latents from T5-XXL and CLIP. @@ -662,6 +678,7 @@ def precomputed_text_latent_diffusion( use_xformers (bool): Whether to use xformers for attention. Defaults to True. lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None. lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None. + local_files_only (bool): Whether to only use local files. Default: `False`. """ latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) @@ -695,7 +712,10 @@ def precomputed_text_latent_diffusion( downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) # Make the unet - unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0] + unet_config = PretrainedConfig.get_config_dict(unet_model_name, + subfolder='unet', + cache_dir=cache_dir, + local_files_only=local_files_only)[0] if isinstance(vae, AutoEncoder): # Adapt the unet config to account for differing number of latent channels if necessary @@ -792,20 +812,22 @@ def precomputed_text_latent_diffusion( if include_text_encoders: dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16} dtype = dtype_map[text_encoder_dtype] - t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True) + t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', + cache_dir=cache_dir, + local_files_only=local_files_only) clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='tokenizer', cache_dir=cache_dir, - local_files_only=False) + local_files_only=local_files_only) t5_encoder = AutoModel.from_pretrained('google/t5-v1_1-xxl', torch_dtype=dtype, cache_dir=cache_dir, - local_files_only=False).encoder.eval() + local_files_only=local_files_only).encoder.eval() clip_encoder = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='text_encoder', torch_dtype=dtype, cache_dir=cache_dir, - local_files_only=False).cuda().eval() + local_files_only=local_files_only).cuda().eval() # Make the composer model model = PrecomputedTextLatentDiffusion( unet=unet, diff --git a/diffusion/models/text_encoder.py b/diffusion/models/text_encoder.py index df78e7d3..fe8ff42e 100644 --- a/diffusion/models/text_encoder.py +++ b/diffusion/models/text_encoder.py @@ -25,13 +25,13 @@ class MultiTextEncoder(torch.nn.Module): the projected output from a CLIPTextModelWithProjection. Default: ``False``. """ - def __init__( - self, - model_names: Union[str, Tuple[str, ...]], - model_dim_keys: Optional[Union[str, List[str]]] = None, - encode_latents_in_fp16: bool = True, - pretrained_sdxl: bool = False, - ): + def __init__(self, + model_names: Union[str, Tuple[str, ...]], + model_dim_keys: Optional[Union[str, List[str]]] = None, + encode_latents_in_fp16: bool = True, + pretrained_sdxl: bool = False, + cache_dir: str = '/tmp/hf_files', + local_files_only: bool = False): super().__init__() self.pretrained_sdxl = pretrained_sdxl @@ -50,7 +50,10 @@ def __init__( name_split = model_name.split('/') base_name = '/'.join(name_split[:2]) subfolder = '/'.join(name_split[2:]) - text_encoder_config = PretrainedConfig.get_config_dict(base_name, subfolder=subfolder)[0] + text_encoder_config = PretrainedConfig.get_config_dict(base_name, + subfolder=subfolder, + cache_dir=cache_dir, + local_files_only=local_files_only)[0] # Add text_encoder output dim to total dim dim_found = False @@ -70,14 +73,25 @@ def __init__( architectures = text_encoder_config['architectures'] if architectures == ['CLIPTextModel']: self.text_encoders.append( - CLIPTextModel.from_pretrained(base_name, subfolder=subfolder, torch_dtype=torch_dtype)) + CLIPTextModel.from_pretrained(base_name, + subfolder=subfolder, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + local_files_only=local_files_only)) elif architectures == ['CLIPTextModelWithProjection']: self.text_encoders.append( - CLIPTextModelWithProjection.from_pretrained(base_name, subfolder=subfolder, - torch_dtype=torch_dtype)) + CLIPTextModelWithProjection.from_pretrained(base_name, + subfolder=subfolder, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + local_files_only=local_files_only)) else: self.text_encoders.append( - AutoModel.from_pretrained(base_name, subfolder=subfolder, torch_dtype=torch_dtype)) + AutoModel.from_pretrained(base_name, + subfolder=subfolder, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + local_files_only=local_files_only)) self.architectures += architectures @property @@ -125,7 +139,10 @@ class MultiTokenizer: "org_name/repo_name/subfolder" where the subfolder is excluded if it is not used in the repo. """ - def __init__(self, tokenizer_names_or_paths: Union[str, Tuple[str, ...]]): + def __init__(self, + tokenizer_names_or_paths: Union[str, Tuple[str, ...]], + cache_dir: str = '/tmp/hf_files', + local_files_only: bool = False): if isinstance(tokenizer_names_or_paths, str): tokenizer_names_or_paths = (tokenizer_names_or_paths,) @@ -134,7 +151,11 @@ def __init__(self, tokenizer_names_or_paths: Union[str, Tuple[str, ...]]): path_split = tokenizer_name_or_path.split('/') base_name = '/'.join(path_split[:2]) subfolder = '/'.join(path_split[2:]) - self.tokenizers.append(AutoTokenizer.from_pretrained(base_name, subfolder=subfolder)) + self.tokenizers.append( + AutoTokenizer.from_pretrained(base_name, + subfolder=subfolder, + cache_dir=cache_dir, + local_files_only=local_files_only)) self.model_max_length = min([t.model_max_length for t in self.tokenizers])