Skip to content

Commit

Permalink
Merge pull request #46 from xrsrke/feature/pp-transformer
Browse files Browse the repository at this point in the history
[Refactor] Apply pre-commit to model partitioner
  • Loading branch information
xrsrke authored Nov 27, 2023
2 parents 3e5ff02 + ed433d1 commit c58f670
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 82 deletions.
95 changes: 31 additions & 64 deletions pipegoose/nn/pipeline_parallel/partitioner.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import re
from abc import ABC, abstractclassmethod
from collections import defaultdict
from enum import Enum, auto
from typing import List
from torch import nn
from typing import Dict, List, Optional

import torch
from typing import Dict
from collections import defaultdict
import re
from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from torch import nn
from transformers.utils.fx import symbolic_trace

from typing import Optional
from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode


class PartitionPolicy(Enum):
Expand All @@ -31,7 +30,7 @@ def __init__(self, model: nn.Module, parallel_context: ParallelContext):
self.parallel_context = parallel_context

def _find_transformer_block_prefix(self):
# Since there is no standarized way to find the name of the transformer block, we have to find the name first
# Since there is no standardized way to find the name of the transformer block, we have to find the name first
transformer_attr_name = self.model.base_model_prefix

transformer_block_name = None
Expand All @@ -42,29 +41,24 @@ def _find_transformer_block_prefix(self):
transformer_block_name = attr_name
break

transformer_block_prefix = (
rf"{transformer_attr_name}_{transformer_block_name}_(\d+)"
)
transformer_block_prefix = rf"{transformer_attr_name}_{transformer_block_name}_(\d+)"

print(f"transformer_block_prefix: {transformer_block_prefix}")
return transformer_block_prefix

def _get_transformer_block_id(self, node_name: str) -> Optional[str]:
pattern = self._find_transformer_block_prefix()
match = re.match(pattern, node_name)
return match.group() if match else None

def _split_nodes(
self, traced_graph_module: torch.fx.GraphModule, shard_count: int = 3
) -> Dict:
def _split_nodes(self, traced_graph_module: torch.fx.GraphModule, shard_count: int = 3) -> Dict:
"""Utility used to trace a graph and identify shard cutpoints."""

param_count: Dict[str, int] = {}

# Find the total number of params in the model and
# the number of params per shard we are aiming for.

