From 1dd80cf8cc29a572f997e2eadb2d96ccefb44182 Mon Sep 17 00:00:00 2001 From: Ziyi Lin Date: Tue, 29 Aug 2023 02:09:21 +0800 Subject: [PATCH 1/5] layer with reset_parameter implemented. --- accessory/model/layers/__init__.py | 7 + accessory/model/layers/linear.py | 91 ++++++++ .../model/layers/tensor_parallel/__init__.py | 4 + .../model/layers/tensor_parallel/embedding.py | 128 +++++++++++ .../model/layers/tensor_parallel/linear.py | 207 ++++++++++++++++++ .../model/layers/tensor_parallel/utils.py | 81 +++++++ 6 files changed, 518 insertions(+) create mode 100644 accessory/model/layers/__init__.py create mode 100644 accessory/model/layers/linear.py create mode 100644 accessory/model/layers/tensor_parallel/__init__.py create mode 100644 accessory/model/layers/tensor_parallel/embedding.py create mode 100644 accessory/model/layers/tensor_parallel/linear.py create mode 100644 accessory/model/layers/tensor_parallel/utils.py diff --git a/accessory/model/layers/__init__.py b/accessory/model/layers/__init__.py new file mode 100644 index 00000000..2804e6a2 --- /dev/null +++ b/accessory/model/layers/__init__.py @@ -0,0 +1,7 @@ +from .linear import Linear +from .tensor_parallel import ( + ColumnParallelLinear, RowParallelLinear, ParallelEmbedding, +) + +__all__ = ["Linear", "ColumnParallelLinear", "RowParallelLinear", + "ParallelEmbedding"] diff --git a/accessory/model/layers/linear.py b/accessory/model/layers/linear.py new file mode 100644 index 00000000..82354c13 --- /dev/null +++ b/accessory/model/layers/linear.py @@ -0,0 +1,91 @@ +import functools +import math +from typing import Any, Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def get_default_linear_weight_init_fn(): + return functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5)) + + +def get_default_linear_bias_init_fn(fan_in): + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + return functools.partial(nn.init.uniform_, a=-bound, b=bound) + + +class Linear(nn.Module): + r"""A Linear module mostly compatible with the PyTorch v2.0.1 builtin one, + with additional initializer args ``*_init_method``. This is to support + deferred custom weight initialization: We expect that parameters are + materialized and set by calling ``Module.reset_parameters()`` in deferred + initialization, but the ``reset_parameters`` of the builtin ``Linear`` + layer always uses default initialization, making custom initialization + (e.g., ``xavier_uniform`` or zero initialization) impossible. We + reimplement a Linear module whose ``reset_parameter`` method respects the + initializers passed in by the user. + + Args: + in_features (int): Input feature dimension. + out_features (int): Output feature dimension. + bias (bool): Whether a learnable bias is added. Default is ``False``. + weight_init_fn (Callable[[torch.Tensor], Any], optional): Initializer + function of the ``weight`` parameter. If not set, follows the + default initialization of the builtin ``nn.Linear``. + bias_init_fn (Callable[[torch.Tensor], Any], optional): Initializer + function of the ``bias`` parameter. If not set, follows the default + initialization of the builtin ``nn.Linear``. + device: The device to be passed into the factory function when creating + the parameter tensors. + dtype: The dtype to be passed into the factory function when creating + the parameter tensors. + """ + + def __init__( + self, in_features: int, out_features: int, bias: bool = True, + weight_init_fn: Optional[Callable[[torch.Tensor], Any]] = None, + bias_init_fn: Optional[Callable[[torch.Tensor], Any]] = None, + device=None, dtype=None, + ) -> None: + super().__init__() + + factory_kwargs = {"device": device, "dtype": dtype} + self.in_features = in_features + self.out_features = out_features + self.weight_init_fn = weight_init_fn + self.bias_init_fn = bias_init_fn + + self.weight = nn.Parameter( + torch.empty([out_features, in_features], **factory_kwargs) + ) + if bias: + self.bias = nn.Parameter( + torch.empty([out_features], **factory_kwargs) + ) + else: + self.register_parameter("bias", None) + + self.reset_parametes() + + def reset_parametes(self) -> None: + if not self.weight.is_meta: + weight_init_fn = ( + self.weight_init_fn or get_default_linear_weight_init_fn() + ) + weight_init_fn(self.weight.data) + if self.bias is not None and not self.bias.is_meta: + bias_init_fn = ( + self.bias_init_fn + or get_default_linear_bias_init_fn(self.in_features) + ) + bias_init_fn(self.bias.data) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight, self.bias) + + def extra_repr(self) -> str: + return "in_features={}, out_features={}, bias={}".format( + self.in_features, self.out_features, self.bias is not None + ) diff --git a/accessory/model/layers/tensor_parallel/__init__.py b/accessory/model/layers/tensor_parallel/__init__.py new file mode 100644 index 00000000..101d7fbc --- /dev/null +++ b/accessory/model/layers/tensor_parallel/__init__.py @@ -0,0 +1,4 @@ +from .linear import ColumnParallelLinear, RowParallelLinear +from .embedding import ParallelEmbedding + +__all__ = ["ColumnParallelLinear", "RowParallelLinear", "ParallelEmbedding"] diff --git a/accessory/model/layers/tensor_parallel/embedding.py b/accessory/model/layers/tensor_parallel/embedding.py new file mode 100644 index 00000000..8cfdfa85 --- /dev/null +++ b/accessory/model/layers/tensor_parallel/embedding.py @@ -0,0 +1,128 @@ +from typing import Any, Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairscale.nn.model_parallel.initialize import ( + get_model_parallel_world_size, +) +from fairscale.nn.model_parallel.mappings import ( + copy_to_model_parallel_region, + gather_from_model_parallel_region, +) +from .utils import init_tensor_parallel_weights + + +class ParallelEmbedding(nn.Module): + r"""A tensor-parallel embedding layer. The output feature dimensions are + divided among the tensor parallel ranks. Each part of the embeddings is + calculated separately on each rank and gathered to form the complete + embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If specified, the entries at + :attr:`padding_idx` do not contribute to the gradient; therefore, + the embedding vector at :attr:`padding_idx` is not updated during + training, i.e. it remains as a fixed "pad". For a newly + constructed Embedding, the embedding vector at :attr:`padding_idx` + will default to all zeros, but can be updated to another value to + be used as the padding vector. + scale_grad_by_freq (bool, optional): If given, this will scale + gradients by the inverse of frequency of the words in the + mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` + matrix will be a sparse tensor. See Notes for more details + regarding sparse gradients. + init_fn (Callable[[torch.Tensor], Any], optional): Initializer function + of the ``bias`` parameter. If set to ``None``, follows the default + initialization of the PyTorch builtin ``torch.nn.Embedding`` layer. + + Attributes: + weight (Tensor): the learnable weights of the module of shape + (num_embeddings, embedding_dim) initialized from + :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape + containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and + :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` + (`CPU`) + + .. note:: + The default initialization of the ``weight`` parameter is different in + PyTorch and fairscale: The former uses ``torch.nn.init.normal_`` while + the latter uses `torch.nn.init.xavier_normal_``. We follow the PyTorch + default behavior. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + init_fn: Optional[Callable[[torch.Tensor], Any]] = None, + ) -> None: + super().__init__() + self.num_emeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + self.init_fn = init_fn + + tp_world_size = get_model_parallel_world_size() + assert self.embdding_dim % tp_world_size == 0, ( + "ParallelEmbedding currently requires that the embedding " + "dimension is evenly divisible by the tensor parallel world size." + ) + self.local_embeddding_dim = embedding_dim // tp_world_size + + self.weight = nn.Parameter( + torch.empty([num_embeddings, self.local_embeddding_dim]) + ) + + self.reset_parameters() + + def reset_parameters(self) -> None: + init_fn = self.init_fn or nn.init.normal_ + init_tensor_parallel_weights(self.weight, init_fn, 1) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + input_parallel = copy_to_model_parallel_region(input_) + output_parallel = F.embedding( + input_parallel, + self.weight, + self.padding_idx, + None, 2.0, # max_norm and norm_type, non-trivial to impl for tp. + self.scale_grad_by_freq, + self.sparse, + ) + output = gather_from_model_parallel_region(output_parallel) + return output + + def extra_repr(self) -> str: + s = "{num_embeddings}, {embedding_dim}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) diff --git a/accessory/model/layers/tensor_parallel/linear.py b/accessory/model/layers/tensor_parallel/linear.py new file mode 100644 index 00000000..b587b9f0 --- /dev/null +++ b/accessory/model/layers/tensor_parallel/linear.py @@ -0,0 +1,207 @@ +from typing import Any, Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairscale.nn.model_parallel.initialize import ( + get_model_parallel_world_size +) +from fairscale.nn.model_parallel.mappings import ( + copy_to_model_parallel_region, + gather_from_model_parallel_region, + scatter_to_model_parallel_region, + reduce_from_model_parallel_region, +) + +from ..linear import ( + get_default_linear_weight_init_fn, + get_default_linear_bias_init_fn, +) +from .utils import init_tensor_parallel_weights + + +class ColumnParallelLinear(nn.Module): + r"""Linear layer with column-wise tensor parallelism. A column parallel + linear layer expects that the input tensor is replicated among tensor + parallel ranks, and each rank calculate a part of the output dimensions. + + Args: + in_features (int): Input feature dimension. + out_features (int): Output feaature dimension. + bias (bool): Whether a learnable bias is added. Default is ``False``. + weight_init_fn (Callable[[torch.Tensor], Any], optional): Initializer + function of the ``weight`` parameter. If not set, follows the + default initialization of the builtin ``nn.Linear``. The given + function should assume that the input tensor is unsharded and the + distribution of the tensor is taken care of by the outside logic. + bias_init_fn (Callable[[torch.Tensor], Any], optional): Initializer + function of the ``bias`` parameter. If not set, follows the default + initialization of the builtin ``nn.Linear``. The given function + should assume that the input tensor is unsharded and the + distribution of the tensor is taken care of by the outside logic. + gather_output (bool): Whether output should be all-gathered after being + calculated separately on each rank. Default is ``True``. + + .. note:: + The default initialization of the ``bias`` parameter is different in + PyTorch and fairscale: The former uses a uniform distribution while the + latter uses an all-zero constant initialization. We follow the official + PyTorch behavior. To use the fairscale behavior, pass + ``torch.nn.init.zeros_`` as the ``bias_init_fn`` argument. + """ + + def __init__( + self, in_features: int, out_features: int, bias: bool = True, + weight_init_fn: Optional[Callable[[torch.Tensor], Any]] = None, + bias_init_fn: Optional[Callable[[torch.Tensor], Any]] = None, + gather_output: bool = True, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight_init_fn = weight_init_fn + self.bias_init_fn = bias_init_fn + self.gather_output = gather_output + + tp_world_size = get_model_parallel_world_size() + assert self.out_features % tp_world_size == 0, ( + "ColumnParallelLinear currently requires that the output " + "dimension is evenly divisible by the tensor parallel world size." + ) + self.local_out_features = self.out_features // tp_world_size + + self.weight = nn.Parameter( + torch.empty([self.local_out_features, in_features]) + ) + if bias: + self.bias = nn.Parameter(torch.empty([self.local_out_features])) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + weight_init_fn = ( + self.weight_init_fn or get_default_linear_weight_init_fn() + ) + init_tensor_parallel_weights(self.weight, weight_init_fn, 0) + if self.bias is not None: + bias_init_fn = ( + self.bias_init_fn + or get_default_linear_bias_init_fn(self.in_features) + ) + init_tensor_parallel_weights(self.bias, bias_init_fn, 0) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + input_parallel = copy_to_model_parallel_region(input_) + output_parallel = F.linear(input_parallel, self.weight, self.bias) + if self.gather_output: + output = gather_from_model_parallel_region(output_parallel) + return output + else: + return output_parallel + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, " + f"out_features={self.out_features}, " + f"local_out_features={self.local_out_features}, " + f"bias={self.bias is not None}, " + f"gather_output={self.gather_output}" + ) + + +class RowParallelLinear(nn.Module): + r"""Linear layer with row-wise tensor parallelism. A row parallel linear + layer divides the input feature dimensions among the tensor parallel ranks, + calculates the linear mapping on each part of the dimensions and sum the + results to form the output. + + Args: + in_features (int): Input feature dimension. + out_features (int): Output feaature dimension. + bias (bool): Whether a learnable bias is added. Default is ``False``. + weight_init_fn (Callable[[torch.Tensor], Any], optional): Initializer + function of the ``weight`` parameter. If not set, follows the + default initialization of the builtin ``nn.Linear``. The given + function should assume that the input tensor is unsharded and the + distribution of the tensor is taken care of by the outside logic. + bias_init_fn (Callable[[torch.Tensor], Any], optional): Initializer + function of the ``bias`` parameter. If not set, follows the default + initialization of the builtin ``nn.Linear``. The given function + should assume that the input tensor is unsharded and the + distribution of the tensor is taken care of by the outside logic. + input_is_parallel (bool): If true, assumes that the input tensor is + already sharded (e.g., the output of a ColumnParallelLinear in + which ``gather_output=False``). + + .. note:: + The default initialization of the ``bias`` parameter is different in + PyTorch and fairscale: The former uses a uniform distribution while the + latter uses an all-zero constant initialization. We follow the official + PyTorch behavior. To use the fairscale behavior, pass + ``torch.nn.init.zeros_`` as the ``bias_init_fn`` argument. + """ + + def __init__( + self, in_features: int, out_features: int, bias: bool = True, + weight_init_fn: Optional[Callable[[torch.Tensor], Any]] = None, + bias_init_fn: Optional[Callable[[torch.Tensor], Any]] = None, + input_is_parallel: bool = False, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight_init_fn = weight_init_fn + self.bias_init_fn = bias_init_fn + self.input_is_parallel = input_is_parallel + + tp_world_size = get_model_parallel_world_size() + assert self.in_features % tp_world_size == 0, ( + "RowParallelLinear currently requires that the output dimension" + "is evenly divisible by the tensor parallel world size." + ) + self.local_in_features = in_features + + self.weight = nn.Parameter( + torch.empty([self.out_features, self.local_in_features]) + ) + if bias: + self.bias = nn.Parameter(torch.empty([self.out_features])) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + weight_init_fn = ( + self.weight_init_fn or get_default_linear_weight_init_fn() + ) + init_tensor_parallel_weights(self.weight, weight_init_fn, 1) + if self.bias is not None: + bias_init_fn = ( + self.bias_init_fn + or get_default_linear_bias_init_fn(self.in_features) + ) + init_tensor_parallel_weights(self.bias, bias_init_fn, -1) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + input_parallel = ( + input_ if self.input_is_parallel else + scatter_to_model_parallel_region(input_) + ) + output_parallel = F.linear(input_parallel, self.weight) + output = reduce_from_model_parallel_region(output_parallel) + if self.bias is not None: + output = output + self.bias + return output + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, " + f"local_in_features={self.local_in_features}, " + f"out_features={self.out_features}, " + f"bias={self.bias is not None}, " + f"input_is_parallel={self.input_is_parallel}" + ) diff --git a/accessory/model/layers/tensor_parallel/utils.py b/accessory/model/layers/tensor_parallel/utils.py new file mode 100644 index 00000000..c617f237 --- /dev/null +++ b/accessory/model/layers/tensor_parallel/utils.py @@ -0,0 +1,81 @@ +from typing import Any, Callable + +import torch +import torch.distributed as dist + +from fairscale.nn.model_parallel.initialize import ( + get_model_parallel_group, + get_model_parallel_src_rank, + get_model_parallel_rank, + get_model_parallel_world_size, +) + + +def _broadcast_replicated_tensor(tensor: torch.Tensor) -> None: + group = get_model_parallel_group() + backend = dist.get_backend(group) + reduction_device = "cuda" if backend == "nccl" else tensor.device + + bcast_tensor = tensor.to(reduction_device) + dist.broadcast(bcast_tensor, get_model_parallel_src_rank(), group) + if bcast_tensor is not tensor: + tensor.copy_(bcast_tensor) + + +def _scatter_distributed_tensor( + tensor: torch.Tensor, master_tensor: torch.Tensor, shard_dim: int +) -> None: + group = get_model_parallel_group() + backend = dist.get_backend(group) + reduction_device = "cuda" if backend == "nccl" else tensor.device + + if get_model_parallel_rank() == 0: + master_tensor = master_tensor.to(reduction_device) + recv_tensor = tensor.to(reduction_device) + dist.scatter(recv_tensor, master_tensor.split(tensor.size(shard_dim)), + get_model_parallel_src_rank(), group) + else: + recv_tensor = tensor.to(reduction_device) + dist.scatter(recv_tensor, None, get_model_parallel_src_rank(), group) + if recv_tensor is not tensor: + tensor.copy_(recv_tensor) + + +def init_tensor_parallel_weights( + tensor: torch.Tensor, init_fn: Callable[[torch.Tensor], Any], + shard_dim: int = -1 +) -> None: + r"""This is a helper function that initializes a tensor-parallel tensor + from a regular tensor-parallel-unaware ``init_fn``. A typical use case is + that ``init_fn`` may calculate the initialization statistics based on the + ``fan_in`` or ``fan_out`` measured with the shape of the tensor which + will be incorrect if the tensor is sharded across tensor-parallel ranks. + Thus, we create a helper function that initializes a tensor as a whole and + then distribute it across the model parallel ranks. + + Args: + tensor (torch.Tensor): The (tensor-parallel-sharded) tensor to + initialize. + init_fn (Callable[[torch.Tensor], Any]): The tensor-parallel-unaware + initializer to be called on the unsharded weights. + shard_dim (int): The sharding dimension of the tensor. If < 0, the + tensor is treated as replicated. Default is -1. + """ + if tensor.is_meta: + return + + if shard_dim < 0: + if get_model_parallel_rank() == 0: + init_fn(tensor.data) + _broadcast_replicated_tensor(tensor.data) + return + + if get_model_parallel_rank() == 0: + master_tensor_shape = list(tensor.size()) + master_tensor_shape[shard_dim] *= get_model_parallel_world_size() + master_tensor = torch.empty(master_tensor_shape, + device=tensor.device, dtype=tensor.dtype) + else: + master_tensor = None + init_fn(master_tensor) + _scatter_distributed_tensor(tensor.data, master_tensor, shard_dim) From a4d811e59da6f0500c58fdc7e4d5b2626e1ec07c Mon Sep 17 00:00:00 2001 From: Ziyi Lin Date: Tue, 29 Aug 2023 21:33:02 +0800 Subject: [PATCH 2/5] update peft modules. --- accessory/model/peft.py | 259 ++++++++-------------------------------- 1 file changed, 48 insertions(+), 211 deletions(-) diff --git a/accessory/model/peft.py b/accessory/model/peft.py index 93002cc4..46a71427 100644 --- a/accessory/model/peft.py +++ b/accessory/model/peft.py @@ -1,233 +1,70 @@ -from typing import Callable, Optional +import functools import torch import torch.nn as nn -import torch.nn.functional as F -from torch.nn import Linear, Parameter, init -from torch import Tensor +from timm.models.layers import trunc_normal_ -from timm.models.layers import trunc_normal_, lecun_normal_, to_2tuple - -from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size -from fairscale.nn.model_parallel.mappings import ( - copy_to_model_parallel_region, - gather_from_model_parallel_region, - reduce_from_model_parallel_region, - scatter_to_model_parallel_region, -) -from fairscale.nn.model_parallel.utils import VocabUtility, divide_and_check_no_remainder -from fairscale.nn.model_parallel.layers import _initialize_affine_weight -from fairscale.nn.model_parallel.layers import ( - RowParallelLinear, - ColumnParallelLinear, -) +from .layers import ColumnParallelLinear, RowParallelLinear, Linear class LoraColumnParallelLinear(ColumnParallelLinear): - """Linear layer with column parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. + r"""ColumnParallelLinear with LoRA. For unlisted arguments see the + documentation for ``ColumnParallelLinear``. - Arguments: - in_features: first dimension of matrix A. - out_features: second dimension of matrix A. - bias: If true, add bias - gather_output: If true, call all-gether on output and make Y avaiable - to all GPUs, otherwise, every GPU will have its output - which is Y_i = XA_i - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. + Args: + lora_rank (int): Bottleneck dimension in the LoRA projections. Default + to ``0``. Only supported as kwargs. """ - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - gather_output: bool = True, - init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, - stride: int = 1, - keep_master_weight_for_test: bool = False, - lora_rank=0 - ) -> None: - nn.Module.__init__(self) - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.gather_output = gather_output - # Divide the weight matrix along the last dimension. - world_size = get_model_parallel_world_size() - self.output_size_per_partition = divide_and_check_no_remainder(out_features, world_size) - - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) - if bias: - self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter("bias", None) + def __init__(self, *args, **kwargs) -> None: + self.lora_rank = kwargs.pop("lora_rank", 0) + super().__init__(*args, **kwargs) - # Initialize weight. - self.master_weight = _initialize_affine_weight( - self.weight, - self.out_features, - self.in_features, - self.output_size_per_partition, - 0, - init_method, - stride=stride, - return_master_weight=keep_master_weight_for_test, - ) - - self.lora_rank = lora_rank if self.lora_rank > 0: - # if world_size > 1: - # raise NotImplemented("Lora with model parallel with change the original behavior, not yet supported") - self.lora_a = nn.Linear(self.in_features, self.lora_rank, bias=False) - trunc_normal_(self.lora_a.weight, std=.02) - self.lora_b = ColumnParallelLinear(self.lora_rank, self.out_features, bias=False, gather_output=gather_output) - nn.init.zeros_(self.lora_b.weight) - else: - self.lora_a = None - self.lora_b = None - - def get_master_weight(self) -> torch.Tensor: - return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1) - - def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore - # Set up backprop all-reduce. - input_parallel = copy_to_model_parallel_region(input_) - # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight, self.bias) - if self.lora_a is not None: - modification = self.lora_b(self.lora_a(input_)) - else: - modification = None - - if self.gather_output: - # All-gather across the partitions. - output = gather_from_model_parallel_region(output_parallel) - else: - output = output_parallel - - if modification is not None: - output = output + modification + self.lora_a = Linear( + self.in_features, self.lora_rank, bias=False, + weight_init_fn=functools.partial(trunc_normal_, std=.02) + ) + self.lora_b = ColumnParallelLinear( + self.lora_rank, self.out_features, bias=False, + weight_init_fn=nn.init.zeros_, + gather_output=self.gather_output, + ) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + output = super().forward(input_) + if self.lora_rank > 0: + output = output + self.lora_b(self.lora_a(input_)) return output class LoraRowParallelLinear(RowParallelLinear): - """Linear layer with row parallelism. + r"""RowParallelLinear with LoRA. For unlisted arguments see the + documentation for ``RowParallelLinear``. - The linear layer is defined as Y = XA + b. A is parallelized along - its first dimension and X along its second dimension as: - - - - | A_1 | - | . | - A = | . | X = [X_1, ..., X_p] - | . | - | A_p | - - - - Arguments: - in_features: first dimension of matrix A. - out_features: second dimension of matrix A. - bias: If true, add bias. Note that bias is not parallelized. - input_is_parallel: If true, we assume that the input is already - split across the GPUs and we do not split - again. - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. + Args: + lora_rank (int): Bottleneck dimension in the LoRA projections. Default + to ``0``. Only supported as kwargs. """ - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - input_is_parallel: bool = False, - init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, - stride: int = 1, - keep_master_weight_for_test: bool = False, - lora_rank = 0 - ): - nn.Module.__init__(self) - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.input_is_parallel = input_is_parallel - # Divide the weight matrix along the last dimension. - world_size = get_model_parallel_world_size() - self.input_size_per_partition = divide_and_check_no_remainder(in_features, world_size) - - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition)) - if bias: - self.bias = Parameter(torch.Tensor(self.out_features)) - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter("bias", None) - - # Initialize weight. - self.master_weight = _initialize_affine_weight( - self.weight, - self.out_features, - self.in_features, - self.input_size_per_partition, - 1, - init_method, - stride=stride, - return_master_weight=keep_master_weight_for_test, - ) + def __init__(self, *args, **kwargs) -> None: + self.lora_rank = kwargs.pop("lora_rank", 0) + super().__init__(*args, **kwargs) - self.lora_rank = lora_rank if self.lora_rank > 0: - # if world_size > 1: - # raise NotImplemented("Lora with model parallel with change the original behavior, not yet supported") - self.lora_a = RowParallelLinear(self.in_features, self.lora_rank, bias=False, input_is_parallel=True) - trunc_normal_(self.lora_a.weight, std=.02) - self.lora_b = nn.Linear(self.lora_rank, self.out_features, bias=False) - nn.init.zeros_(self.lora_b.weight) - else: - self.lora_a = None - self.lora_b = None - - def get_master_weight(self) -> torch.Tensor: - return gather_from_model_parallel_region(self.weight.data) - - def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore - # Set up backprop all-reduce. - if self.input_is_parallel: - input_parallel = input_ - else: - input_parallel = scatter_to_model_parallel_region(input_) - # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight) - # All-reduce across all the partitions. - output_ = reduce_from_model_parallel_region(output_parallel) - if self.lora_a is not None: - modification = self.lora_b(self.lora_a(input_parallel)) - output_ = output_ + modification - if self.bias is not None: - output = output_ + self.bias - else: - output = output_ - return output \ No newline at end of file + self.lora_a = RowParallelLinear( + self.in_features, self.lora_rank, bias=False, + weight_init_fn=functools.partial(trunc_normal_, std=.02), + input_is_parallel=self.input_is_parallel, + ) + self.lora_b = Linear( + self.lora_rank, self.out_features, bias=False, + weight_init_fn=nn.init.zeros_, + ) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + output = super().forward(input_) + if self.lora_rank > 0: + output = output + self.lora_b(self.lora_a(input_)) + return output From 9852978879697c84200042747017bb1104c410f0 Mon Sep 17 00:00:00 2001 From: Ziyi Lin Date: Fri, 1 Sep 2023 16:19:44 +0800 Subject: [PATCH 3/5] update utils for new tensor parallel layers --- accessory/util/misc.py | 44 ++++++--------------- accessory/util/tensor_parallel.py | 65 +++++++++++++++++++++++++++---- 2 files changed, 69 insertions(+), 40 deletions(-) diff --git a/accessory/util/misc.py b/accessory/util/misc.py index b82c72f7..dc498d75 100644 --- a/accessory/util/misc.py +++ b/accessory/util/misc.py @@ -16,6 +16,7 @@ from collections import defaultdict, deque from pathlib import Path import subprocess +from warnings import warn import torch import torch.distributed as dist @@ -582,42 +583,19 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()): def broadcast_nonmp_parameters(model): - if fs_init.get_model_parallel_world_size() == 1: - return - print("starting broadcast non-model-parallel parameters within model parallel group") - memo = set() - modules = model.named_modules(prefix='', remove_duplicate=True) - for module_prefix, module in modules: - members = dict(module._parameters.items()) - for k, v in members.items(): - name = module_prefix + ('.' if module_prefix else '') + k - if v is None or v in memo: - continue - if getattr(v, "is_model_parallel", False): - print(f"ignore: {name}") - continue - memo.add(v) - dist.broadcast(v, src=fs_init.get_model_parallel_src_rank(), group=fs_init.get_model_parallel_group()) - print("braodcast done") + warn("util.misc.broadcast_nonmp_parameters is deprecated. " + "Use util.tensor_parallel.broadcast_nonmp_parameters instead.", + DeprecationWarning, stacklevel=2) + from .tensor_parallel import broadcast_nonmp_parameters + broadcast_nonmp_parameters(model) def mark_mp_params(model: torch.nn.Module): - from fairscale.nn.model_parallel.layers import ( - RowParallelLinear, - ColumnParallelLinear, - ParallelEmbedding, - ) - for m in model.modules(): - if isinstance(m, ColumnParallelLinear): - m.weight.is_model_parallel = True - if m.bias is not None: - m.bias.is_model_parallel = True - - if isinstance(m, RowParallelLinear): - m.weight.is_model_parallel = True - - if isinstance(m, ParallelEmbedding): - m.weight.is_model_parallel = True + warn("util.misc.mark_mp_params is deprecated. " + "Use util.tensor_parallel.mark_mp_params instead.", + DeprecationWarning, stacklevel=2) + from .tensor_parallel import mark_mp_params + mark_mp_params(model) def print_param_status(model: torch.nn.Module) -> None: diff --git a/accessory/util/tensor_parallel.py b/accessory/util/tensor_parallel.py index 4a950e7f..78f3ef5a 100644 --- a/accessory/util/tensor_parallel.py +++ b/accessory/util/tensor_parallel.py @@ -5,12 +5,16 @@ import torch import torch.nn as nn +import torch.distributed as dist import fairscale.nn.model_parallel.initialize as fs_init from fairscale.nn.model_parallel.layers import ( - ColumnParallelLinear, - RowParallelLinear, - ParallelEmbedding, + ColumnParallelLinear as fs_ColumnParallelLinear, + RowParallelLinear as fs_RowParallelLinear, + ParallelEmbedding as fs_ParallelEmbedding, +) +from model.layers.tensor_parallel import ( + ColumnParallelLinear, RowParallelLinear, ParallelEmbedding, ) # _MODEL_PARALLEL_MODULES defines a list of module classes that contains @@ -33,6 +37,11 @@ (ColumnParallelLinear, {"weight": 0, "bias": 0}), (RowParallelLinear, {"weight": 1, "bias": -1}), (ParallelEmbedding, {"weight": 1}), + # TODO: fs_* layer registrations are to be removed after the + # migration is completed. + (fs_ColumnParallelLinear, {"weight": 0, "bias": 0}), + (fs_RowParallelLinear, {"weight": 1, "bias": -1}), + (fs_ParallelEmbedding, {"weight": 1}), ] FORMAT_FILENAME_PATTERNS: Dict[str, re.Pattern] = { @@ -327,10 +336,10 @@ def infer_checkpoint_format_and_mp_size(path: str) -> str: raise NotImplementedError(f"Multiple matched format detected: " f"{inferred_format} and {format}.") if inferred_format is None: - folder_contents = ", ".join( - [x if os.path.isfile(os.path.join(path, x)) else x + " (not a file)" - for x in os.listdir(path)] - ) + folder_contents = ", ".join([ + x if os.path.isfile(os.path.join(path, x)) else x + " (not a file)" + for x in os.listdir(path) + ]) raise NotImplementedError( f"Files in the given folder do not match any format. " f"Contents in the folder: [{folder_contents}]." @@ -534,3 +543,45 @@ def is_complete(self) -> bool: assert all(x >= 0 and x < self._num_shards for x in self._loaded_shards) return len(self._loaded_shards) == self._num_shards + + +def mark_mp_params(model: nn.Module) -> None: + r"""This method marks all parameters in the model that is sharded among + model parallel ranks. The mark may be used for various tensor-parallel- + related processings (e.g., when synchronizing model parameters among the + model parallel workers, only broadcast replicated tensors and skip sharded + tensors). + + Args: + model (torch.nn.Module): The model whose parameters are to be marked. + """ + for m in model.modules(): + for module_class, shard_dict in _MODEL_PARALLEL_MODULES.items(): + if isinstance(m, module_class): + for key, dim in shard_dict.items(): + if getattr(m, key, None) is not None and dim >= 0: + getattr(m, key).is_model_parallel = True + break + + +def broadcast_nonmp_parameters(model: nn.Module) -> None: + r"""This method broadcasts replicated parameters among tensor parallel + workers. Sharded parameters are skipped. + + Args: + model (torch.nn.Module): The model whose replicated parameters are to + be broadcasted. + """ + if fs_init.get_model_parallel_world_size() == 1: + print("Skip broadcasting parameters in tensor parallel groups as " + "group size is 1.") + return + print("Starting broadcast non-model-parallel parameters within tensor " + "parallel groups.") + for name, param in model.named_parameters(): + if getattr(param, "is_model_parallel", False): + dist.broadcast(param, src=fs_init.get_model_parallel_src_rank(), + group=fs_init.get_model_parallel_group()) + else: + print(f"Ignoring sharded parameter: {name}") + print("Broadcasting within tensor parallel groups is done.") From 4266554f34e0d22ebbe4359824521c63959dec1d Mon Sep 17 00:00:00 2001 From: Ziyi Lin Date: Fri, 1 Sep 2023 17:19:56 +0800 Subject: [PATCH 4/5] default_tensor_type supports meta creation Also add more docs. --- accessory/util/tensor_type.py | 109 +++++++++++++++++++++++++++------- 1 file changed, 87 insertions(+), 22 deletions(-) diff --git a/accessory/util/tensor_type.py b/accessory/util/tensor_type.py index e9ddb9fe..e858d8b9 100644 --- a/accessory/util/tensor_type.py +++ b/accessory/util/tensor_type.py @@ -1,36 +1,59 @@ from types import TracebackType -from typing import Any, Optional +from typing import Any, List, Optional, Tuple import torch import torch.nn as nn class default_tensor_type: - _tensor_type_stack = [(torch.float, "cpu")] - + r"""A context manager that maintains a stack of tensor type states. Each + state is a tuple of 3 elements: (1) The default scalar dtype of new + tensors; (2) The default real device of the new tensors (i.e., not + including the ``meta`` device) and (3) Whether new tensors should be + created as ``meta``. + + Each argument is optional and will inherit the last value on the stack if + passed ``None``. + + .. note:: + Unlike PyTorch which manages ``meta`` as a special type of device, we + manage ``is_meta`` as a separate dimension in our states. This allows + us to maintain the materialization device while entering or exiting + ``meta`` creation state freely. + + Args: + dtype (torch.dtype, Optional): The scalar data type of the new tensors. + device (str, Optional): The string representing the real device of the + new tensors. ``meta`` device and device ordinals are not supported. + is_meta (bool, Optional): Whether new tensors should be created as + ``meta``. + """ + + _tensor_type_stack: List[Tuple[torch.dtype, str, bool]] = [ + (torch.float, "cpu", False) + ] + def __init__( self, dtype: Optional[torch.dtype] = None, device: Optional[str] = None, + is_meta: Optional[bool] = None, ) -> None: # Only limited combinations are supported. assert device is None or device in ["cpu", "cuda"] - assert dtype is None or dtype in [torch.float, torch.bfloat16, torch.half] - self.dtype, self.device = dtype, device - + assert dtype is None or dtype in [torch.float, torch.bfloat16, + torch.half] + self.dtype, self.device, self.is_meta = dtype, device, is_meta + def __enter__(self) -> None: - dtype, device = self.dtype, self.device + dtype, device, is_meta = self.dtype, self.device, self.is_meta if dtype is None: dtype = default_tensor_type._tensor_type_stack[-1][0] if device is None: device = default_tensor_type._tensor_type_stack[-1][1] - default_tensor_type._tensor_type_stack.append((dtype, device)) - - # We use all 3 calls since the new apis (set_default_device, set_default_dtype) - # seems to be ineffective sometimes (e.g., set_default_device is ineffective to - # torch.Tensor calls). - torch.set_default_tensor_type(default_tensor_type.get_tensor_type(dtype, device)) - torch.set_default_device(device) - torch.set_default_dtype(dtype) + if is_meta is None: + is_meta = default_tensor_type._tensor_type_stack[-1][2] + default_tensor_type._tensor_type_stack.append((dtype, device, is_meta)) + default_tensor_type._set_pytorch_state_by_last_state_tuple() def __exit__( self, @@ -39,10 +62,27 @@ def __exit__( exc_tb: Optional[TracebackType], ) -> None: default_tensor_type._tensor_type_stack.pop() - dtype, device = default_tensor_type._tensor_type_stack[-1] + default_tensor_type._set_pytorch_state_by_last_state_tuple() - torch.set_default_tensor_type(default_tensor_type.get_tensor_type(dtype, device)) - torch.set_default_device(device) + @staticmethod + def _set_pytorch_state_by_last_state_tuple(): + device, dtype, is_meta = default_tensor_type._tensor_type_stack[-1] + + # We use all 3 calls since the new apis (set_default_device, + # set_default_dtype) seems to be ineffective sometimes (e.g., + # set_default_device is ineffective to torch.Tensor calls). + # + # We are aware that torch.Tensor creator is deprecated as of PyTorch + # v2.0.1. This is a 'catch-all' for some third-party libraries (e.g., + # fairscale) which still uses the old torch.Tensor API but is out of + # our control. + # + # Also, torch.set_default_tensor_type seems to not support the new + # meta tensor feature so we have to fall back to the real device. + torch.set_default_tensor_type( + default_tensor_type.get_tensor_type(dtype, device) + ) + torch.set_default_device(device if not is_meta else "meta") torch.set_default_dtype(dtype) @staticmethod @@ -56,11 +96,36 @@ def get_tensor_type(dtype: torch.dtype, device: str) -> Any: (torch.half, "cuda"): torch.cuda.HalfTensor, }[(dtype, device)] + @staticmethod + def get_current_materialization_device() -> torch.device: + r"""Get the current 'real' device on the default tensor type stack, + regardless of the is_meta state. + """ + return torch.device(default_tensor_type._tensor_type_stack[-1][1]) + def promote_trainable_params_to_fp32(model: nn.Module) -> None: + r"""This method promotes each parameter of a given model with + ``requires_grad=True`` to at least FP32, following the common practice of + mixed precision training that a copy of FP32 master weights is maintained + for optimization despite that each forward and backward pass uses the + down-casted low precision weights (16-bit, or even 8-bit on the newer + hardware). + + .. note:: + The method handles both floating point (real) and complex scalar types. + For complex type, both the real and the imaginary parts are promoted to + FP32 (resulting in the ``torch.complex64`` scalar type). + + Args: + model (torch.nn.Module): The model whose ``requires_grad`` parameters + are promoted to FP32. + """ for param in model.parameters(): if param.requires_grad: - if param.is_floating_point() and torch.finfo(param.dtype).bits < 32: - param.data = param.data.float() - if param.is_complex() and torch.finfo(param.dtype).bits < 32: - param.data = param.data.to(torch.complex64) \ No newline at end of file + if param.is_floating_point(): + if torch.finfo(param.dtype).bits < 32: + param.data = param.data.float() + elif param.is_complex(): + if torch.finfo(param.dtype).bits < 32: + param.data = param.data.to(torch.complex64) From af37bbd99f2111bb459cb4294ff7b83d1f78113f Mon Sep 17 00:00:00 2001 From: Ziyi Lin Date: Fri, 1 Sep 2023 20:00:27 +0800 Subject: [PATCH 5/5] visual models disable meta creation --- accessory/model/LLM/llama.py | 2 +- accessory/model/LLM/llama_adapter.py | 2 +- accessory/model/LLM/llama_ens.py | 39 +++++++++++---------- accessory/model/LLM/llama_peft.py | 2 +- accessory/model/LLM/llama_qformerv2.py | 4 ++- accessory/model/LLM/llama_qformerv2_peft.py | 5 +-- 6 files changed, 29 insertions(+), 25 deletions(-) diff --git a/accessory/model/LLM/llama.py b/accessory/model/LLM/llama.py index 71934f29..bcdf602b 100644 --- a/accessory/model/LLM/llama.py +++ b/accessory/model/LLM/llama.py @@ -318,7 +318,7 @@ def __init__(self, params: ModelArgs, with_visual=False): self.cache_image_words = 0 # for inference if with_visual: print("build llama model with clip") - with default_tensor_type(dtype=torch.half): + with default_tensor_type(dtype=torch.half, is_meta=False): self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') for name, param in self.clip.named_parameters(): param.requires_grad = False diff --git a/accessory/model/LLM/llama_adapter.py b/accessory/model/LLM/llama_adapter.py index 122c85b7..7ebd7377 100644 --- a/accessory/model/LLM/llama_adapter.py +++ b/accessory/model/LLM/llama_adapter.py @@ -315,7 +315,7 @@ def __init__(self, params: ModelArgs, with_visual=False): self.image_words = 0 if with_visual: print("build llama model with clip") - with default_tensor_type(dtype=torch.half): + with default_tensor_type(dtype=torch.half, is_meta=False): self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') for name, param in self.clip.named_parameters(): param.requires_grad = False diff --git a/accessory/model/LLM/llama_ens.py b/accessory/model/LLM/llama_ens.py index ead647da..6a477b02 100644 --- a/accessory/model/LLM/llama_ens.py +++ b/accessory/model/LLM/llama_ens.py @@ -30,7 +30,7 @@ default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5)) from .llama import precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb, repeat_kv - +from util.tensor_type import default_tensor_type @dataclass class ModelArgs: @@ -286,24 +286,25 @@ def __init__(self, params: ModelArgs, with_visual=False): nn.LayerNorm(params.dim) ) - print("build llama model with clip") - self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') - self.clip.transformer = None - - print("build llama model with openclip") - self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms( - "convnext_xxlarge", pretrained="laion2b_s34b_b82k_augreg_soup" - ) - self.openclip_convnext_xxl = self.openclip_convnext_xxl.visual.trunk - self.openclip_convnext_xxl.head.global_pool = nn.Identity() - self.openclip_convnext_xxl.head.flatten = nn.Identity() - - print("build llama model with dinov2") - import os.path - if os.path.exists("/mnt/petrelfs/gaopeng/.cache/torch/hub/facebookresearch_dinov2_main"): - self.dinov2_vitg14 = torch.hub.load("/mnt/petrelfs/gaopeng/.cache/torch/hub/facebookresearch_dinov2_main", "dinov2_vitg14", source="local") - else: - self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14") + with default_tensor_type(is_meta=False): + print("build llama model with clip") + self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') + self.clip.transformer = None + + print("build llama model with openclip") + self.openclip_convnext_xxl, _, _ = open_clip.create_model_and_transforms( + "convnext_xxlarge", pretrained="laion2b_s34b_b82k_augreg_soup" + ) + self.openclip_convnext_xxl = self.openclip_convnext_xxl.visual.trunk + self.openclip_convnext_xxl.head.global_pool = nn.Identity() + self.openclip_convnext_xxl.head.flatten = nn.Identity() + + print("build llama model with dinov2") + import os.path + if os.path.exists("/mnt/petrelfs/gaopeng/.cache/torch/hub/facebookresearch_dinov2_main"): + self.dinov2_vitg14 = torch.hub.load("/mnt/petrelfs/gaopeng/.cache/torch/hub/facebookresearch_dinov2_main", "dinov2_vitg14", source="local") + else: + self.dinov2_vitg14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14") self.visual_proj = nn.Sequential( nn.Linear(3072 + 1024 + 1536, params.dim), diff --git a/accessory/model/LLM/llama_peft.py b/accessory/model/LLM/llama_peft.py index bda48bff..747e2d95 100644 --- a/accessory/model/LLM/llama_peft.py +++ b/accessory/model/LLM/llama_peft.py @@ -289,7 +289,7 @@ def __init__(self, params: ModelArgs, with_visual=False): self.cache_image_words = 0 # for inference if with_visual: print("build llama model with clip") - with default_tensor_type(dtype=torch.half): + with default_tensor_type(dtype=torch.half, is_meta=False): self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') for name, param in self.clip.named_parameters(): param.requires_grad = False diff --git a/accessory/model/LLM/llama_qformerv2.py b/accessory/model/LLM/llama_qformerv2.py index cd2cfe74..6ff21031 100644 --- a/accessory/model/LLM/llama_qformerv2.py +++ b/accessory/model/LLM/llama_qformerv2.py @@ -27,6 +27,7 @@ default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5)) from .llama import precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb, repeat_kv +from util.tensor_type import default_tensor_type @dataclass @@ -273,7 +274,8 @@ def __init__(self, params: ModelArgs, with_visual=False): self.cache_image_words = 0 # for inference if with_visual: print("build llama model with qformerv2") - self.qformer = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16) + with default_tensor_type(is_meta=False): + self.qformer = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16) self.qformer.language_projection = None self.qformer.language_model = None diff --git a/accessory/model/LLM/llama_qformerv2_peft.py b/accessory/model/LLM/llama_qformerv2_peft.py index 004468a0..9d6e915c 100644 --- a/accessory/model/LLM/llama_qformerv2_peft.py +++ b/accessory/model/LLM/llama_qformerv2_peft.py @@ -28,7 +28,7 @@ default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5)) from .llama import precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb, repeat_kv - +from util.tensor_type import default_tensor_type @dataclass class ModelArgs: @@ -288,7 +288,8 @@ def __init__(self, params: ModelArgs, with_visual=False): self.cache_image_words = 0 # for inference if with_visual: print("build llama model with qformerv2") - self.qformer = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16) + with default_tensor_type(is_meta=False): + self.qformer = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16) self.qformer.language_projection = None self.qformer.language_model = None