Skip to content

Commit

Permalink
(fake*) FP8 training support (#184)
Browse files Browse the repository at this point in the history
* remove mixed_precision

* update

* make style

* update

* better defaults for experimenting

* fix train continuation after validation error

* update READMEs

* remove granularity

* update hook implementation to latest diffusers)

* update

* update

* remove unused patches

* remove mixed precision in tests

* add changes lost in merge conflict resolution

* update README date
  • Loading branch information
a-r-r-o-w authored Jan 14, 2025
1 parent f5f9cc0 commit d220aac
Show file tree
Hide file tree
Showing 19 changed files with 513 additions and 99 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ FineTrainers is a work-in-progress library to support (accessible) training of v

## News

- 🔥 **2024-01-15**: Support for naive FP8 weight-casting training added! This allows training HunyuanVideo in under 24 GB upto specific resolutions.
- 🔥 **2024-01-13**: Support for T2V full-finetuning added! Thanks to @ArEnSc for taking up the initiative!
- 🔥 **2024-01-03**: Support for T2V LoRA finetuning of [CogVideoX](https://huggingface.co/docs/diffusers/main/api/pipelines/cogvideox) added!
- 🔥 **2024-12-20**: Support for T2V LoRA finetuning of [Hunyuan Video](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) added! We would like to thank @SHYuanBest for his work on a training script [here](https://github.com/huggingface/diffusers/pull/10254).
Expand Down Expand Up @@ -83,7 +84,6 @@ diffusion_cmd="--flow_weighting_scheme logit_normal"
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
--mixed_precision bf16 \
--batch_size 1 \
--train_steps 3000 \
--rank 128 \
Expand Down Expand Up @@ -140,14 +140,14 @@ For inference, refer [here](./docs/training/ltx_video.md#inference). For docs re

| **Model Name** | **Tasks** | **Min. LoRA VRAM<sup>*</sup>** | **Min. Full Finetuning VRAM<sup>^</sup>** |
|:------------------------------------------------:|:-------------:|:----------------------------------:|:---------------------------------------------:|
| [LTX-Video](./docs/training/ltx_video.md) | Text-to-Video | 11 GB | 21 GB |
| [HunyuanVideo](./docs/training/hunyuan_video.md) | Text-to-Video | 42 GB | OOM |
| [CogVideoX-5b](./docs/training/cogvideox.md) | Text-to-Video | 21 GB | 53 GB |
| [LTX-Video](./docs/training/ltx_video.md) | Text-to-Video | 5 GB | 21 GB |
| [HunyuanVideo](./docs/training/hunyuan_video.md) | Text-to-Video | 32 GB | OOM |
| [CogVideoX-5b](./docs/training/cogvideox.md) | Text-to-Video | 18 GB | 53 GB |

</div>

<sub><sup>*</sup>Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using fp8 weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).</sub><br/>
<sub><sup>^</sup>Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using bf16 weights & gradient checkpointing.</sub>
<sub><sup>*</sup>Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using **FP8** weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).</sub><br/>
<sub><sup>^</sup>Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using **BF16** weights & gradient checkpointing.</sub>

If you would like to use a custom dataset, refer to the dataset preparation guide [here](./docs/dataset/README.md).

Expand Down
2 changes: 1 addition & 1 deletion accelerate_configs/compiled_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ enable_cpu_affinity: false
gpu_ids: '3'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
Expand Down
2 changes: 1 addition & 1 deletion accelerate_configs/uncompiled_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ enable_cpu_affinity: false
gpu_ids: '3'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
Expand Down
7 changes: 6 additions & 1 deletion docs/training/cogvideox.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ dataloader_cmd="--dataloader_num_workers 4"
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
--mixed_precision bf16 \
--batch_size 1 \
--precompute_conditions \
--train_steps 1000 \
Expand Down Expand Up @@ -88,6 +87,12 @@ echo -ne "-------------------- Finished executing script --------------------\n\

### LoRA

<!-- TODO(aryan): Update these numbers for 49x512x768 -->

> [!NOTE]
>
> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x480x720` resolutions, **with precomputation**:

```
Expand Down
5 changes: 4 additions & 1 deletion docs/training/hunyuan_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ diffusion_cmd=""
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
--mixed_precision bf16 \
--batch_size 1 \
--train_steps 500 \
--rank 128 \
Expand Down Expand Up @@ -91,6 +90,10 @@ echo -ne "-------------------- Finished executing script --------------------\n\

### LoRA

> [!NOTE]
>
> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **without precomputation**:

```
Expand Down
5 changes: 4 additions & 1 deletion docs/training/ltx_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ diffusion_cmd="--flow_weighting_scheme logit_normal"
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
--mixed_precision bf16 \
--batch_size 1 \
--train_steps 3000 \
--rank 128 \
Expand Down Expand Up @@ -90,6 +89,10 @@ echo -ne "-------------------- Finished executing script --------------------\n\

### LoRA

> [!NOTE]
>
> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **without precomputation**:

```
Expand Down
11 changes: 7 additions & 4 deletions docs/training/optimization.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Memory optimizations

To lower memory requirements during training:

- `--precompute_conditions`: this precomputes the conditions and latents, and loads them as required during training, which saves a significant amount of time and memory.
- `--gradient_checkpointing`: this saves memory by recomputing activations during the backward pass.
- `--layerwise_upcasting_modules transformer`: naively casts the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`. This halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`)
- `--use_8bit_bnb`: this is only applicable to Adam and AdamW optimizers, and makes use of 8-bit precision to store optimizer states.
- Use a DeepSpeed config to launch training (refer to [`accelerate_configs/deepspeed.yaml`](./accelerate_configs/deepspeed.yaml) as an example).
- Pass `--precompute_conditions` when launching training.
- Pass `--gradient_checkpointing` when launching training.
- Pass `--use_8bit_bnb` when launching training. Note that this is only applicable to Adam and AdamW optimizers.
- Do not perform validation/testing. This saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.