# Note: we need to iterate over named_parameters AND named_modules because
# NOTE: we need to iterate over named_parameters AND named_modules because
# sometimes the parameters of a module is split into weight and bia in the
# traced graph and sometimes not.
# Example:
Expand All @@ -83,19 +77,18 @@ def _split_nodes(
# exclude embedding layers
if isinstance(module, nn.Embedding):
name = name.replace(".", "_")
exclude_param_count += sum([x.numel() for x in module.parameters()])
exclude_param_count += sum(x.numel() for x in module.parameters())
param_count[name] = 0
continue

name = name.replace(".", "_")
param_count[name] = sum([x.numel() for x in module.parameters()])
param_count[name] = sum(x.numel() for x in module.parameters())

# Note that we use param_count[""] as total parameters which does not include the
# lm_head. We want to exclude the lm_head and the embeddings here because they
# usually cause high skew which does not lead to equal partitioning of the
# transformer blocks.
per_shard_param = (param_count[""] - exclude_param_count) // shard_count
original_per_shard_param = per_shard_param
node_name_to_shard_id: Dict[str, int] = {}
shard_id = 0
shard_id_to_param_count = [0 for _ in range(shard_count)]
Expand Down Expand Up @@ -123,8 +116,7 @@ def _split_nodes(

if (
shard_id_to_param_count[shard_id] > 0
and (shard_id_to_param_count[shard_id] + current_param_count)
>= per_shard_param
and (shard_id_to_param_count[shard_id] + current_param_count) >= per_shard_param
and (shard_id + 1) < shard_count
and transformer_block_ended
):
Expand Down Expand Up @@ -158,9 +150,7 @@ def split(self, input_names: List[str]) -> List[nn.Module]:

symbolic_traced_module = symbolic_trace(model, input_names=input_names)

node_name_to_shard_id, output_from_shard = self._split_nodes(
symbolic_traced_module, n_partitions
)
node_name_to_shard_id, output_from_shard = self._split_nodes(symbolic_traced_module, n_partitions)

nodes_per_shard = defaultdict(dict)

Expand All @@ -169,56 +159,33 @@ def split(self, input_names: List[str]) -> List[nn.Module]:
for node in symbolic_traced_module.graph.nodes:
# If the current node is in the next shard, we insert an output node.
# A new graph is created and a placeholder is added for the next shard.
if (
node.name in node_name_to_shard_id
and prev_shard_id < node_name_to_shard_id[node.name]
):
if node.name in node_name_to_shard_id and prev_shard_id < node_name_to_shard_id[node.name]:
assert prev_node, "prev_node cannot be None"

# generate output node for the past graph/shard
with new_graph.inserting_after(prev_node):
outputs = output_from_shard[prev_shard_id]
graph_output_names = [prev_node.name] + [
i for i in outputs if i != prev_node.name
]

if isinstance(
nodes_per_shard[prev_shard_id][prev_node.name], tuple
):
graph_outputs = nodes_per_shard[prev_shard_id][
prev_node.name
] + tuple(
[
nodes_per_shard[prev_shard_id][i]
for i in outputs
if i != prev_node.name
]
graph_output_names = [prev_node.name] + [i for i in outputs if i != prev_node.name]

if isinstance(nodes_per_shard[prev_shard_id][prev_node.name], tuple):
graph_outputs = nodes_per_shard[prev_shard_id][prev_node.name] + tuple(
nodes_per_shard[prev_shard_id][i] for i in outputs if i != prev_node.name
)
else:
graph_outputs = tuple(
[nodes_per_shard[prev_shard_id][prev_node.name]]
+ [
nodes_per_shard[prev_shard_id][i]
for i in outputs
if i != prev_node.name
]
+ [nodes_per_shard[prev_shard_id][i] for i in outputs if i != prev_node.name]
)
new_graph.create_node(
op="output", target="output", args=(graph_outputs,)
)
new_graph.create_node(op="output", target="output", args=(graph_outputs,))

# generate new graph/shard and its input nodes (i.e., the output from the previous graph/shard)
num_graphs += 1
module_list.append(torch.fx.GraphModule(model, new_graph))
new_graph = torch.fx.Graph()
# generate placeholder nodes in the new graph/shard which matches the output nodes of the previous graph/shard
for new_graph_input_name in graph_output_names:
graph_input_node = new_graph.create_node(
"placeholder", new_graph_input_name
)
nodes_per_shard[node_name_to_shard_id[node.name]][
new_graph_input_name
] = graph_input_node
graph_input_node = new_graph.create_node("placeholder", new_graph_input_name)
nodes_per_shard[node_name_to_shard_id[node.name]][new_graph_input_name] = graph_input_node

if node.op in [
"placeholder",
Expand All @@ -229,9 +196,7 @@ def split(self, input_names: List[str]) -> List[nn.Module]:
]:
# Copy the nodes from the existing graph to the new graph.
current_shard_id = node_name_to_shard_id[node.name]
new_node = new_graph.node_copy(
node, lambda x: nodes_per_shard[current_shard_id][x.name]
)
new_node = new_graph.node_copy(node, lambda x: nodes_per_shard[current_shard_id][x.name])
nodes_per_shard[current_shard_id][node.name] = new_node
elif node.op == "output":
# If this is the last node, we should add an output
Expand All @@ -245,6 +210,10 @@ def split(self, input_names: List[str]) -> List[nn.Module]:
prev_node = new_node
prev_shard_id = node_name_to_shard_id[node.name]

for shard in module_list:
shard.graph.lint()
shard.recompile()

return module_list


Expand All @@ -257,9 +226,7 @@ def _get_partitioner(policy: PartitionPolicy) -> BasePartitioner:
return policy_to_partitioner[policy]


def get_model_partition(
module: nn.Module, policy: PartitionPolicy, parallel_context: ParallelContext
) -> nn.Module:
def get_model_partition(module: nn.Module, policy: PartitionPolicy, parallel_context: ParallelContext) -> nn.Module:
"""Get the corresponding partition of the current process."""
partitioner = _get_partitioner(policy)
partitions = partitioner(module, parallel_context).split()
Expand Down
27 changes: 9 additions & 18 deletions tests/nn/pipeline_parallel/test_partitioner.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,28 @@
import pytest
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoTokenizer,
BloomConfig,
BloomForCausalLM,
)
from pipegoose.nn.pipeline_parallel.partitioner import ( # PartitionPolicy,; get_model_partition,
UniformPartitioner,
)

from pipegoose.nn.pipeline_parallel.partitioner import UniformPartitioner
from pipegoose.testing.utils import init_parallel_context, spawn


def get_gpt2_and_tokenizer():
return AutoModelForCausalLM.from_pretrained("gpt2"), AutoTokenizer.from_pretrained(
"gpt2"
)
return AutoModelForCausalLM.from_pretrained("gpt2"), AutoTokenizer.from_pretrained("gpt2")


def get_bloom_560m_and_tokenizer():
return AutoModelForCausalLM.from_pretrained(
return AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m"), AutoTokenizer.from_pretrained(
"bigscience/bloom-560m"
), AutoTokenizer.from_pretrained("bigscience/bloom-560m")
)


def get_bloom_and_tokenizer_with_6_layers():
return BloomForCausalLM(BloomConfig(n_layer=6)), AutoTokenizer.from_pretrained(
"bigscience/bloom-560m"
)
return BloomForCausalLM(BloomConfig(n_layer=6)), AutoTokenizer.from_pretrained("bigscience/bloom-560m")


# TODO: Also add a function for a generic nn.Transformer model
Expand Down Expand Up @@ -81,13 +76,9 @@ def run_model_partitioner(
partitioned_model_result = inputs["input_ids"]
for partition_id in range(pipeline_parallel_size):
if type(partitioned_model_result) in (list, tuple):
partitioned_model_result = partitioned_model[partition_id](
*partitioned_model_result
)
partitioned_model_result = partitioned_model[partition_id](*partitioned_model_result)
else:
partitioned_model_result = partitioned_model[partition_id](
partitioned_model_result
)
partitioned_model_result = partitioned_model[partition_id](partitioned_model_result)

assert torch.allclose(gt_logits, partitioned_model_result), "Results are not close"

Expand Down

0 comments on commit c58f670

Please sign in to comment.