Skip to content

Commit

Permalink
Merge branch 'main' into feature/moe
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke authored Nov 27, 2023
2 parents 7a28f7a + db8ae11 commit bc2664a
Show file tree
Hide file tree
Showing 14 changed files with 395 additions and 154 deletions.
15 changes: 6 additions & 9 deletions pipegoose/distributed/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import random
from typing import Dict, List, Literal

import numpy as np
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
Expand Down Expand Up @@ -124,17 +125,12 @@ def __init__(

self.init_global_dist(rank, world_size, backend, host, port)
self.init_parallel_groups()

# if torch.cuda.is_available():
# self.set_device()

self.map_rank_to_device()

self.rpc_worker_map = {rank: WORKER_NAME.format(rank) for rank in self.get_ranks_in_group(ParallelMode.GLOBAL)}
self.init_rpc_workers(host, port)

# self.set_seed(seed)

self.set_seed(seed)
self._set_context()

def _set_context(self):
Expand Down Expand Up @@ -253,11 +249,12 @@ def set_device(self):
def set_seed(self, seed: int):
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# TODO: set GPU seed
# if torch.cuda.is_available():
# pass
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def map_rank_to_device(self):
"""Map global rank to device."""
Expand Down
11 changes: 8 additions & 3 deletions pipegoose/nn/pipeline_parallel/_job/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@ class _SaveGradLossFunction(torch.autograd.Function):
def forward(ctx: Any, key, metadata, tensor: torch.Tensor):
ctx.key = key
ctx.package_metadata = metadata
new_tensor = tensor.detach().clone()
# NOTE: a hacky way to work around with `transformers`
if isinstance(tensor, torch.Tensor):
new_tensor = tensor.detach().clone()
elif isinstance(tensor, tuple):
new_tensor = tuple(t.detach().clone() for t in tensor)
else:
raise ValueError(f"tensor must be an instance of torch.Tensor or tuple, got {type(tensor)}")

return new_tensor

@staticmethod
Expand All @@ -45,8 +52,6 @@ def backward(ctx: Any, grad_output: torch.Tensor):
def save_grad_loss(package: Package) -> Package:
key = (package.metadata.microbatch_idx, package.metadata.partition_idx)
package.data = _SaveGradLossFunction.apply(key, package.metadata, package.data)
if package.metadata.partition_idx == 3:
assert 1 == 1
return package


Expand Down
17 changes: 10 additions & 7 deletions pipegoose/nn/pipeline_parallel/_job/creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,14 @@ def __init__(self, pipeline_context: PipelineContext):
self.pipeline_context = pipeline_context

def after_compute(self):
from pipegoose.nn.pipeline_parallel.queue import (
get_input_activations,
get_output_activations,
)
pass

package = self.job.output
microbatch_idx = self.job.input.metadata.microbatch_idx
partition_idx = self.job.input.metadata.partition_idx

assert isinstance(get_input_activations(microbatch_idx, partition_idx), torch.Tensor)
assert isinstance(get_output_activations(microbatch_idx, partition_idx), torch.Tensor)
# assert isinstance(get_input_activations(microbatch_idx, partition_idx), torch.Tensor)
# assert isinstance(get_output_activations(microbatch_idx, partition_idx), torch.Tensor)

if package.metadata.microbatch_idx == self.pipeline_context.num_microbatches - 1:
new_package = schedule_backward_execution(package, self.pipeline_context)
Expand Down Expand Up @@ -187,7 +184,13 @@ class Function(torch.autograd.Function):
@staticmethod
def forward(ctx, metadata: Metadata, input: torch.Tensor) -> torch.Tensor:
ctx.package_meta = metadata
new_input = input.detach().clone()
# NOTE: a hacky way to make it works with `transformers`
if type(input) in (list, tuple):
# NOTE: ignore attention mask, which is a bool tensor
new_input = [x.detach().clone() for x in input]
else:
new_input = input.detach().clone()

return new_input

@staticmethod
Expand Down
12 changes: 11 additions & 1 deletion pipegoose/nn/pipeline_parallel/_job/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,18 @@
class ForwardJob(Job):
def run_compute(self) -> torch.Tensor:
is_training = self.input.metadata.training.is_training

with torch.set_grad_enabled(is_training):
output = self.function(self.input.data)
# TODO: a hacky way to work around with `transformers`
if isinstance(self.input.data, torch.Tensor):
output = self.function(self.input.data)
elif type(self.input.data) in (list, tuple):
output = self.function(*self.input.data)
elif "input_ids" in self.input.data:
output = self.function(self.input.data["input_ids"])
else:
output = self.function(self.input.data)

return output


Expand Down
11 changes: 8 additions & 3 deletions pipegoose/nn/pipeline_parallel/microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@ class ModelInputs(TypedDict):

def split(inputs: ModelInputs, n_microbatches: int) -> List[ModelInputs]:
assert n_microbatches > 0, f"n_microbatches must be greater than 0, got {n_microbatches}"

input_ids_microbatches = torch.split(inputs["input_ids"], 2)
attention_mask_microbatches = torch.split(inputs["attention_mask"], 2)
assert "input_ids" in inputs, f"inputs must have 'input_ids' key, got {inputs.keys()}"
assert "attention_mask" in inputs, f"inputs must have 'attention_mask' key, got {inputs.keys()}"
assert (
inputs["input_ids"].size(0) % n_microbatches == 0
), f"The batch size must be divisible by n_microbatches, got {inputs['input_ids'].size(0)} and {n_microbatches}"

input_ids_microbatches = torch.split(inputs["input_ids"], n_microbatches)
attention_mask_microbatches = torch.split(inputs["attention_mask"], n_microbatches)

microbatches = []
for input_ids, attention_mask in zip(input_ids_microbatches, attention_mask_microbatches):
Expand Down
227 changes: 194 additions & 33 deletions pipegoose/nn/pipeline_parallel/partitioner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import re
from abc import ABC, abstractclassmethod
from collections import defaultdict
from enum import Enum, auto
from typing import List
from typing import Dict, List, Optional

import torch
from torch import nn
from transformers.utils.fx import symbolic_trace

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode

INPUT_NAMES = ["input_ids", "attention_mask"]


class PartitionPolicy(Enum):
UNIFORM = auto()
Expand All @@ -21,41 +27,196 @@ def split(self) -> List[nn.Module]:


class UniformPartitioner(BasePartitioner):
def __init__(self, module: nn.Module, parallel_context: ParallelContext):
self.module = module
def __init__(self, model: nn.Module, parallel_context: ParallelContext):
self.model = model
self.parallel_context = parallel_context

def split(self) -> List[nn.Module]:
module = self.module
def _find_transformer_block_prefix(self):
# Since there is no standardized way to find the name of the transformer block, we have to find the name first
transformer_attr_name = self.model.base_model_prefix

transformer_block_name = None
transformer = getattr(self.model, transformer_attr_name)
for attr_name in dir(transformer):
part = getattr(transformer, attr_name)
if isinstance(part, nn.ModuleList):
transformer_block_name = attr_name
break

transformer_block_prefix = rf"{transformer_attr_name}_{transformer_block_name}_(\d+)"

return transformer_block_prefix

def _get_transformer_block_id(self, node_name: str) -> Optional[str]:
pattern = self._find_transformer_block_prefix()
match = re.match(pattern, node_name)
return match.group() if match else None

def _split_nodes(self, traced_graph_module: torch.fx.GraphModule, shard_count: int = 3) -> Dict:
"""Utility used to trace a graph and identify shard cutpoints."""

param_count: Dict[str, int] = {}

# Find the total number of params in the model and
# the number of params per shard we are aiming for.

# NOTE: we need to iterate over named_parameters AND named_modules because
# sometimes the parameters of a module is split into weight and bia in the
# traced graph and sometimes not.
# Example:
# The embedding in traced graph is called transformer_wte. The naming as parameter
# is transformer.wte.weight while as a module it is transformer.wte
#
# The projection inside the attention layer is split into weight and bias
# in the traced graph while in the module we only see the projection as a whole module.

for name, param in traced_graph_module.named_parameters():
name = name.replace(".", "_")
param_count[name] = param.numel()

exclude_param_count = 0
for name, module in traced_graph_module.named_modules():
# exclude embedding layers
if isinstance(module, nn.Embedding):
name = name.replace(".", "_")
exclude_param_count += sum(x.numel() for x in module.parameters())
param_count[name] = 0
continue

name = name.replace(".", "_")
param_count[name] = sum(x.numel() for x in module.parameters())

# Note that we use param_count[""] as total parameters which does not include the
# lm_head. We want to exclude the lm_head and the embeddings here because they
# usually cause high skew which does not lead to equal partitioning of the
# transformer blocks.
per_shard_param = (param_count[""] - exclude_param_count) // shard_count
node_name_to_shard_id: Dict[str, int] = {}
shard_id = 0
shard_id_to_param_count = [0 for _ in range(shard_count)]
output_from_shard = {}
transformer_block_ended = False
current_transformer_block = 0

for node in traced_graph_module.graph.nodes:
if node.op == "output":
break

# While splitting the nodes, we have to check if a node of the transformer block is detected
# If so, we have to force it in the current shard
if node.op in ("call_module", "get_attr"):
current_param_count = param_count.get(node.name, 0)
new_transformer_block_id = self._get_transformer_block_id(node.name)
# Check if a new transformer block has started
if new_transformer_block_id != current_transformer_block:
# End the previous block and start a new one
current_transformer_block = new_transformer_block_id
transformer_block_ended = True
else:
# We are still in the same transformer block
transformer_block_ended = False

if (
shard_id_to_param_count[shard_id] > 0
and (shard_id_to_param_count[shard_id] + current_param_count) >= per_shard_param
and (shard_id + 1) < shard_count
and transformer_block_ended
):
shard_id += 1
transformer_block_ended = False
shard_id_to_param_count[shard_id] += current_param_count

# we need to collect the nodes from the previous shards which are needed in the
# current shard because we need to propagate them until the current shard
if hasattr(node, "args"):
for arg in node.args:
if not hasattr(arg, "name"):
continue

arg_shard_id = node_name_to_shard_id.get(arg.name, shard_id)
if arg_shard_id < shard_id:
# propagate the input from arg_shard_id until shard_id
for idx in range(arg_shard_id, shard_id):
# note that we use the dict as an ordered set data structure
output_from_shard.setdefault(idx, dict())[arg.name] = None

node_name_to_shard_id[node.name] = shard_id
return node_name_to_shard_id, output_from_shard

def split(self, input_names: List[str] = INPUT_NAMES) -> List[nn.Module]:
n_partitions = self.parallel_context.pipeline_parallel_size

# NOTE: BLOOM-560
# embedding_module = module.transformer.word_embeddings
# transformer_blocks = module.transformer.h
# lm_head = module.lm_head

# NOTE: For sshleifer/tiny-gpt2
embed_module = module.transformer.wte
pos_embed_module = module.transformer.wpe
drop_module = module.transformer.drop
transformer_blocks = module.transformer.h
ln_f = module.transformer.ln_f
lm_head = module.lm_head

# NOTE: Calculate the number of transformer blocks per partition
blocks_per_partition = len(transformer_blocks) // n_partitions
partitions = []

for i in range(n_partitions):
start = i * blocks_per_partition
# NOTE: if it's the last partition, get all remaining blocks
end = start + blocks_per_partition if i < n_partitions - 1 else None
partitions.append(nn.Sequential(*transformer_blocks[start:end]))

partitions[0] = nn.Sequential(embed_module, pos_embed_module, drop_module, partitions[0])
partitions[-1] = nn.Sequential(ln_f, lm_head, partitions[-1])

return partitions
model = self.model
module_list: List[torch.fx.GraphModule] = []
num_graphs = 0
new_graph = torch.fx.Graph() # type: ignore

symbolic_traced_module = symbolic_trace(model, input_names=input_names)

node_name_to_shard_id, output_from_shard = self._split_nodes(symbolic_traced_module, n_partitions)

nodes_per_shard = defaultdict(dict)

prev_shard_id = 1000
prev_node = None
for node in symbolic_traced_module.graph.nodes:
# If the current node is in the next shard, we insert an output node.
# A new graph is created and a placeholder is added for the next shard.
if node.name in node_name_to_shard_id and prev_shard_id < node_name_to_shard_id[node.name]:
assert prev_node, "prev_node cannot be None"

# generate output node for the past graph/shard
with new_graph.inserting_after(prev_node):
outputs = output_from_shard[prev_shard_id]
graph_output_names = [prev_node.name] + [i for i in outputs if i != prev_node.name]

if isinstance(nodes_per_shard[prev_shard_id][prev_node.name], tuple):
graph_outputs = nodes_per_shard[prev_shard_id][prev_node.name] + tuple(
nodes_per_shard[prev_shard_id][i] for i in outputs if i != prev_node.name
)
else:
graph_outputs = tuple(
[nodes_per_shard[prev_shard_id][prev_node.name]]
+ [nodes_per_shard[prev_shard_id][i] for i in outputs if i != prev_node.name]
)
new_graph.create_node(op="output", target="output", args=(graph_outputs,))

# generate new graph/shard and its input nodes (i.e., the output from the previous graph/shard)
num_graphs += 1
module_list.append(torch.fx.GraphModule(model, new_graph))
new_graph = torch.fx.Graph()
# generate placeholder nodes in the new graph/shard which matches the output nodes of the previous graph/shard
for new_graph_input_name in graph_output_names:
graph_input_node = new_graph.create_node("placeholder", new_graph_input_name)
nodes_per_shard[node_name_to_shard_id[node.name]][new_graph_input_name] = graph_input_node

if node.op in [
"placeholder",
"get_attr",
"call_function",
"call_method",
"call_module",
]:
# Copy the nodes from the existing graph to the new graph.
current_shard_id = node_name_to_shard_id[node.name]
new_node = new_graph.node_copy(node, lambda x: nodes_per_shard[current_shard_id][x.name])
nodes_per_shard[current_shard_id][node.name] = new_node
elif node.op == "output":
# If this is the last node, we should add an output
# node and add the last graph to the list.
assert prev_node, "prev_node cannot be None"

with new_graph.inserting_after(prev_node):
new_graph.output(nodes_per_shard[prev_shard_id][prev_node.name])
module_list.append(torch.fx.GraphModule(model, new_graph))
break
prev_node = new_node
prev_shard_id = node_name_to_shard_id[node.name]

for shard in module_list:
shard.graph.lint()
shard.recompile()

return module_list


def _get_partitioner(policy: PartitionPolicy) -> BasePartitioner:
Expand Down
Loading

0 comments on commit bc2664a

Please sign in to comment.