Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support lazy model init #60

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft

Support lazy model init #60

wants to merge 5 commits into from

Conversation

linziyi96
Copy link
Contributor

@linziyi96 linziyi96 commented Aug 28, 2023

This PR aims to add the support of lazy model initialization. This is one of the two steps to lower the CPU memory usage for quantized models. Quantization is currently implemented by replacing regular linear layers with quantized linear layers. Without lazy init, the full-precision model before replacement results in a huge peak memory usage, making both training and inference hard to run on commodity hardware even with aggressive quantization: For example, the 4-bit 13B, which theoretically only needs 6.5GB of memory and fits comfortably in any mainstream PC, now requires 52GB of memory (full precision model and full precision checkpoint); and the 4-bit 70B model, which theoretically needs 35GB of memory and fits in two 3090s, now requires 280GB of memory which is only possible on some expensive HEDT and server platforms.

With lazy init, the model creation steps become: (1) create a placeholder model without allocating any actual storage, (2) replace layers with quantized ones and (3) instantiate all tensors. In this way, we need not manually re-implement a quantized version for each (current or future) model, and only the amount of storage after quantization is actually allocated.

However, supporting lazy init turns out to be a complicated task, as PyTorch essentially provides no good way to decouple model creation and weight initialization at this moment. Despite that tensors can be created as meta, there seems to be no reliable way to initialize them afterwards: The fairscale layers tend to initialize the weights in __init__ and simply do not provide a separate method to initialize the weights after creation; and even if most PyTorch built-in layers do provide reset_parameter methods as of v2.0.1, they usually do not support custom initialization (e.g., LoRA needs zero init, but torch.nn.Linear.reset_parameters always initializes the weights randomly following a uniform distribution).

Facing such a dilemma, I am trying to follow the lazy init implementation of PyTorch FSDP: Relying on the module's reset_parameter method for each module containing directly managed parameters and buffers, with the heavy lifting left to implementing the reset_parameter for each module we used but do not have a working one in all cases.

The model creation process is supposed to be like the following after the change:

# All weights on meta device, including quantized layers but except vision encoder.
# Quantization layer replacement happens in MetaModel.
with default_tensor_type(..., meta=True):
    model = MetaModel(...)

# All tensors in checkpoints are materialized. If a quantized layer sees full-precision states, quantize before materialize.
utils.tensor_parallel.load_tensor_parallel_model_list(...)

# Materialize remaining weights (unseen in loaded checkpoints, using the reset_parameter method).
model.materialize()

Following this plan, the proposed code change is roughly organized into the following parts:

  • Equipping each layer with a correct and flexible (i.e., supporting custom initialization) reset_parameter method.
  • Extend the context manager default_tensor_type to support meta tensor creation. Disable meta tensor creation around visual backbones (for loading their weights; it may be problematic for large vision models for which we may discuss later).
  • Quantization layers support meta creation and quantized materialization (i.e., from full-precision weights).
  • Materialization logic of full-precision tensors (materialize firstly from checkpoints, and if not found in any checkpoint, use reset_parameters).
  • Change the defined models to use the layers with reset_parameters implemented; change the training / inference entry scripts to use the new model creation logic.

This PR is going to involve an extensive code refactor and need thorough testings so mark it as draft for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant