Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support lazy model init #60

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion accessory/model/LLM/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def __init__(self, params: ModelArgs, with_visual=False):
self.cache_image_words = 0 # for inference
if with_visual:
print("build llama model with clip")
with default_tensor_type(dtype=torch.half):
with default_tensor_type(dtype=torch.half, is_meta=False):
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
for name, param in self.clip.named_parameters():
param.requires_grad = False
Expand Down
2 changes: 1 addition & 1 deletion accessory/model/LLM/llama_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def __init__(self, params: ModelArgs, with_visual=False):
self.image_words = 0
if with_visual:
print("build llama model with clip")
with default_tensor_type(dtype=torch.half):
with default_tensor_type(dtype=torch.half, is_meta=False):
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
for name, param in self.clip.named_parameters():
param.requires_grad = False
Expand Down
39 changes: 20 additions & 19 deletions accessory/model/LLM/llama_ens.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5))

from .llama import precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb, repeat_kv

from util.tensor_type import default_tensor_type

@dataclass
class ModelArgs:
Expand Down Expand Up @@ -286,24 +286,25 @@ def __init__(self, params: ModelArgs, with_visual=False):
nn.LayerNorm(params.dim)
)

print("build llama model with clip")
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
self.clip.transformer = None

print("build llama model with openclip")
self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms(
"convnext_xxlarge", pretrained="laion2b_s34b_b82k_augreg_soup"
)
self.openclip_convnext_xxl = self.openclip_convnext_xxl.visual.trunk
self.openclip_convnext_xxl.head.global_pool = nn.Identity()
self.openclip_convnext_xxl.head.flatten = nn.Identity()

print("build llama model with dinov2")
import os.path
if os.path.exists("/mnt/petrelfs/gaopeng/.cache/torch/hub/facebookresearch_dinov2_main"):
self.dinov2_vitg14 = torch.hub.load("/mnt/petrelfs/gaopeng/.cache/torch/hub/facebookresearch_dinov2_main", "dinov2_vitg14", source="local")
else:
self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14")
with default_tensor_type(is_meta=False):
print("build llama model with clip")
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
self.clip.transformer = None

print("build llama model with openclip")
self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms(
"convnext_xxlarge", pretrained="laion2b_s34b_b82k_augreg_soup"
)
self.openclip_convnext_xxl = self.openclip_convnext_xxl.visual.trunk
self.openclip_convnext_xxl.head.global_pool = nn.Identity()
self.openclip_convnext_xxl.head.flatten = nn.Identity()

print("build llama model with dinov2")
import os.path
if os.path.exists("/mnt/petrelfs/gaopeng/.cache/torch/hub/facebookresearch_dinov2_main"):
self.dinov2_vitg14 = torch.hub.load("/mnt/petrelfs/gaopeng/.cache/torch/hub/facebookresearch_dinov2_main", "dinov2_vitg14", source="local")
else:
self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14")

self.visual_proj = nn.Sequential(
nn.Linear(3072 + 1024 + 1536, params.dim),
Expand Down
2 changes: 1 addition & 1 deletion accessory/model/LLM/llama_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def __init__(self, params: ModelArgs, with_visual=False):
self.cache_image_words = 0 # for inference
if with_visual:
print("build llama model with clip")
with default_tensor_type(dtype=torch.half):
with default_tensor_type(dtype=torch.half, is_meta=False):
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
for name, param in self.clip.named_parameters():
param.requires_grad = False
Expand Down
4 changes: 3 additions & 1 deletion accessory/model/LLM/llama_qformerv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5))

from .llama import precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb, repeat_kv
from util.tensor_type import default_tensor_type


@dataclass
Expand Down Expand Up @@ -273,7 +274,8 @@ def __init__(self, params: ModelArgs, with_visual=False):
self.cache_image_words = 0 # for inference
if with_visual:
print("build llama model with qformerv2")
self.qformer = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
with default_tensor_type(is_meta=False):
self.qformer = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)

