From ed433d11b78b9c7bdd7e5dfccc786cbf8a5701de Mon Sep 17 00:00:00 2001 From: xrsrke Date: Mon, 27 Nov 2023 09:16:01 +0700 Subject: [PATCH] [Refactor] Apply pre-commit to model partitioner --- pipegoose/nn/pipeline_parallel/partitioner.py | 95 ++++++------------- .../nn/pipeline_parallel/test_partitioner.py | 27 ++---- 2 files changed, 40 insertions(+), 82 deletions(-) diff --git a/pipegoose/nn/pipeline_parallel/partitioner.py b/pipegoose/nn/pipeline_parallel/partitioner.py index 13fd8c0..391c0f2 100644 --- a/pipegoose/nn/pipeline_parallel/partitioner.py +++ b/pipegoose/nn/pipeline_parallel/partitioner.py @@ -1,16 +1,15 @@ +import re from abc import ABC, abstractclassmethod +from collections import defaultdict from enum import Enum, auto -from typing import List -from torch import nn +from typing import Dict, List, Optional + import torch -from typing import Dict -from collections import defaultdict -import re -from pipegoose.distributed.parallel_context import ParallelContext -from pipegoose.distributed.parallel_mode import ParallelMode +from torch import nn from transformers.utils.fx import symbolic_trace -from typing import Optional +from pipegoose.distributed.parallel_context import ParallelContext +from pipegoose.distributed.parallel_mode import ParallelMode class PartitionPolicy(Enum): @@ -31,7 +30,7 @@ def __init__(self, model: nn.Module, parallel_context: ParallelContext): self.parallel_context = parallel_context def _find_transformer_block_prefix(self): - # Since there is no standarized way to find the name of the transformer block, we have to find the name first + # Since there is no standardized way to find the name of the transformer block, we have to find the name first transformer_attr_name = self.model.base_model_prefix transformer_block_name = None @@ -42,11 +41,8 @@ def _find_transformer_block_prefix(self): transformer_block_name = attr_name break - transformer_block_prefix = ( - rf"{transformer_attr_name}_{transformer_block_name}_(\d+)" - ) + transformer_block_prefix = rf"{transformer_attr_name}_{transformer_block_name}_(\d+)" - print(f"transformer_block_prefix: {transformer_block_prefix}") return transformer_block_prefix def _get_transformer_block_id(self, node_name: str) -> Optional[str]: @@ -54,9 +50,7 @@ def _get_transformer_block_id(self, node_name: str) -> Optional[str]: match = re.match(pattern, node_name) return match.group() if match else None - def _split_nodes( - self, traced_graph_module: torch.fx.GraphModule, shard_count: int = 3 - ) -> Dict: + def _split_nodes(self, traced_graph_module: torch.fx.GraphModule, shard_count: int = 3) -> Dict: """Utility used to trace a graph and identify shard cutpoints.""" param_count: Dict[str, int] = {} @@ -64,7 +58,7 @@ def _split_nodes( # Find the total number of params in the model and # the number of params per shard we are aiming for. - # Note: we need to iterate over named_parameters AND named_modules because + # NOTE: we need to iterate over named_parameters AND named_modules because # sometimes the parameters of a module is split into weight and bia in the # traced graph and sometimes not. # Example: @@ -83,19 +77,18 @@ def _split_nodes( # exclude embedding layers if isinstance(module, nn.Embedding): name = name.replace(".", "_") - exclude_param_count += sum([x.numel() for x in module.parameters()]) + exclude_param_count += sum(x.numel() for x in module.parameters()) param_count[name] = 0 continue name = name.replace(".", "_") - param_count[name] = sum([x.numel() for x in module.parameters()]) + param_count[name] = sum(x.numel() for x in module.parameters()) # Note that we use param_count[""] as total parameters which does not include the # lm_head. We want to exclude the lm_head and the embeddings here because they # usually cause high skew which does not lead to equal partitioning of the # transformer blocks. per_shard_param = (param_count[""] - exclude_param_count) // shard_count - original_per_shard_param = per_shard_param node_name_to_shard_id: Dict[str, int] = {} shard_id = 0 shard_id_to_param_count = [0 for _ in range(shard_count)] @@ -123,8 +116,7 @@ def _split_nodes( if ( shard_id_to_param_count[shard_id] > 0 - and (shard_id_to_param_count[shard_id] + current_param_count) - >= per_shard_param + and (shard_id_to_param_count[shard_id] + current_param_count) >= per_shard_param and (shard_id + 1) < shard_count and transformer_block_ended ): @@ -158,9 +150,7 @@ def split(self, input_names: List[str]) -> List[nn.Module]: symbolic_traced_module = symbolic_trace(model, input_names=input_names) - node_name_to_shard_id, output_from_shard = self._split_nodes( - symbolic_traced_module, n_partitions - ) + node_name_to_shard_id, output_from_shard = self._split_nodes(symbolic_traced_module, n_partitions) nodes_per_shard = defaultdict(dict) @@ -169,43 +159,24 @@ def split(self, input_names: List[str]) -> List[nn.Module]: for node in symbolic_traced_module.graph.nodes: # If the current node is in the next shard, we insert an output node. # A new graph is created and a placeholder is added for the next shard. - if ( - node.name in node_name_to_shard_id - and prev_shard_id < node_name_to_shard_id[node.name] - ): + if node.name in node_name_to_shard_id and prev_shard_id < node_name_to_shard_id[node.name]: assert prev_node, "prev_node cannot be None" # generate output node for the past graph/shard with new_graph.inserting_after(prev_node): outputs = output_from_shard[prev_shard_id] - graph_output_names = [prev_node.name] + [ - i for i in outputs if i != prev_node.name - ] - - if isinstance( - nodes_per_shard[prev_shard_id][prev_node.name], tuple - ): - graph_outputs = nodes_per_shard[prev_shard_id][ - prev_node.name - ] + tuple( - [ - nodes_per_shard[prev_shard_id][i] - for i in outputs - if i != prev_node.name - ] + graph_output_names = [prev_node.name] + [i for i in outputs if i != prev_node.name] + + if isinstance(nodes_per_shard[prev_shard_id][prev_node.name], tuple): + graph_outputs = nodes_per_shard[prev_shard_id][prev_node.name] + tuple( + nodes_per_shard[prev_shard_id][i] for i in outputs if i != prev_node.name ) else: graph_outputs = tuple( [nodes_per_shard[prev_shard_id][prev_node.name]] - + [ - nodes_per_shard[prev_shard_id][i] - for i in outputs - if i != prev_node.name - ] + + [nodes_per_shard[prev_shard_id][i] for i in outputs if i != prev_node.name] ) - new_graph.create_node( - op="output", target="output", args=(graph_outputs,) - ) + new_graph.create_node(op="output", target="output", args=(graph_outputs,)) # generate new graph/shard and its input nodes (i.e., the output from the previous graph/shard) num_graphs += 1 @@ -213,12 +184,8 @@ def split(self, input_names: List[str]) -> List[nn.Module]: new_graph = torch.fx.Graph() # generate placeholder nodes in the new graph/shard which matches the output nodes of the previous graph/shard for new_graph_input_name in graph_output_names: - graph_input_node = new_graph.create_node( - "placeholder", new_graph_input_name - ) - nodes_per_shard[node_name_to_shard_id[node.name]][ - new_graph_input_name - ] = graph_input_node + graph_input_node = new_graph.create_node("placeholder", new_graph_input_name) + nodes_per_shard[node_name_to_shard_id[node.name]][new_graph_input_name] = graph_input_node if node.op in [ "placeholder", @@ -229,9 +196,7 @@ def split(self, input_names: List[str]) -> List[nn.Module]: ]: # Copy the nodes from the existing graph to the new graph. current_shard_id = node_name_to_shard_id[node.name] - new_node = new_graph.node_copy( - node, lambda x: nodes_per_shard[current_shard_id][x.name] - ) + new_node = new_graph.node_copy(node, lambda x: nodes_per_shard[current_shard_id][x.name]) nodes_per_shard[current_shard_id][node.name] = new_node elif node.op == "output": # If this is the last node, we should add an output @@ -245,6 +210,10 @@ def split(self, input_names: List[str]) -> List[nn.Module]: prev_node = new_node prev_shard_id = node_name_to_shard_id[node.name] + for shard in module_list: + shard.graph.lint() + shard.recompile() + return module_list @@ -257,9 +226,7 @@ def _get_partitioner(policy: PartitionPolicy) -> BasePartitioner: return policy_to_partitioner[policy] -def get_model_partition( - module: nn.Module, policy: PartitionPolicy, parallel_context: ParallelContext -) -> nn.Module: +def get_model_partition(module: nn.Module, policy: PartitionPolicy, parallel_context: ParallelContext) -> nn.Module: """Get the corresponding partition of the current process.""" partitioner = _get_partitioner(policy) partitions = partitioner(module, parallel_context).split() diff --git a/tests/nn/pipeline_parallel/test_partitioner.py b/tests/nn/pipeline_parallel/test_partitioner.py index 081672e..9d2c22d 100644 --- a/tests/nn/pipeline_parallel/test_partitioner.py +++ b/tests/nn/pipeline_parallel/test_partitioner.py @@ -1,33 +1,28 @@ import pytest import torch from transformers import ( - AutoTokenizer, AutoModelForCausalLM, + AutoTokenizer, BloomConfig, BloomForCausalLM, ) -from pipegoose.nn.pipeline_parallel.partitioner import ( # PartitionPolicy,; get_model_partition, - UniformPartitioner, -) + +from pipegoose.nn.pipeline_parallel.partitioner import UniformPartitioner from pipegoose.testing.utils import init_parallel_context, spawn def get_gpt2_and_tokenizer(): - return AutoModelForCausalLM.from_pretrained("gpt2"), AutoTokenizer.from_pretrained( - "gpt2" - ) + return AutoModelForCausalLM.from_pretrained("gpt2"), AutoTokenizer.from_pretrained("gpt2") def get_bloom_560m_and_tokenizer(): - return AutoModelForCausalLM.from_pretrained( + return AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m"), AutoTokenizer.from_pretrained( "bigscience/bloom-560m" - ), AutoTokenizer.from_pretrained("bigscience/bloom-560m") + ) def get_bloom_and_tokenizer_with_6_layers(): - return BloomForCausalLM(BloomConfig(n_layer=6)), AutoTokenizer.from_pretrained( - "bigscience/bloom-560m" - ) + return BloomForCausalLM(BloomConfig(n_layer=6)), AutoTokenizer.from_pretrained("bigscience/bloom-560m") # TODO: Also add a function for a generic nn.Transformer model @@ -81,13 +76,9 @@ def run_model_partitioner( partitioned_model_result = inputs["input_ids"] for partition_id in range(pipeline_parallel_size): if type(partitioned_model_result) in (list, tuple): - partitioned_model_result = partitioned_model[partition_id]( - *partitioned_model_result - ) + partitioned_model_result = partitioned_model[partition_id](*partitioned_model_result) else: - partitioned_model_result = partitioned_model[partition_id]( - partitioned_model_result - ) + partitioned_model_result = partitioned_model[partition_id](partitioned_model_result) assert torch.allclose(gt_logits, partitioned_model_result), "Results are not close"