diff --git a/.github/workflows/test_fx_automatic_parallel.yml b/.github/workflows/test_fx_automatic_parallel.yml new file mode 100644 index 00000000000..3c913e3f7ed --- /dev/null +++ b/.github/workflows/test_fx_automatic_parallel.yml @@ -0,0 +1,65 @@ +name: Automatic Model Parallelism Test on GPUs + +on: + pull_request: + branches: + - main + paths: + - 'optimum/fx/parallelization/**.py' + push: + branches: + - main + paths: + - 'optimum/fx/parallelization/**.py' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + run_gpu_tests: + strategy: + fail-fast: false + matrix: + config: + - name: GPU-enabled Optimum Test Suite + image: nvidia/cuda:12.4.1-devel-ubuntu22.04 + gpu_target: ["nvidia-multi-gpu-l4-runners", "nvidia-multi-gpu-a10-runners"] + + name: ${{ matrix.config.name }} + runs-on: + group: "${{matrix.gpu_target}}" + + container: + image: ${{ matrix.config.image }} + options: --mount type=tmpfs,destination=/tmp --shm-size 64gb --gpus all --ipc host -v /mnt/hf_cache:/mnt/cache/ + env: + NCCL_DEBUG: INFO + HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} + defaults: + run: + shell: bash + + steps: + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Checkout optimum + uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: Run nvidia-smi + run: | + nvidia-smi + + - name: Install dependencies + run: | + python3 -m pip install -U pip + python3 -m pip install torch transformers + python3 -m pip install .[tests] + + - name: Run automatic model parallelism tests + run: | + pytest -s -v -o log_cli=true tests/fx/parallelization diff --git a/optimum/fx/parallelization/__init__.py b/optimum/fx/parallelization/__init__.py new file mode 100644 index 00000000000..701badd4d59 --- /dev/null +++ b/optimum/fx/parallelization/__init__.py @@ -0,0 +1,16 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .api import parallelize_backend, parallelize_model +from .core import Config, ParallelExecutionCtx diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py new file mode 100644 index 00000000000..bd307bd93c1 --- /dev/null +++ b/optimum/fx/parallelization/api.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import os +from functools import partial +from typing import List, Union + +import torch +from torch.fx import GraphModule + +from .core import Config, ParallelExecutionCtx +from .passes import build_parallel_pass_pipeline +from .utils import ( + MetaAwareMethodsPatcher, + download_model_from_hf, + initialize_parameter_meta, + move_model_to_device, + try_collect_weight_map, +) + + +def parallelize_backend( + graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config +) -> GraphModule: + ctx.example_inputs = example_inputs + pass_pipeline = build_parallel_pass_pipeline() + graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config) + ctx.compile_times += 1 + ctx.last_optimized_graph_module = graph_module + return graph_module + + +def parallelize_model( + model: Union[torch.nn.Module, str], + parallel_ctx: ParallelExecutionCtx, + *model_args, + **kwargs, +): + """ + API for automatic model parallelism through Pytorch FX. + + Args: + model (Union[torch.nn.Module, str]): + Model to parallelize, could either be a module or a model id on the Huggingface Hub. + parallel_ctx (ParallelExecutionCtx): + Parallel execution context containing process groups the current process belongs to. + *model_args (Any): + Additional postional arguments for intializing the model if a model id is passed. + revision (str, defaults to `main`): + Model revision for weights downloading if a model id is passed. + cache_dir (Optional[str], defaults to `None`): + Cache directory to store downloaded weights. Defaults to None. + local_files_only (bool, defaults to `False`): + Whether to use local files only, will avoid downloading from remote if set to `True`. + skip_load_weights (bool, defaults to `False`): + Whether to skip loading weights from disk to model. + **kwargs (Dict[str, Any]): + Addtional keyword arguments for overriding fields in parallel config, model config and `Model.__init__`. + """ + revision = kwargs.pop("revision", "main") + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", False) + skip_load_weights = kwargs.pop("skip_load_weights", False) + + parallel_config = Config() + for k, v in dict(kwargs).items(): + if k in parallel_config.__dict__: + setattr(parallel_config, k, v) + kwargs.pop(k) + + if isinstance(model, str): + from transformers import AutoConfig + + is_local = os.path.isdir(model) + if not is_local: + hf_folder = download_model_from_hf( + model_name_or_path=model, + cache_dir=cache_dir, + revision=revision, + local_files_only=local_files_only, + skip_download_weights=skip_load_weights, + ) + else: + hf_folder = model + + # should be able to load config using only local files + model_config, kwargs = AutoConfig.from_pretrained( + hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs + ) + + # try getting model class info from config + model_arch = model_config.architectures + model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) + + if not skip_load_weights: + parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder) + + torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None + if torch_dtype is not None: + dtype_orig = model_cls._set_default_torch_dtype(torch_dtype) + + with MetaAwareMethodsPatcher(): + model = model_cls(model_config, *model_args, **kwargs) + # TODO: remove this once support training-time trace + model.eval() + + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + move_model_to_device(model, device=parallel_ctx.current_device) + initialize_parameter_meta(model) + backend = partial(parallelize_backend, ctx=parallel_ctx, config=parallel_config) + model = torch.compile(model, fullgraph=True, backend=backend) + return model diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py new file mode 100644 index 00000000000..cba7d454441 --- /dev/null +++ b/optimum/fx/parallelization/core.py @@ -0,0 +1,167 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.fx import GraphModule + + +class HashableSlice: + def __init__(self, start: Optional[int] = None, stop: Optional[int] = None, step: Optional[int] = None) -> None: + self.start = start + self.stop = stop + self.step = step + + def __hash__(self) -> int: + return hash(f"{self.start},{self.stop},{self.step}") + + def __eq__(self, value: object) -> bool: + return ( + isinstance(value, HashableSlice) + and self.start == value.start + and self.stop == value.stop + and self.step == value.step + ) + + def to_slice(self) -> slice: + return slice(self.start, self.stop, self.step) + + +@dataclass +class ParameterSlice: + """ + A slice of parameter which corresponds to a tensor in weight dict. Only support slicing + along a specific axis (the potential parallel axis) right now. + + Attributes: + - source (`Optional[str]`, defaults to `None`): + Original parameter name which can be found in the weight dict. + + - shape (`Optional[Tuple]`, defaults to `None`): + Shape of parameter tensor corresponding to `source`. + + - index (`slice`, defaults to `slice(None, None, None)`): + Index to slice the tensor on the parallel axis. Assume tensor in weight dict has the same + layout as their correspondings in memory. + """ + + source: Optional[str] = None + shape: Optional[Tuple] = None + index: slice = slice(None, None, None) + + +@dataclass +class ParameterMeta: + """ + Parameter meta information. + + Attributes: + - is_tied (`bool`, defaults to `False`): + Whether the parameter is shared accross multiple modules. + + - is_parallel (`bool`, defaults to `False`): + Whether the parameter needs to be parallelized. + + - is_modified_meta (`bool`, defaults to `False`): + Whether the meta has already been modified since initialization. + + - need_initialize (`bool`, defaults to `False`): + Whether need to manually initialize weights if not provided in weight map. + + - init_fn (`Optional[Callable]`, defaults to `None`): + Initialization function, can override `weight_init_fn` in `Config` if not None. + + - dim (`int`, defaults to `0`): + Axis on which `mapping` is based, also the parallel axis if `is_parallel`. + + - mapping (`Dict[HashableSlice, ParameterSlice]`): + Mapping between the current parameter and weight tensor stored in weight map. + """ + + is_tied: bool = False + is_parallel: bool = False + is_modified_meta: bool = False + need_initialize: bool = False + init_fn: Optional[Callable] = None + dim: int = 0 + mapping: Dict[HashableSlice, ParameterSlice] = field(default_factory=dict) + + +@dataclass +class ParallelExecutionCtx: + """ + Parallel execution context which contains runtime information. + + Attributes: + - tp_group (`dist.ProcessGroup`): + Tensor parallel process group the current process belongs to. + + - current_device (`torch.device`): + Device correpsonding to the current process. + + - example_inputs (`List[Any]`): + A list of tensors which are used as example inputs for graphs captured by dynamo. + + - parallel_layer_cache (`Dict[str, nn.Module]`): + Cache which maps layers(`nn.Linear`, `nn.Embedding`) to their parallel counterparts. + Note that we will build the cache in the first compilation process, and for recompilations + later on, we will directly replace the modules with their parallel counterparts in the cache, + because we have to make sure we don't initiate new parameters and replace original ones when + recompilation happens in training process. + + - weight_map (`Dict[str, str]`): + Mapping between parameter names and their locations on disk, useful when loading weights + from disk. + + - last_optimized_graph_module (`Optional[GraphModule]`, defaults to `None`): + Optimized graph module corresponding to the latest compilation. + + - compile_times (`int`, defaults to `0`): + Number of compilation times happened during the whole process. + """ + + tp_group: dist.ProcessGroup + current_device: torch.device + example_inputs: List[Any] = field(default_factory=list) + parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict) + weight_map: Dict[str, str] = field(default_factory=dict) + last_optimized_graph_module: Optional[GraphModule] = None + compile_times: int = 0 + + +@dataclass +class Config: + """ + Static config which contains instructions which do not change in runtime. + + Attributes: + - lint_and_recompile (`bool`, defaults to `True`): + Whether to run graph linting and module recompilation after every pass. + + - clean_markers_after_all_passes (`bool`, defaults to `True`): + Whether to clean markers of analytical passes after all passes have run. + + - weight_init_fn (`Callable`, defaults to `partial(nn.init.normal_, std=0.02)`) + Initialization function of weights in `nn.Linear` and `nn.Embedding` layers, + if not provided weights loading path. + """ + + lint_and_recompile: bool = True + clean_markers_after_all_passes: bool = True + weight_init_fn: Callable = partial(nn.init.normal_, std=0.02) diff --git a/optimum/fx/parallelization/distributed/__init__.py b/optimum/fx/parallelization/distributed/__init__.py new file mode 100644 index 00000000000..37340136691 --- /dev/null +++ b/optimum/fx/parallelization/distributed/__init__.py @@ -0,0 +1,21 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .dist_ops import ( + differentiable_all_gather, + differentiable_all_reduce_sum, + differentiable_identity, + differentiable_scatter, + scatter, +) diff --git a/optimum/fx/parallelization/distributed/dist_ops.py b/optimum/fx/parallelization/distributed/dist_ops.py new file mode 100644 index 00000000000..081f84ce17f --- /dev/null +++ b/optimum/fx/parallelization/distributed/dist_ops.py @@ -0,0 +1,145 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.distributed as dist + +from ..utils import ensure_divisibility + + +def all_reduce(group: dist.ProcessGroup, tensor: torch.Tensor) -> torch.Tensor: + world_size = dist.get_world_size(group) + if world_size == 1: + return tensor + + dist.all_reduce(tensor, group=group) + return tensor + + +def all_gather(group: dist.ProcessGroup, tensor: torch.Tensor, gather_dim: int = -1) -> torch.Tensor: + world_size = dist.get_world_size(group) + if world_size == 1: + return tensor + gather_dim = (gather_dim + tensor.ndim) % tensor.ndim + shape = [tensor.size(dim) * world_size if dim == gather_dim else tensor.size(dim) for dim in range(tensor.ndim)] + if gather_dim != 0: + shape[0], shape[gather_dim] = shape[gather_dim], shape[0] + tensors = torch.empty(*shape, dtype=tensor.dtype, device=tensor.device) + + if gather_dim != 0: + tensor = tensor.transpose(0, gather_dim) + tensor = tensor.contiguous() + + dist.all_gather_into_tensor(tensors, tensor, group=group) + if gather_dim != 0: + tensors = tensors.transpose(0, gather_dim).contiguous() + return tensors + + +def split(group: dist.ProcessGroup, tensor: torch.Tensor, split_dim: int = -1) -> torch.Tensor: + world_size = dist.get_world_size(group) + if world_size == 1: + return tensor + + rank = dist.get_rank(group) + size = tensor.size() + ensure_divisibility(size[split_dim], world_size) + tensors = torch.split(tensor, size[split_dim] // world_size, dim=split_dim) + tensor = tensors[rank].contiguous() + + return tensor + + +def scatter( + group: dist.ProcessGroup, tensor: torch.Tensor, output_tensor: torch.Tensor, scatter_dim: int = 0 +) -> torch.Tensor: + world_size = dist.get_world_size(group) + if world_size == 1: + output_tensor.copy_(tensor) + return tensor + + rank = dist.get_rank(group) + if rank == 0: + size = tensor.size() + ensure_divisibility(size[scatter_dim], world_size) + tensors = torch.split(tensor, size[scatter_dim] // world_size, dim=scatter_dim) + scatter_list = [tensor.contiguous() for tensor in tensors] + output_tensor.copy_(scatter_list[rank]) + else: + scatter_list = None + dist.scatter(tensor=output_tensor, scatter_list=scatter_list, src=0, group=group) + return output_tensor + + +class DifferentiableIdentity(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, group: dist.ProcessGroup): + ctx.group = group + return tensor + + @staticmethod + def backward(ctx, grad_output): + group = ctx.group + return DifferentiableAllReduceSum.apply(grad_output, group), None + + +class DifferentiableAllReduceSum(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: + ctx.group = group + return all_reduce(group=group, tensor=tensor) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Any: + return grad_output, None + + +class DifferentiableScatter(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim: int = -1) -> torch.Tensor: + ctx.group = group + ctx.dim = dim + return split(group=group, tensor=tensor, split_dim=dim) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + return DifferentiableAllGather.apply(grad_output, group=ctx.group, dim=ctx.dim), None, None + + +class DifferentiableAllGather(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim: int = -1) -> torch.Tensor: + ctx.group = group + ctx.dim = dim + return all_gather(group=group, tensor=tensor, gather_dim=dim) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + return DifferentiableScatter.apply(grad_output, group=ctx.group, dim=ctx.dim), None, None + + +def differentiable_all_reduce_sum(tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: + return DifferentiableAllReduceSum.apply(tensor, group) + + +def differentiable_identity(tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: + return DifferentiableIdentity.apply(tensor, group) + + +def differentiable_all_gather(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1) -> torch.Tensor: + return DifferentiableAllGather.apply(tensor, group, dim) + + +def differentiable_scatter(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1) -> torch.Tensor: + return DifferentiableScatter.apply(tensor, group, dim) diff --git a/optimum/fx/parallelization/parallel_layers/__init__.py b/optimum/fx/parallelization/parallel_layers/__init__.py new file mode 100644 index 00000000000..9bfb13afdf6 --- /dev/null +++ b/optimum/fx/parallelization/parallel_layers/__init__.py @@ -0,0 +1,16 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .embedding import VocabParallelEmbedding +from .linear import ColumnParallelLinear, RowParallelLinear diff --git a/optimum/fx/parallelization/parallel_layers/embedding.py b/optimum/fx/parallelization/parallel_layers/embedding.py new file mode 100644 index 00000000000..eb8cc9b2942 --- /dev/null +++ b/optimum/fx/parallelization/parallel_layers/embedding.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from ..core import ParallelExecutionCtx, ParameterMeta +from ..distributed import differentiable_all_reduce_sum +from ..utils import ensure_divisibility + + +class VocabParallelEmbedding(nn.Module): + """ + Embedding layer parallelized in vocabulary dimension. + + Arguments: + ctx(`ParallelExecutionCtx`): parallel execution context which contains runtime information. + embedding(`torch.nn.Embedding`): the original embedding module being replaced. + """ + + def __init__(self, ctx: ParallelExecutionCtx, embedding: nn.Embedding): + super(VocabParallelEmbedding, self).__init__() + self.process_group = ctx.tp_group + world_size = dist.get_world_size(self.process_group) + tp_rank = dist.get_rank(self.process_group) + ensure_divisibility(embedding.num_embeddings, world_size) + + num_embeddings = embedding.num_embeddings // world_size + + self.padding_idx = embedding.padding_idx + self.max_norm = embedding.max_norm + self.norm_type = embedding.norm_type + self.scale_grad_by_freq = embedding.scale_grad_by_freq + self.sparse = embedding.sparse + self.vocab_start_idx = tp_rank * num_embeddings + self.vocab_end_idx = (tp_rank + 1) * num_embeddings + + # modify meta information + weight_meta = getattr(embedding.weight, "meta", None) + assert isinstance( + weight_meta, ParameterMeta + ), "should have run `initialize_parameter_meta` after moving model to current device" + if weight_meta.is_modified_meta: + assert weight_meta.is_tied, "only tied parameters could already have modified meta" + else: + weight_meta.need_initialize = True + weight_meta.is_parallel = True + weight_meta.dim = 0 + for _, Slice in weight_meta.mapping.items(): + Slice.index = slice(self.vocab_start_idx, self.vocab_end_idx) + weight_meta.is_modified_meta = True + + # skip creating actual parameters + self.weight = embedding.weight + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input_mask = (input < self.vocab_start_idx) | (input >= self.vocab_end_idx) + masked_input = input.clone() - self.vocab_start_idx + masked_input[input_mask] = 0 + + output = F.embedding( + masked_input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + output[input_mask, :] = 0.0 + output = differentiable_all_reduce_sum(output, self.process_group) + return output diff --git a/optimum/fx/parallelization/parallel_layers/linear.py b/optimum/fx/parallelization/parallel_layers/linear.py new file mode 100644 index 00000000000..62d5894dacf --- /dev/null +++ b/optimum/fx/parallelization/parallel_layers/linear.py @@ -0,0 +1,177 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from ..core import ( + ParallelExecutionCtx, + ParameterMeta, +) +from ..distributed import ( + differentiable_all_gather, + differentiable_all_reduce_sum, + differentiable_identity, + differentiable_scatter, +) +from ..utils import ensure_divisibility + + +class ColumnParallelLinear(nn.Module): + """ + Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + ctx(`ParallelExecutionCtx`): parallel execution context which contains runtime information. + linear(`torch.nn.Linear`): the original linear module being replaced. + gather_output(`bool`, defaults to `True`): whether gathering output in the end of forward. + """ + + def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, gather_output: bool = True) -> None: + super(ColumnParallelLinear, self).__init__() + self.process_group = ctx.tp_group + world_size = dist.get_world_size(self.process_group) + tp_rank = dist.get_rank(self.process_group) + ensure_divisibility(linear.out_features, world_size) + + out_features = linear.out_features // world_size + bias = linear.bias is not None + + # modify meta information + weight_meta = getattr(linear.weight, "meta", None) + assert isinstance( + weight_meta, ParameterMeta + ), "should have run `initialize_parameter_meta` after moving model to current device" + + if weight_meta.is_modified_meta: + assert weight_meta.is_tied, "only tied parameters could already have modified meta" + else: + weight_meta.need_initialize = True + weight_meta.is_parallel = True + weight_meta.dim = 0 + for _, Slice in weight_meta.mapping.items(): + Slice.index = slice(tp_rank * out_features, (tp_rank + 1) * out_features) + weight_meta.is_modified_meta = True + + # skip creating actual parameters + self.weight = linear.weight + self.gather_output = gather_output + + if bias: + bias_meta = getattr(linear.bias, "meta", None) + assert isinstance( + bias_meta, ParameterMeta + ), "should have run `initialize_parameter_meta` after moving model to current device" + + if bias_meta.is_modified_meta: + assert bias_meta.is_tied, "only tied parameters could already have modified meta" + else: + bias_meta.need_initialize = True + bias_meta.is_parallel = True + bias_meta.init_fn = torch.zero_ + bias_meta.dim = 0 + for _, Slice in bias_meta.mapping.items(): + Slice.index = slice(tp_rank * out_features, (tp_rank + 1) * out_features) + bias_meta.is_modified_meta = True + self.bias = linear.bias + else: + self.register_parameter("bias", None) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = differentiable_identity(input, self.process_group) + output = F.linear(input, self.weight, self.bias) + if self.gather_output: + output = differentiable_all_gather(output, self.process_group) + return output + + +class RowParallelLinear(nn.Module): + """ + Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + ctx(`ParallelExecutionCtx`): parallel execution context which contains runtime information. + linear(`torch.nn.Linear`): the original linear module being replaced. + input_is_parallel(`bool`, defaults to `True`): whether the input tensor has already been parallelized. + """ + + def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, input_is_parallel: bool = False) -> None: + super(RowParallelLinear, self).__init__() + self.process_group = ctx.tp_group + world_size = dist.get_world_size(self.process_group) + tp_rank = dist.get_rank(self.process_group) + ensure_divisibility(linear.in_features, world_size) + + in_features = linear.in_features // world_size + bias = linear.bias is not None + + # modify meta information + weight_meta = getattr(linear.weight, "meta", None) + assert isinstance( + weight_meta, ParameterMeta + ), "should have run `initialize_parameter_meta` after moving model to current device" + + if weight_meta.is_modified_meta: + assert weight_meta.is_tied, "only tied parameters could already have modified meta" + else: + weight_meta.need_initialize = True + weight_meta.is_parallel = True + weight_meta.dim = 1 + for _, Slice in weight_meta.mapping.items(): + Slice.index = slice(tp_rank * in_features, (tp_rank + 1) * in_features) + weight_meta.is_modified_meta = True + + # skip creating actual parameters + self.weight = linear.weight + self.input_is_parallel = input_is_parallel + + if bias: + bias_meta = getattr(linear.bias, "meta", None) + assert isinstance( + bias_meta, ParameterMeta + ), "should have run `initialize_parameter_meta` after moving model to current device" + if bias_meta.is_modified_meta: + assert bias_meta.is_tied, "only tied parameters could already have modified meta" + else: + bias_meta.need_initialize = True + bias_meta.init_fn = torch.zero_ + bias_meta.is_modified_meta = True + self.bias = linear.bias + else: + self.register_parameter("bias", None) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not self.input_is_parallel: + input = differentiable_scatter(input, self.process_group) + + output = F.linear(input, self.weight) + output = differentiable_all_reduce_sum(output, self.process_group) + + if self.bias is not None: + output = output + self.bias + return output diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py new file mode 100644 index 00000000000..1b25e9e1233 --- /dev/null +++ b/optimum/fx/parallelization/passes.py @@ -0,0 +1,623 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.fx import Graph, GraphModule, Node + +from .core import Config, ParallelExecutionCtx, ParameterMeta +from .distributed import scatter +from .parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding +from .utils import ( + is_embedding, + is_linear, + is_permute, + is_shape_consumer, + is_shape_generator, + is_transpose, + stable_topological_sort, +) + + +class PassBase(ABC): + """ + Base class for parallelization targeted passes. + """ + + need_rerun_when_recompile: bool = True + + @classmethod + def signature(cls) -> str: + return cls.__name__ + + @abstractmethod + def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + """ + Args: + graph_module (`GraphModule`): + graph module before processing. + ctx (`ParallelExecutionCtx`): + dynamic execution context which gathers and preserves information along processing. + config (`Config`): + static config to include instructions which persists the whole process. + + Returns: + GraphModule: graph module after processed by the current pass. + """ + raise NotImplementedError + + def __call__(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + # skip running when recompilation happens + if not self.need_rerun_when_recompile and ctx.compile_times > 0: + return graph_module + + graph_module = self.run(graph_module, ctx=ctx, config=config) + if config.lint_and_recompile: + graph_module.graph.lint() + graph_module.recompile() + return graph_module + + +class AnalyzeBase(PassBase): + """ + Base class for passes which only runs for analytical purposes and preserve graph structure + during processing. Analytical passes are often prerequisite passes which provide information + for passes later on to actually change the graph. + + Passes inheriting from `AnalyzeBase` places the class signature as a meta key in `node.meta`, + which is a dict storing meta information related with a fx Node, such as the shape and dtype of + output. Look-up APIs are exposed as classmethod so that passes using them won't need to create + concrete instances. + """ + + @classmethod + def meta_key(cls) -> str: + # place class-wise unique meta_key in `meta` to prevent duplicate fields + return cls.signature() + + @classmethod + def get_stored_field_info(cls, node: Node, field: Any, must_have: bool = False) -> Any: + if not cls.already_executed_per_node(node): + if not must_have: + return None + else: + raise RuntimeError( + f"Can't find information related with {cls.__name__} in the current node `{node}` " + f"make sure {cls.__name__} has run and marked it" + ) + + info: Dict[Any, Any] = node.meta[cls.meta_key()] + if field not in info: + if must_have: + raise KeyError(f"Invalid query field {field} for {cls.__name__}, valid fields are {list(info.keys())}") + return None + + return info[field] + + @classmethod + def already_executed_per_node(cls, node: Node) -> bool: + return cls.meta_key() in node.meta + + def place_marker_per_node(self, node: Node, info: Dict[Any, Any]) -> None: + if self.already_executed_per_node(node): + raise RuntimeError( + f"Node {node} has already been marked by the current pass, check if " + "the current pass has already been executed in the pipeline" + ) + + node.meta[self.meta_key()] = info + + def clear_marker_per_node(self, node: Node) -> None: + key = self.meta_key() + if key in node.meta: + node.meta.pop(key) + + def clean_all(self, graph_module: GraphModule) -> None: + g: Graph = graph_module.graph + for node in g.nodes: + self.clear_marker_per_node(node) + + +class ParallelLayerAnnotatePass(AnalyzeBase): + """ + A pass which tries to automatically identify parallel layers in the graph. Note that for simplicity + we only consider classical ways of parallelizing layers in transformers architecture for now, we are not + solving an optimization problem which tries to give a best solution of parallelizing any model under + memory/hardware constraints. + + For `nn.Embedding` layers, we parallelize them on the vocabulary dim by default, because they are often tied + to the `lm_head` of the model, which is usually a `ColumnLinear`(parallelized on vocab dim). + + For `nn.Linear` layers, we parallelize them by grouping them as `upstream` nodes and `downstream` nodes, and + `upstream` nodes are marked as `ColumnLinear`, `downstream` nodes are marked as `RowLinear`. + + Typical examples in transformer models: + + Attention Bert-style MLP Llama-style MLP + __________________________________________________________________________ + Linear Linear Linear Linear + \\ / | \\ --> upstream + Matmul Linear Activation Activation Linear + __________________________________________________________________________ + \\ / | \\ / + \\ / ___________ \\ / + Matmul / Linear \ Mul + | / \ | + _______________________________/ \___________________________ + Linear Linear --> downstream + + Note that there are some patterns that can not be clearly marked, like this one: + + Linear + | \\ + | Linear <-- which label should we mark for the intermediate linear, `upstream` or `downstream` + | / + Add + | + Linear + + For patterns like this we will be conservative and raise errors directly because we don't know how to parallelize + it. Another concern is about the correctness, it's possible that we might end up with a wrong parallelization solution + even if the pattern itself is clear, but for now we are mainly targeting on transformer models and the current solution + should work fairly well. + """ + + def try_form_parallel_linear_groups(self, linear: Node) -> None: + """ + We try to form linears by forming closures in a greedy way, we start with an unmarked linear node, and traverses down + recusively to find all the potential `downstream` linears, note that once we have reached a linear, the recursion stops. + And the newly found `downstream` linears are used as new seeds to traverse upwards to find all the potential `upstream` + linears, the process goes on until number of linears on both sides converges. + Args: + linear (Node): the first linear node used as `upstream` node seed to form closure. + + Raises: + RuntimeError: + raises runtime error when the pattern itself is not clear, there are no clear boundaries that can be drawn. + """ + upstream_nodes, downstream_nodes = {linear}, set() + + seeds, next_seeds = [(linear, "down")], [] + + def traverse(start: Node, cur: Node, direction: str = "down"): + if is_linear(cur) and cur is not start: + if direction == "up" and cur not in upstream_nodes: + upstream_nodes.add(cur) + next_seeds.append((cur, "down")) + elif direction == "down" and cur not in downstream_nodes: + downstream_nodes.add(cur) + next_seeds.append((cur, "up")) + return + + next_nodes = cur.all_input_nodes if direction == "up" else cur.users + for node in next_nodes: + # we should ignore shape-related dependencies + if is_shape_generator(node): + continue + traverse(start, node, direction) + + while seeds: + next_seeds = [] + for node, direction in seeds: + traverse(start=node, cur=node, direction=direction) + seeds = next_seeds + + if any(self.already_executed_per_node(node) for node in (upstream_nodes | downstream_nodes)) or ( + upstream_nodes & downstream_nodes + ): + raise RuntimeError( + "Failed to automatically group and parallelize ops in graph in greedy way: " + "no clear boudaries between `upstream` and `downstream` ops." + ) + + for node in upstream_nodes: + self.place_marker_per_node(node, {"axis": "column", "gather_output": False if downstream_nodes else True}) + + for node in downstream_nodes: + self.place_marker_per_node(node, {"axis": "row", "input_is_parallel": True}) + + def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + graph: Graph = graph_module.graph + stable_topological_sort(graph) + for node in graph.nodes: + if is_linear(node) and not self.already_executed_per_node(node): + self.try_form_parallel_linear_groups(node) + elif is_embedding(node): + # directly mark `nn.Embedding` layers + self.place_marker_per_node(node, {"axis": "vocab"}) + + return graph_module + + +class ParallelAxisPropagationPass(AnalyzeBase): + """ + A pass which tries to track which axis is being parallelized in the dataflow. For transformer models, the + axis being paralled for tensor parallism is almost always 2, i.e., the attention head axis, except for + Q and K matrices which need to swap the sequence length axis and head axis to do the attention computation, + so we focus on operations like `transpose` or `permute` which swaps axis, and try inducting the parallel + axis after these operations. + """ + + def propagate_transpose(self, node: Node, parallel_axis: int) -> bool: + dims = node.meta["example_value"].dim() + if "dim0" in node.kwargs and "dim1" in node.kwargs: + dim0, dim1 = node.kwargs["dim0"], node.kwargs["dim1"] + elif len(node.args) == 3: + dim0, dim1 = node.args[1:] + + dim0 = (dim0 + dims) % dims + dim1 = (dim1 + dims) % dims + + if dim0 == parallel_axis: + self.place_marker_per_node(node, {"parallel_axis": dim1}) + return True + elif dim1 == parallel_axis: + self.place_marker_per_node(node, {"parallel_axis": dim0}) + return True + return False + + def propagate_permute(self, node: Node, parallel_axis: int) -> bool: + if "dims" in node.kwargs: + dims = node.kwargs["dims"] + else: + dims = ( + list(node.args[1]) + if isinstance(node.args[1], tuple) + else [arg for arg in node.args if isinstance(arg, int)] + ) + + dim_len = node.meta["example_value"].dim() + dims = [dim + dim_len if dim < 0 else dim for dim in dims] + + for i, dim in enumerate(dims): + if dim == parallel_axis: + self.place_marker_per_node(node, {"parallel_axis": i}) + return True + return False + + def propagate_getitem(self, node: Node, parallel_axis: int) -> bool: + slices = node.args[1] + dims = node.meta["example_value"].dim() + assert parallel_axis < dims + inc, i, j = 0, 0, 0 + + while i < parallel_axis and j < len(slices): + if isinstance(slices[j], int): + inc -= 1 + i += 1 + elif slices[j] is None: + inc += 1 + elif slices[j] is Ellipsis: + i = dims + k = j + while k < len(slices): + if slices[k] is not Ellipsis: + i -= 1 + k += 1 + else: + i += 1 + j += 1 + + if inc != 0: + assert parallel_axis + inc < dims and parallel_axis + inc >= 0 + self.place_marker_per_node(node, {"parallel_axis": parallel_axis + inc}) + return True + return False + + def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + g: Graph = graph_module.graph + stable_topological_sort(g) + + for node in g.nodes: + if ParallelLayerAnnotatePass.already_executed_per_node(node): + # start propagating at ColumnLinear, marking the beginning of parallelized region + axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis", must_have=True) + gather_output = ParallelLayerAnnotatePass.get_stored_field_info(node, field="gather_output") + if axis == "column" and not gather_output: + self.place_marker_per_node(node, {"parallel_axis": 2}) + # stop propagating at RowLinear, concluding the ending of parallelized region + else: + continue + else: + already_marked_args, parallel_axis = [], None + for arg in node.all_input_nodes: + if not self.already_executed_per_node(arg): + continue + if parallel_axis is None: + parallel_axis = self.get_stored_field_info(arg, field="parallel_axis", must_have=True) + else: + assert parallel_axis == self.get_stored_field_info( + arg, field="parallel_axis", must_have=True + ), "`parallel_axis` should be equal for all arguments in any related ops" + already_marked_args.append(arg) + + if not already_marked_args: + continue + + marked = False + if is_transpose(node): + marked = self.propagate_transpose(node, parallel_axis) + elif is_permute(node): + marked = self.propagate_permute(node, parallel_axis) + + # fall back + if not marked: + self.place_marker_per_node(node, {"parallel_axis": parallel_axis}) + return graph_module + + +class ParallelLayerReplacePass(PassBase): + """ + A pass which modifies graph according to information provided by previous analytical passes, + in general it does two things for now: + 1. replaces linears and embedding layers with their parallel counterparts. + 2. modifies hard-coded arguments like the number of attention heads in the graph by dividing it by parallelism level. + """ + + @staticmethod + def handle_linear(node: Node, ctx: ParallelExecutionCtx) -> None: + graph_module = node.graph.owning_module + axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") + if axis is None: + return + + assert axis in {"column", "row"} + prefix_and_field = node.target.rsplit(".", maxsplit=1) + if len(prefix_and_field) == 2: + parent_mod = graph_module.get_submodule(prefix_and_field[0]) + field = prefix_and_field[1] + else: + parent_mod = graph_module + field = node.target + + mod: nn.Linear = graph_module.get_submodule(node.target) + key, layer_cache = node.target, ctx.parallel_layer_cache + if key in layer_cache: + new_mod = layer_cache[key] + else: + if axis == "column": + gather_output = ParallelLayerAnnotatePass.get_stored_field_info( + node, field="gather_output", must_have=True + ) + new_mod = ColumnParallelLinear(ctx, mod, gather_output) + else: + input_is_parallel = ParallelLayerAnnotatePass.get_stored_field_info( + node, field="input_is_parallel", must_have=True + ) + new_mod = RowParallelLinear(ctx, mod, input_is_parallel) + layer_cache[key] = new_mod + setattr(parent_mod, field, new_mod) + + @staticmethod + def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None: + graph_module = node.graph.owning_module + axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") + if axis is None: + return + + assert axis in {"vocab"}, "Only support parallelization on vocab dim for now." + prefix_and_field = node.target.rsplit(".", maxsplit=1) + if len(prefix_and_field) == 2: + parent_mod = graph_module.get_submodule(prefix_and_field[0]) + field = prefix_and_field[1] + else: + parent_mod = graph_module + field = node.target + + mod: nn.Embedding = graph_module.get_submodule(node.target) + key, layer_cache = node.target, ctx.parallel_layer_cache + if key in layer_cache: + new_mod = layer_cache[key] + else: + assert ctx.compile_times == 0, "illegal path for recompilation" + new_mod = VocabParallelEmbedding(ctx, mod) + layer_cache[key] = new_mod + setattr(parent_mod, field, new_mod) + + @staticmethod + def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None: + def extract_shape_from_node(node: Node) -> List[Any]: + if "size" in node.kwargs: + return list(node.kwargs["size"]) + elif "shape" in node.kwargs: + return list(node.kwargs["shape"]) + elif isinstance(node.args[1], tuple): + return list(node.args[1]) + else: + return list(node.args[1:]) + + def update(node: Node, new_shape: List[Any], parallel_axis: int): + if "size" in node.kwargs: + node.update_kwarg("size", tuple(new_shape)) + elif "shape" in node.kwargs: + node.update_kwarg("shape", tuple(new_shape)) + elif isinstance(node.args[1], tuple): + node.update_arg(1, tuple(new_shape)) + else: + node.update_arg(parallel_axis + 1, shape[parallel_axis]) + + parallel_axis = ParallelAxisPropagationPass.get_stored_field_info(node, field="parallel_axis") + if parallel_axis is None: + return + + shape = extract_shape_from_node(node) + assert parallel_axis < len(shape) + if not isinstance(shape[parallel_axis], int) or shape[parallel_axis] == -1: + return + world_size = ctx.tp_group.size() + assert shape[parallel_axis] % world_size == 0 + shape[parallel_axis] = shape[parallel_axis] // world_size + update(node, shape, parallel_axis) + + def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + for node in graph_module.graph.nodes: + if is_linear(node): + self.handle_linear(node, ctx) + elif is_embedding(node): + self.handle_embedding(node, ctx) + # correct the attention head num in parallel setting + elif is_shape_consumer(node): + self.handle_hard_coded_axis_param(node, ctx) + return graph_module + + +class InitializeOrLoadWeightsPass(PassBase): + """ + Make weights loading/initialization a seperate pass for cleaner logic and easier extensibility. This + pass will only run once in the very first compilation step. + """ + + need_rerun_when_recompile = False + + def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + world_size = dist.get_world_size(ctx.tp_group) + tp_rank = dist.get_rank(ctx.tp_group) + + new_parameters, tied_parameters = [], {} + for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)): + param_meta: ParameterMeta = getattr(param, "meta") + # skip already initialized/loaded tied parameters + if param_meta.is_tied and id(param) in tied_parameters: + new_parameters.append((name, tied_parameters[id(param)])) + continue + + shape = [ + param.size(dim) // world_size if dim == param_meta.dim and param_meta.is_parallel else param.size(dim) + for dim in range(param.ndim) + ] + + if not param_meta.is_parallel and param.device == ctx.current_device: + new_param = param + else: + new_param = nn.Parameter( + torch.zeros(*shape, dtype=param.dtype, device=ctx.current_device), + requires_grad=param.requires_grad, + ) + + # load weights if possible + for source, target in sorted(param_meta.mapping.items()): + if target.source in ctx.weight_map: + from safetensors import safe_open + + with safe_open(ctx.weight_map[target.source], framework="pt", device="cpu") as fp: + tensor_slice = fp.get_slice(target.source) + source_index = [ + source.to_slice() if dim == param_meta.dim else slice(None, None, None) + for dim in range(param.ndim) + ] + load_index = [ + target.index if dim == param_meta.dim else slice(None, None, None) + for dim in range(param.ndim) + ] + + tensor = tensor_slice[load_index].contiguous() + tensor = torch.empty_like(tensor).copy_(tensor) + with torch.no_grad(): + new_param.data[source_index].copy_(tensor) + + # weights initialization + if param_meta.need_initialize: + for source, target in sorted(param_meta.mapping.items()): + if target.source in ctx.weight_map: + continue + if not param_meta.is_parallel or tp_rank == 0: + # initialize weight on master rank + weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu") + init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn + init_fn(weight) + weight = weight.to(ctx.current_device) + else: + weight = None + index = [ + source.to_slice() if dim == param_meta.dim else slice(None, None, None) + for dim in range(param.ndim) + ] + with torch.no_grad(): + if param_meta.is_parallel: + scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim) + else: + new_param.data[index].copy_(weight) + setattr(new_param, "meta", param_meta) + + if id(new_param) != id(param): + new_parameters.append((name, new_param)) + if param_meta.is_tied: + tied_parameters[id(param)] = new_param + + for name, new_param in new_parameters: + prefix_and_field = name.rsplit(".", maxsplit=1) + if len(prefix_and_field) == 2: + parent_mod = graph_module.get_submodule(prefix_and_field[0]) + field = prefix_and_field[1] + else: + parent_mod = graph_module + field = name + setattr(parent_mod, field, new_param) + + return graph_module + + +def build_parallel_pass_pipeline() -> PassPipeline: + """ + Ensemble a pass pipeline which contains the following passes: + 1. `ParallelLayerAnnotatePass` to annoate which linears are `ColumnLinear`, which are `RowLinear` + 2. `ParallelAxisPropagationPass` to propate parallel axis along the data flow + 3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes + 4. `InitializeOrLoadWeightsPass` to load or initialize weights for parameters + + Returns: + PassPipeline: the pipeline used for automatic parallelism. + """ + return PassPipeline( + [ + ParallelLayerAnnotatePass(), + ParallelAxisPropagationPass(), + ParallelLayerReplacePass(), + InitializeOrLoadWeightsPass(), + ] + ) + + +class PassPipeline: + """ + `PassPipeline` ensembles a list of passes and execute them one by one as provided in the list, + it can be iterated and appended after initialization for flexibility. + """ + + def __init__(self, passes: List[PassBase] = []) -> None: + self._passes = passes + + def __iter__( + self, + ): + return self._passes.__iter__() + + def append(self, PASS: PassBase) -> None: + self._passes.append(PASS) + + def __call__(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + for PASS in self._passes: + graph_module = PASS(graph_module=graph_module, ctx=ctx, config=config) + + if config.clean_markers_after_all_passes: + for PASS in self._passes: + if isinstance(PASS, AnalyzeBase): + PASS.clean_all(graph_module) + return graph_module diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py new file mode 100644 index 00000000000..f129ffbd402 --- /dev/null +++ b/optimum/fx/parallelization/utils.py @@ -0,0 +1,472 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import fnmatch +import glob +import hashlib +import importlib +import json +import operator +import os +import re +import tempfile +from collections import defaultdict +from functools import wraps +from itertools import chain +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import filelock +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.fx import Graph, Node +from tqdm.auto import tqdm + +from .core import HashableSlice, ParameterMeta, ParameterSlice + + +def ensure_divisibility(numerator: int, denominator: int) -> None: + if numerator % denominator != 0: + raise RuntimeError( + f"{numerator} is not divisible by {denominator}, check if the parallel dimension of weight parameters is divisible " + "by parallelism level(world size of tensor parallel group)" + ) + + +def is_linear(node: Node) -> bool: + if node.op != "call_module": + return False + mod = node.graph.owning_module + return isinstance(mod.get_submodule(node.target), nn.Linear) + + +def is_embedding(node: Node) -> bool: + if node.op != "call_module": + return False + mod = node.graph.owning_module + return isinstance(mod.get_submodule(node.target), nn.Embedding) + + +def is_shape_consumer(node: Node) -> bool: + if node.op == "call_method": + return node.target in {"view", "reshape", "expand", "resize", "resize_"} + elif node.op == "call_function": + return node.target in {torch.reshape} + return False + + +def is_transpose(node: Node) -> bool: + if node.op == "call_method": + return node.target in {"transpose", "transpose_"} + elif node.op == "call_function": + return node.target is torch.transpose + return False + + +def is_permute(node: Node) -> bool: + if node.op == "call_method": + return node.target in {"permute"} + elif node.op == "call_function": + return node.target is torch.permute + return False + + +def is_getitem(node: Node) -> bool: + return node.op == "call_function" and node.target is operator.getitem + + +def is_output(node: Node) -> bool: + return node.op == "output" + + +def is_shape_generator(node: Node) -> bool: + return node.op == "call_method" and node.target == "size" + + +def stable_topological_sort(graph: Graph): + def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: + args: List[torch.fx.node.Argument] = [] + torch.fx.map_arg((n.args, n.kwargs), args.append) + return args + + # Nodes are in exactly one of these three collections: + + # - Nodes in `pending` are waiting to be processed (in reverse order): + pending = list(reversed(graph.nodes)) + + # - Nodes in `ready` have been processed and are already in the correct + # order. + ready = set() + + # - `waiting` is a mapping from a dependency to nodes which depend on that + # dependency. + waiting = defaultdict(list) + + # The cursor indicates the last processed node so we can add new nodes + # after it. + cursor = None + while pending: + node = pending.pop() + waiting_for = [x for x in _args(node) if x not in ready] + if waiting_for: + # We have unprocessed input nodes. Might as well wait for the last + # arg so an already sorted list will only recheck this node once. + waiting[waiting_for[-1]].append(node) + else: + ready.add(node) + if cursor and cursor.next is not node: + cursor.append(node) + cursor = node + # Mark the nodes that have been waiting for this node to finish as + # ready to check again. + pending.extend(reversed(waiting.pop(node, ()))) + + assert not waiting and len(ready) == len(graph.nodes) + + +def meta_init(init_fn): + @wraps(init_fn) + def wrapper(*args, **kwargs): + kwargs["device"] = kwargs.pop("device", torch.device("meta")) + return init_fn(*args, **kwargs) + + return wrapper + + +@wraps(nn.Linear.forward) +def meta_aware_linear_forward(*args, **kwargs): + self = args[0] + input = args[1] + + if self.weight.device != torch.device("meta"): + return F.linear(input, self.weight, self.bias) + + orig_device = input.device + input = input.to("meta") + meta_output = F.linear(input, self.weight, self.bias) + return torch.empty_like(meta_output, device=orig_device) + + +@wraps(nn.Embedding.forward) +def meta_aware_embedding_forward(*args, **kwargs): + self = args[0] + input = args[1] + + if self.weight.device != torch.device("meta"): + return F.embedding( + input=input, + weight=self.weight, + padding_idx=self.padding_idx, + max_norm=self.max_norm, + norm_type=self.norm_type, + scale_grad_by_freq=self.scale_grad_by_freq, + sparse=self.sparse, + ) + + orig_device = input.device + input = input.to("meta") + meta_output = F.embedding( + input=input, + weight=self.weight, + padding_idx=self.padding_idx, + max_norm=self.max_norm, + norm_type=self.norm_type, + scale_grad_by_freq=self.scale_grad_by_freq, + sparse=self.sparse, + ) + return torch.empty_like(meta_output, device=orig_device) + + +class MetaAwareMethodsPatcher: + """ + A patcher class which patches `__init__` and `forward` methods on modules which will be put on meta + devices for memory efficiency purposes during initialization. + + Note that for `__init__` method, it can be unpatched once we have finished the initialization of the + model, however, for `forward`, we need it to constantly being patched during the whole process in case + recompile happens and torch dynamo needs meta-aware `forward` to be able to re-capture the graph. + """ + + methods_to_patch: Dict[str, Callable] = [ + ("torch.nn.Linear.__init__", meta_init(nn.Linear.__init__)), + ("torch.nn.Embedding.__init__", meta_init(nn.Embedding.__init__)), + ("torch.nn.Linear.forward", meta_aware_linear_forward), + ("torch.nn.Embedding.forward", meta_aware_embedding_forward), + ] + + def __init__(self) -> None: + self.patching_specs = [] + for orig, patch_fn in self.methods_to_patch: + module_qualified_name, attribute_name = orig.rsplit(".", maxsplit=1) + try: + module = importlib.import_module(module_qualified_name) + except ModuleNotFoundError as e: + module_qualified_name, module_attribute_name = module_qualified_name.rsplit(".", maxsplit=1) + module = importlib.import_module(module_qualified_name) + try: + module = getattr(module, module_attribute_name) + except AttributeError: + raise e + orig_fn = getattr(module, attribute_name) + + # Module, Attribute, Patchee, Patcher, Status + self.patching_specs.append([module, attribute_name, orig_fn, patch_fn, False]) + + def _patch(self, identifier: str): + for spec in self.patching_specs: + # already patched + if spec[-1]: + continue + if identifier in spec[1]: + setattr(spec[0], spec[1], spec[3]) + spec[-1] = True + + def _unpatch(self, identifier: str): + for spec in self.patching_specs: + # already patched + if not spec[-1]: + continue + if identifier in spec[1]: + setattr(spec[0], spec[1], spec[2]) + spec[-1] = False + + def patch_meta_init( + self, + ): + self._patch("init") + + def patch_meta_forward( + self, + ): + self._patch("forward") + + def unpatch_meta_init( + self, + ): + self._unpatch("init") + + def unpatch_meta_forward( + self, + ): + self._unpatch("forward") + + def __enter__( + self, + ): + self.patch_meta_init() + self.patch_meta_forward() + + def __exit__(self, exc_type, exc_value, traceback): + self.unpatch_meta_init() + + +def initialize_parameter_meta(model: nn.Module) -> None: + parameter_ids = set() + for name, tensor in model.named_parameters(remove_duplicate=False): + key = id(tensor) + if key not in parameter_ids: + setattr( + tensor, + "meta", + ParameterMeta( + dim=0, + mapping={HashableSlice(None, None, None): ParameterSlice(source=name, shape=tuple(tensor.shape))}, + ), + ) + parameter_ids.add(key) + else: + tensor.meta.is_tied = True + + +@torch.no_grad +def move_model_to_device(model: nn.Module, device: Union[torch.device, str]): + """ + Move everything except tensors on meta devices on current device + this function should be called before `intialize_parameter_meta` + """ + for name, tensor in chain(model.named_parameters(), model.named_buffers()): + if tensor.device == torch.device("meta"): + continue + splits = name.rsplit(".", maxsplit=1) + if len(splits) == 1: + parent_mod = model + attr_name = splits[0] + else: + qualified_name = splits[0] + parent_mod = model.get_submodule(qualified_name) + attr_name = splits[1] + new_tensor = tensor.to(device) + if isinstance(tensor, nn.Parameter): + new_tensor = nn.Parameter(new_tensor) + setattr(parent_mod, attr_name, new_tensor) + + +temp_dir = tempfile.gettempdir() + + +def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): + lock_dir = cache_dir or temp_dir + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) + return lock + + +class DisabledTqdm(tqdm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + +# adpated from vllm.model_executor.model_loader.weight_utils.py +def download_model_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + revision: Optional[str] = None, + local_files_only: bool = False, + skip_download_weights: bool = False, +) -> str: + """Download model weights, index and config files from Hugging Face Hub. + + Args: + model_name_or_path (`str`): The model name or path. + cache_dir (`Optional[str]`): The cache directory to store the model + weights. If None, will use HF defaults. + revision (`Optional[str]`, defaults to `None`): The revision of the model. + local_files_only(`bool`): Should only use local files if True. + skip_download_weights (`bool`, defaults to `False`): Whether to skip downloading weights to disk. + + Returns: + str: The path to the downloaded files. + """ + import huggingface_hub.constants + from huggingface_hub import HfFileSystem, snapshot_download + from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME + + allow_patterns = ["*.safetensors", "*.bin"] + + if not skip_download_weights and not huggingface_hub.constants.HF_HUB_OFFLINE: + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break + + if skip_download_weights: + # only need to download config file + allow_patterns = [CONFIG_NAME] + elif allow_patterns[0] == "*.safetensors": + allow_patterns = allow_patterns + [CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME] + else: + allow_patterns = allow_patterns + [CONFIG_NAME, WEIGHTS_INDEX_NAME] + + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE or local_files_only, + tqdm_class=DisabledTqdm, + ) + return hf_folder + + +# copied from optimum.neuron.utils.misc.py +def _original_filename_to_safetensors_filename(filename: str) -> str: + """Transforms the filename for any kind of checkpoint to a safetensors equivalent.""" + from transformers.utils import SAFE_WEIGHTS_NAME + + _, extension = filename.rsplit(".", maxsplit=1) + pattern = rf"\w+(-[0-9]*-of-[0-9]*)?\.{extension}" + match_ = re.match(pattern, filename) + if not match_: + raise ValueError(f"Could not convert {filename} to a safetensor filename.") + group_1 = match_.group(1) + index_out_of_total_str = group_1 if group_1 is not None else "" + safetensor_filename, safetensor_extension = SAFE_WEIGHTS_NAME.rsplit(".", maxsplit=1) + return f"{safetensor_filename}{index_out_of_total_str}.{safetensor_extension}" + + +def convert_bin_to_safetensors( + model_name_or_path: str, cache_dir: Optional[str], weight_files: List[str], weight_map: Dict[str, str] +): + """Convert to pytorch bin files to their safetensors equivalent.""" + from safetensors.torch import save_file + + with get_lock(model_name_or_path, cache_dir): + for weight_file in weight_files: + weight_file_path = Path(weight_file) + safetensors_filename = _original_filename_to_safetensors_filename(weight_file_path.name) + output_dir = cache_dir if cache_dir else weight_file_path.parent + output_file_path = os.path.join(output_dir, safetensors_filename) + if not os.path.isfile(output_file_path): + checkpoint = torch.load(weight_file, map_location=torch.device("cpu")) + data_pointers = set() + for k, v in checkpoint.items(): + if v.data_ptr() in data_pointers: + v = v.detach().clone() + v = v.contiguous() + checkpoint[k] = v + data_pointers.add(v.data_ptr()) + save_file(checkpoint, output_file_path) + keys = [key for key, value in weight_map.items() if value == weight_file] + for key in keys: + weight_map[key] = output_file_path + + +def try_collect_weight_map(model_name_or_path: str, cache_dir: Optional[str], folder_path: str) -> Dict[str, str]: + """Try collecting weight mapping information from the model folder.""" + from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME + + weight_map = {} + use_safetensors, weight_patterns = False, ["*safetensors", "*.bin"] + for pattern in weight_patterns: + if len(glob.glob(os.path.join(folder_path, pattern))) > 0: + use_safetensors = pattern == "*.safetensors" + break + index_path = os.path.join(folder_path, SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME) + weight_files = glob.glob(os.path.join(folder_path, "*.safetensors" if use_safetensors else "*.bin")) + + if os.path.isfile(index_path): + with open(index_path) as f: + index_dict = json.load(f) + weight_map = {k: os.path.join(folder_path, v) for k, v in index_dict["weight_map"].items()} + + # convert bin files to safetensors, modify `weight_map` meanwhile + if not use_safetensors: + convert_bin_to_safetensors(model_name_or_path, cache_dir, weight_files, weight_map) + + # last resort: try directly construct weight_map from weight files + if not weight_map: + from safetensors import safe_open + + # should have safetensors on disk in any case + weight_files = glob.glob(os.path.join(folder_path, "*.safetensors")) + for weight_file in weight_files: + with safe_open(filename=weight_file, framework="pt") as f: + for key in f.keys(): + weight_map[key] = weight_file + return weight_map diff --git a/tests/fx/parallelization/dist_utils.py b/tests/fx/parallelization/dist_utils.py new file mode 100644 index 00000000000..ef35fb33d06 --- /dev/null +++ b/tests/fx/parallelization/dist_utils.py @@ -0,0 +1,77 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Callable, List, Optional + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from transformers import set_seed + + +SEED = 42 +NUM_AVAILABLE_DEVICES = torch.cuda.device_count() + + +def dist_init( + rank: int, + world_size: int, + backend: str = "nccl", + master_addr: str = "127.0.0.1", + master_port: str = "29501", +): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port + + dist.init_process_group( + backend=backend, + init_method="env://", + world_size=world_size, + rank=rank, + ) + + torch.cuda.set_device(rank) + + +def runner(rank: int, fn: Callable, deterministic: bool, *args, **kwargs): + if deterministic: + set_seed(SEED) + fn(rank, *args, **kwargs) + + +def spawn(world_size: int, fn: Callable, *args, deterministic: bool = False): + mp.spawn(fn=runner, args=(fn, deterministic, world_size, *args), nprocs=world_size, join=True) + + +def tearDown(group: Optional[dist.ProcessGroup] = None): + dist.destroy_process_group(group) + + +def gather_at_main_process( + tensor: torch.Tensor, group: dist.ProcessGroup, rank: int, world_size: int +) -> List[torch.Tensor]: + if world_size == 1: + return [tensor] + + tensor = tensor.contiguous() + if rank == 0: + tensors = [torch.empty_like(tensor) for _ in range(world_size)] + tensors[rank] = tensor + else: + tensors = None + dist.gather(tensor=tensor, gather_list=tensors, dst=0, group=group) + return tensors diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py new file mode 100644 index 00000000000..9626fccec3b --- /dev/null +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -0,0 +1,253 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from typing import Any, Dict, Union + +import torch +import torch.distributed as dist +from dist_utils import NUM_AVAILABLE_DEVICES, SEED, dist_init, gather_at_main_process, spawn, tearDown +from packaging import version +from parameterized import parameterized +from transformers import ( + PretrainedConfig, + set_seed, +) + +from optimum.fx.parallelization import ParallelExecutionCtx, parallelize_model +from optimum.fx.parallelization.parallel_layers import ColumnParallelLinear, VocabParallelEmbedding +from optimum.fx.parallelization.utils import stable_topological_sort + + +DUMMY_MODEL_KWARGS = { + "num_hidden_layers": 2, + "use_cache": False, + "output_attentions": False, + "output_hidden_states": False, + "tie_word_embeddings": True, +} + +DUMMY_MODELS_TO_TEST = ( + ( + "saibo/llama-1B", + DUMMY_MODEL_KWARGS, + ), + ( + "PhoenixJie/dummy-mistral", + DUMMY_MODEL_KWARGS, + ), +) + + +def is_gpu_available(): + return torch.cuda.is_available() + + +def is_torch_compile_available(): + return version.parse(torch.__version__) >= version.parse("2.3.0") + + +def prepare_dummy_inputs( + model_config: PretrainedConfig, + batch_size: int = 1, + seq_len: int = 10, + device: Union[str, torch.device] = "cuda", +): + return { + "input_ids": torch.randint(low=1, high=model_config.vocab_size, size=(batch_size, seq_len), device=device), + "attention_mask": torch.ones((batch_size, seq_len), dtype=torch.int64, device=device), + "position_ids": torch.arange(0, seq_len, device=device).unsqueeze(0).expand(batch_size, -1), + } + + +def run_test_all_rank_results_match(rank: int, world_size: int, model_id: str, model_kwargs: Dict[str, Any]): + # initialize default group + dist_init(rank, world_size) + tp_group = dist.new_group() + + # prepare config and context + device = torch.device(type="cuda", index=torch.cuda.current_device()) + ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) + + model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + inputs = prepare_dummy_inputs(model.config) + logits = model(**inputs)[0] + tensors = gather_at_main_process(tensor=logits, group=tp_group, rank=rank, world_size=world_size) + + # check results at main worker process + if rank == 0: + assert len(tensors) == world_size + for i in range(1, world_size): + torch.testing.assert_close(tensors[i - 1].cpu(), tensors[i].cpu(), rtol=1e-4, atol=1e-4) + + dist.barrier(tp_group) + tearDown(tp_group) + + +def run_test_parameters_persist_bewteen_recompile( + rank: int, world_size: int, model_id: str, model_kwargs: Dict[str, Any] +): + # initialize default group + dist_init(rank, world_size) + tp_group = dist.new_group() + + # prepare config and context + device = torch.device(type="cuda", index=torch.cuda.current_device()) + ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) + + model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + inputs = prepare_dummy_inputs(model.config) + + # different shape to trigger recompile + another_inputs = prepare_dummy_inputs(model.config, seq_len=11) + yet_another_inputs = prepare_dummy_inputs(model.config, batch_size=2, seq_len=12) + + model(**inputs) + parameter_ids = {id(param) for _, param in ctx.last_optimized_graph_module.named_parameters()} + + model(**another_inputs) + # check second compilation has been triggered + assert ctx.compile_times == 2 + parameter_ids_after_recompile = {id(param) for _, param in ctx.last_optimized_graph_module.named_parameters()} + assert parameter_ids == parameter_ids_after_recompile + + model(**yet_another_inputs) + assert ctx.compile_times == 3 + parameter_ids_after_recompile = {id(param) for _, param in ctx.last_optimized_graph_module.named_parameters()} + assert parameter_ids == parameter_ids_after_recompile + dist.barrier(tp_group) + tearDown(tp_group) + + +def run_test_parallel_results_matches_non_parallel( + rank: int, world_size: int, model_id: str, model_kwargs: Dict[str, Any] +): + # initialize default group + dist_init(rank, world_size) + tp_group = dist.new_group(ranks=[rank]) + + # prepare config and context + device = torch.device(type="cuda", index=torch.cuda.current_device()) + ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) + + model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + inputs = prepare_dummy_inputs(model.config) + + set_seed(SEED) + logits = model(**inputs)[0] + + torch._dynamo.reset() + del model + + tp_group = dist.new_group() + set_seed(SEED) + ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) + model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + parallel_logits = model(**inputs)[0] + + torch.testing.assert_close(logits.cpu(), parallel_logits.cpu(), rtol=1e-4, atol=1e-4) + + dist.barrier(tp_group) + tearDown() + + +def run_test_tie_word_embeddings(rank: int, world_size: int, model_id: str, model_kwargs: Dict[str, Any]): + dist_init(rank, world_size) + tp_group = dist.new_group() + + # prepare config and context + device = torch.device(type="cuda", index=torch.cuda.current_device()) + ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) + model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + + inputs = prepare_dummy_inputs(model.config) + model(**inputs) + + embedding_weight, lm_head_weight = None, None + graph_module = ctx.last_optimized_graph_module + stable_topological_sort(graph_module.graph) + for node in graph_module.graph.nodes: + if node.op == "call_module": + mod = graph_module.get_submodule(node.target) + if isinstance(mod, VocabParallelEmbedding): + embedding_weight = mod.weight + break + for node in reversed(graph_module.graph.nodes): + if node.op == "call_module": + mod = graph_module.get_submodule(node.target) + if isinstance(mod, ColumnParallelLinear): + lm_head_weight = mod.weight + break + assert ( + id(embedding_weight) == id(lm_head_weight) + and hasattr(embedding_weight, "meta") + and embedding_weight.meta.is_tied + ) + dist.barrier(tp_group) + tearDown() + + +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" +) +def test_all_rank_results_match( + model_id, + model_kwargs, +): + for world_size in [1, 2, 4, 8]: + if world_size <= NUM_AVAILABLE_DEVICES: + spawn(world_size, run_test_all_rank_results_match, model_id, model_kwargs, deterministic=True) + + +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available(), "requires gpu and torch version >= 2.3.0 to run" +) +def test_parameters_persist_bewteen_recompile( + model_id, + model_kwargs, +): + for world_size in [1, 2]: + if world_size <= NUM_AVAILABLE_DEVICES: + spawn( + world_size, run_test_parameters_persist_bewteen_recompile, model_id, model_kwargs, deterministic=False + ) + + +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available() or NUM_AVAILABLE_DEVICES < 2, + "requires more than one gpu and torch version >= 2.3.0 to run", +) +def test_parallel_results_matches_non_parallel( + model_id, + model_kwargs, +): + # world_size == 2 is enough + spawn(2, run_test_parallel_results_matches_non_parallel, model_id, model_kwargs, deterministic=True) + + +@parameterized.expand(DUMMY_MODELS_TO_TEST) +@unittest.skipIf( + not is_gpu_available() or not is_torch_compile_available(), + "requires gpu and torch version >= 2.3.0 to run", +) +def test_tie_word_embeddings( + model_id, + model_kwargs, +): + for world_size in [1, 2]: + if world_size <= NUM_AVAILABLE_DEVICES: + spawn(world_size, run_test_tie_word_embeddings, model_id, model_kwargs, deterministic=False)