self.qformer.language_projection = None
self.qformer.language_model = None
Expand Down
5 changes: 3 additions & 2 deletions accessory/model/LLM/llama_qformerv2_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5))

from .llama import precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb, repeat_kv

from util.tensor_type import default_tensor_type

@dataclass
class ModelArgs:
Expand Down Expand Up @@ -288,7 +288,8 @@ def __init__(self, params: ModelArgs, with_visual=False):
self.cache_image_words = 0 # for inference
if with_visual:
print("build llama model with qformerv2")
self.qformer = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
with default_tensor_type(is_meta=False):
self.qformer = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)

self.qformer.language_projection = None
self.qformer.language_model = None
Expand Down
7 changes: 7 additions & 0 deletions accessory/model/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .linear import Linear
from .tensor_parallel import (
ColumnParallelLinear, RowParallelLinear, ParallelEmbedding,
)

__all__ = ["Linear", "ColumnParallelLinear", "RowParallelLinear",
"ParallelEmbedding"]
91 changes: 91 additions & 0 deletions accessory/model/layers/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import functools
import math
from typing import Any, Callable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


def get_default_linear_weight_init_fn():
return functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5))


def get_default_linear_bias_init_fn(fan_in):
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
return functools.partial(nn.init.uniform_, a=-bound, b=bound)


class Linear(nn.Module):
r"""A Linear module mostly compatible with the PyTorch v2.0.1 builtin one,
with additional initializer args ``*_init_method``. This is to support
deferred custom weight initialization: We expect that parameters are
materialized and set by calling ``Module.reset_parameters()`` in deferred
initialization, but the ``reset_parameters`` of the builtin ``Linear``
layer always uses default initialization, making custom initialization
(e.g., ``xavier_uniform`` or zero initialization) impossible. We
reimplement a Linear module whose ``reset_parameter`` method respects the
initializers passed in by the user.

Args:
in_features (int): Input feature dimension.
out_features (int): Output feature dimension.
bias (bool): Whether a learnable bias is added. Default is ``False``.
weight_init_fn (Callable[[torch.Tensor], Any], optional): Initializer
function of the ``weight`` parameter. If not set, follows the
default initialization of the builtin ``nn.Linear``.
bias_init_fn (Callable[[torch.Tensor], Any], optional): Initializer
function of the ``bias`` parameter. If not set, follows the default
initialization of the builtin ``nn.Linear``.
device: The device to be passed into the factory function when creating
the parameter tensors.
dtype: The dtype to be passed into the factory function when creating
the parameter tensors.
"""

def __init__(
self, in_features: int, out_features: int, bias: bool = True,
weight_init_fn: Optional[Callable[[torch.Tensor], Any]] = None,
bias_init_fn: Optional[Callable[[torch.Tensor], Any]] = None,
device=None, dtype=None,
) -> None:
super().__init__()

factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
self.weight_init_fn = weight_init_fn
self.bias_init_fn = bias_init_fn

self.weight = nn.Parameter(
torch.empty([out_features, in_features], **factory_kwargs)
)
if bias:
self.bias = nn.Parameter(
torch.empty([out_features], **factory_kwargs)
)
else:
self.register_parameter("bias", None)

self.reset_parametes()

def reset_parametes(self) -> None:
if not self.weight.is_meta:
weight_init_fn = (
self.weight_init_fn or get_default_linear_weight_init_fn()
)
weight_init_fn(self.weight.data)
if self.bias is not None and not self.bias.is_meta:
bias_init_fn = (
self.bias_init_fn
or get_default_linear_bias_init_fn(self.in_features)
)
bias_init_fn(self.bias.data)

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)

def extra_repr(self) -> str:
return "in_features={}, out_features={}, bias={}".format(
self.in_features, self.out_features, self.bias is not None
)
4 changes: 4 additions & 0 deletions accessory/model/layers/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .linear import ColumnParallelLinear, RowParallelLinear
from .embedding import ParallelEmbedding