We will continue to add more features that help to reduce memory consumption.
We will continue to add more features that help to reduce memory consumption.
76 changes: 58 additions & 18 deletions finetrainers/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ class Args:
Data type for the transformer model.
vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
Data type for the VAE model.
layerwise_upcasting_modules (`List[str]`, defaults to `[]`):
Modules that should have fp8 storage weights but higher precision computation. Choose between ['transformer'].
layerwise_upcasting_storage_dtype (`torch.dtype`, defaults to `float8_e4m3fn`):
Data type for the layerwise upcasting storage. Choose between ['float8_e4m3fn', 'float8_e5m2'].
layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"]`):
Modules to skip for layerwise upcasting. Layers such as normalization and modulation, when casted to fp8 precision
naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers
by default, and recommend adding more layers to the default list based on the model architecture.
DATASET ARGUMENTS
-----------------
Expand Down Expand Up @@ -126,8 +134,6 @@ class Args:
Type of training to perform. Choose between ['lora'].
seed (`int`, defaults to `42`):
A seed for reproducible training.
mixed_precision (`str`, defaults to `None`):
Whether to use mixed precision. Choose between ['no', 'fp8', 'fp16', 'bf16'].
batch_size (`int`, defaults to `1`):
Per-device batch size.
train_epochs (`int`, defaults to `1`):
Expand Down Expand Up @@ -243,6 +249,18 @@ class Args:
text_encoder_3_dtype: torch.dtype = torch.bfloat16
transformer_dtype: torch.dtype = torch.bfloat16
vae_dtype: torch.dtype = torch.bfloat16
layerwise_upcasting_modules: List[str] = []
layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn
layerwise_upcasting_skip_modules_pattern: List[str] = [
"patch_embed",
"pos_embed",
"x_embedder",
"context_embedder",
"time_embed",
"^proj_in$",
"^proj_out$",
"norm",
]

# Dataset arguments
data_root: str = None
Expand Down Expand Up @@ -277,9 +295,6 @@ class Args:
# Training arguments
training_type: str = None
seed: int = 42
mixed_precision: str = (
None # TODO: consider removing later https://github.com/a-r-r-o-w/finetrainers/pull/139#discussion_r1897438414
)
batch_size: int = 1
train_epochs: int = 1
train_steps: int = None
Expand Down Expand Up @@ -347,6 +362,9 @@ def to_dict(self) -> Dict[str, Any]:
"text_encoder_3_dtype": self.text_encoder_3_dtype,
"transformer_dtype": self.transformer_dtype,
"vae_dtype": self.vae_dtype,
"layerwise_upcasting_modules": self.layerwise_upcasting_modules,
"layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype,
"layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern,
},
"dataset_arguments": {
"data_root": self.data_root,
Expand Down Expand Up @@ -381,7 +399,6 @@ def to_dict(self) -> Dict[str, Any]:
"training_arguments": {
"training_type": self.training_type,
"seed": self.seed,
"mixed_precision": self.mixed_precision,
"batch_size": self.batch_size,
"train_epochs": self.train_epochs,
"train_steps": self.train_steps,
Expand Down Expand Up @@ -464,6 +481,7 @@ def parse_arguments() -> Args:


def validate_args(args: Args):
_validated_model_args(args)
_validate_training_args(args)
_validate_validation_args(args)

Expand Down Expand Up @@ -506,6 +524,28 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.")
parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.")
parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.")
parser.add_argument(
"--layerwise_upcasting_modules",
type=str,
default=[],
nargs="+",
choices=["transformer"],
help="Modules that should have fp8 storage weights but higher precision computation.",
)
parser.add_argument(
"--layerwise_upcasting_storage_dtype",
type=str,
default="float8_e4m3fn",
choices=["float8_e4m3fn", "float8_e5m2"],
help="Data type for the layerwise upcasting storage.",
)
parser.add_argument(
"--layerwise_upcasting_skip_modules_pattern",
type=str,
default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"],
nargs="+",
help="Modules to skip for layerwise upcasting.",
)


def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
Expand Down Expand Up @@ -688,16 +728,6 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
help="Type of training to perform. Choose between ['lora', 'full-finetune']",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp8", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Defaults to the value of accelerate config of the current system or the "
"flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--batch_size",
type=int,
Expand Down Expand Up @@ -979,8 +1009,9 @@ def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
"float8_e4m3fn": torch.float8_e4m3fn,
"float8_e5m2": torch.float8_e5m2,
}
_INVERSE_DTYPE_MAP = {v: k for k, v in _DTYPE_MAP.items()}


def _map_to_args_type(args: Dict[str, Any]) -> Args:
Expand All @@ -997,6 +1028,9 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype]
result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype]
result_args.layerwise_upcasting_modules = args.layerwise_upcasting_modules
result_args.layerwise_upcasting_storage_dtype = _DTYPE_MAP[args.layerwise_upcasting_storage_dtype]
result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern

# Dataset arguments
if args.data_root is None and args.dataset_file is None:
Expand Down Expand Up @@ -1034,7 +1068,6 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
# Training arguments
result_args.training_type = args.training_type
result_args.seed = args.seed
result_args.mixed_precision = args.mixed_precision
result_args.batch_size = args.batch_size
result_args.train_epochs = args.train_epochs
result_args.train_steps = args.train_steps
Expand Down Expand Up @@ -1117,6 +1150,13 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
return result_args


def _validated_model_args(args: Args):
if args.training_type == "full-finetune":
assert (
"transformer" not in args.layerwise_upcasting_modules
), "Layerwise upcasting is not supported for full-finetune training"


def _validate_training_args(args: Args):
if args.training_type == "lora":
assert args.rank is not None, "Rank is required for LoRA training"
Expand Down
8 changes: 7 additions & 1 deletion finetrainers/cogvideox/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def initialize_pipeline(
enable_slicing: bool = False,
enable_tiling: bool = False,
enable_model_cpu_offload: bool = False,
is_training: bool = False,
**kwargs,
) -> CogVideoXPipeline:
component_name_pairs = [
Expand All @@ -81,9 +82,14 @@ def initialize_pipeline(

pipe = CogVideoXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
pipe.vae = pipe.vae.to(dtype=vae_dtype)

# The transformer should already be in the correct dtype when training, so we don't need to cast it here.
# If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
# DDP optimizer step.
if not is_training:
pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)

if enable_slicing:
pipe.vae.enable_slicing()
if enable_tiling:
Expand Down
1 change: 1 addition & 0 deletions finetrainers/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .layerwise_upcasting import apply_layerwise_upcasting
Loading

0 comments on commit d220aac

Please sign in to comment.