__all__ = ["ColumnParallelLinear", "RowParallelLinear", "ParallelEmbedding"]
128 changes: 128 additions & 0 deletions accessory/model/layers/tensor_parallel/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import Any, Callable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from fairscale.nn.model_parallel.initialize import (
get_model_parallel_world_size,
)
from fairscale.nn.model_parallel.mappings import (
copy_to_model_parallel_region,
gather_from_model_parallel_region,
)
from .utils import init_tensor_parallel_weights


class ParallelEmbedding(nn.Module):
r"""A tensor-parallel embedding layer. The output feature dimensions are
divided among the tensor parallel ranks. Each part of the embeddings is
calculated separately on each rank and gathered to form the complete
embeddings.

Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If specified, the entries at
:attr:`padding_idx` do not contribute to the gradient; therefore,
the embedding vector at :attr:`padding_idx` is not updated during
training, i.e. it remains as a fixed "pad". For a newly
constructed Embedding, the embedding vector at :attr:`padding_idx`
will default to all zeros, but can be updated to another value to
be used as the padding vector.
scale_grad_by_freq (bool, optional): If given, this will scale
gradients by the inverse of frequency of the words in the
mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight`
matrix will be a sparse tensor. See Notes for more details
regarding sparse gradients.
init_fn (Callable[[torch.Tensor], Any], optional): Initializer function
of the ``bias`` parameter. If set to ``None``, follows the default
initialization of the PyTorch builtin ``torch.nn.Embedding`` layer.

Attributes:
weight (Tensor): the learnable weights of the module of shape
(num_embeddings, embedding_dim) initialized from
:math:`\mathcal{N}(0, 1)`

Shape:
- Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape
containing the indices to extract
- Output: :math:`(*, H)`, where `*` is the input shape and
:math:`H=\text{embedding\_dim}`

.. note::
Keep in mind that only a limited number of optimizers support
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad`
(`CPU`)

.. note::
The default initialization of the ``weight`` parameter is different in
PyTorch and fairscale: The former uses ``torch.nn.init.normal_`` while
the latter uses `torch.nn.init.xavier_normal_``. We follow the PyTorch
default behavior.
"""

def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
scale_grad_by_freq: bool = False,
sparse: bool = False,
init_fn: Optional[Callable[[torch.Tensor], Any]] = None,
) -> None:
super().__init__()
self.num_emeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
self.init_fn = init_fn

tp_world_size = get_model_parallel_world_size()
assert self.embdding_dim % tp_world_size == 0, (
"ParallelEmbedding currently requires that the embedding "
"dimension is evenly divisible by the tensor parallel world size."
)
self.local_embeddding_dim = embedding_dim // tp_world_size

self.weight = nn.Parameter(
torch.empty([num_embeddings, self.local_embeddding_dim])
)

self.reset_parameters()

def reset_parameters(self) -> None:
init_fn = self.init_fn or nn.init.normal_
init_tensor_parallel_weights(self.weight, init_fn, 1)
self._fill_padding_idx_with_zero()

def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)

def forward(self, input_: torch.Tensor) -> torch.Tensor:
input_parallel = copy_to_model_parallel_region(input_)
output_parallel = F.embedding(
input_parallel,
self.weight,
self.padding_idx,
None, 2.0, # max_norm and norm_type, non-trivial to impl for tp.
self.scale_grad_by_freq,
self.sparse,
)
output = gather_from_model_parallel_region(output_parallel)
return output

def extra_repr(self) -> str:
s = "{num_embeddings}, {embedding_dim}"
if self.padding_idx is not None:
s += ", padding_idx={padding_idx}"
if self.scale_grad_by_freq is not False:
s += ', scale_grad_by_freq={scale_grad_by_freq}'
if self.sparse is not False:
s += ', sparse=True'
return s.format(**self.__dict__)
Loading