diff --git a/pipegoose/distributed/parallel_context.py b/pipegoose/distributed/parallel_context.py index 652d0f3..cd7262c 100644 --- a/pipegoose/distributed/parallel_context.py +++ b/pipegoose/distributed/parallel_context.py @@ -19,6 +19,7 @@ import random from typing import Dict, List, Literal +import numpy as np import torch import torch.distributed as dist import torch.distributed.rpc as rpc @@ -124,17 +125,12 @@ def __init__( self.init_global_dist(rank, world_size, backend, host, port) self.init_parallel_groups() - - # if torch.cuda.is_available(): - # self.set_device() - self.map_rank_to_device() self.rpc_worker_map = {rank: WORKER_NAME.format(rank) for rank in self.get_ranks_in_group(ParallelMode.GLOBAL)} self.init_rpc_workers(host, port) - # self.set_seed(seed) - + self.set_seed(seed) self._set_context() def _set_context(self): @@ -253,11 +249,12 @@ def set_device(self): def set_seed(self, seed: int): """Set seed for reproducibility.""" random.seed(seed) + np.random.seed(seed) torch.manual_seed(seed) - # TODO: set GPU seed - # if torch.cuda.is_available(): - # pass + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) def map_rank_to_device(self): """Map global rank to device.""" diff --git a/pipegoose/nn/pipeline_parallel/_job/backward.py b/pipegoose/nn/pipeline_parallel/_job/backward.py index ea05c0a..88d27ca 100644 --- a/pipegoose/nn/pipeline_parallel/_job/backward.py +++ b/pipegoose/nn/pipeline_parallel/_job/backward.py @@ -21,7 +21,14 @@ class _SaveGradLossFunction(torch.autograd.Function): def forward(ctx: Any, key, metadata, tensor: torch.Tensor): ctx.key = key ctx.package_metadata = metadata - new_tensor = tensor.detach().clone() + # NOTE: a hacky way to work around with `transformers` + if isinstance(tensor, torch.Tensor): + new_tensor = tensor.detach().clone() + elif isinstance(tensor, tuple): + new_tensor = tuple(t.detach().clone() for t in tensor) + else: + raise ValueError(f"tensor must be an instance of torch.Tensor or tuple, got {type(tensor)}") + return new_tensor @staticmethod @@ -45,8 +52,6 @@ def backward(ctx: Any, grad_output: torch.Tensor): def save_grad_loss(package: Package) -> Package: key = (package.metadata.microbatch_idx, package.metadata.partition_idx) package.data = _SaveGradLossFunction.apply(key, package.metadata, package.data) - if package.metadata.partition_idx == 3: - assert 1 == 1 return package diff --git a/pipegoose/nn/pipeline_parallel/_job/creator.py b/pipegoose/nn/pipeline_parallel/_job/creator.py index 85ec14b..3554f09 100644 --- a/pipegoose/nn/pipeline_parallel/_job/creator.py +++ b/pipegoose/nn/pipeline_parallel/_job/creator.py @@ -40,17 +40,14 @@ def __init__(self, pipeline_context: PipelineContext): self.pipeline_context = pipeline_context def after_compute(self): - from pipegoose.nn.pipeline_parallel.queue import ( - get_input_activations, - get_output_activations, - ) + pass package = self.job.output microbatch_idx = self.job.input.metadata.microbatch_idx partition_idx = self.job.input.metadata.partition_idx - assert isinstance(get_input_activations(microbatch_idx, partition_idx), torch.Tensor) - assert isinstance(get_output_activations(microbatch_idx, partition_idx), torch.Tensor) + # assert isinstance(get_input_activations(microbatch_idx, partition_idx), torch.Tensor) + # assert isinstance(get_output_activations(microbatch_idx, partition_idx), torch.Tensor) if package.metadata.microbatch_idx == self.pipeline_context.num_microbatches - 1: new_package = schedule_backward_execution(package, self.pipeline_context) @@ -187,7 +184,13 @@ class Function(torch.autograd.Function): @staticmethod def forward(ctx, metadata: Metadata, input: torch.Tensor) -> torch.Tensor: ctx.package_meta = metadata - new_input = input.detach().clone() + # NOTE: a hacky way to make it works with `transformers` + if type(input) in (list, tuple): + # NOTE: ignore attention mask, which is a bool tensor + new_input = [x.detach().clone() for x in input] + else: + new_input = input.detach().clone() + return new_input @staticmethod diff --git a/pipegoose/nn/pipeline_parallel/_job/forward.py b/pipegoose/nn/pipeline_parallel/_job/forward.py index a511347..241e45f 100644 --- a/pipegoose/nn/pipeline_parallel/_job/forward.py +++ b/pipegoose/nn/pipeline_parallel/_job/forward.py @@ -14,8 +14,18 @@ class ForwardJob(Job): def run_compute(self) -> torch.Tensor: is_training = self.input.metadata.training.is_training + with torch.set_grad_enabled(is_training): - output = self.function(self.input.data) + # TODO: a hacky way to work around with `transformers` + if isinstance(self.input.data, torch.Tensor): + output = self.function(self.input.data) + elif type(self.input.data) in (list, tuple): + output = self.function(*self.input.data) + elif "input_ids" in self.input.data: + output = self.function(self.input.data["input_ids"]) + else: + output = self.function(self.input.data) + return output diff --git a/pipegoose/nn/pipeline_parallel/microbatch.py b/pipegoose/nn/pipeline_parallel/microbatch.py index 644da1d..025910f 100644 --- a/pipegoose/nn/pipeline_parallel/microbatch.py +++ b/pipegoose/nn/pipeline_parallel/microbatch.py @@ -10,9 +10,14 @@ class ModelInputs(TypedDict): def split(inputs: ModelInputs, n_microbatches: int) -> List[ModelInputs]: assert n_microbatches > 0, f"n_microbatches must be greater than 0, got {n_microbatches}" - - input_ids_microbatches = torch.split(inputs["input_ids"], 2) - attention_mask_microbatches = torch.split(inputs["attention_mask"], 2) + assert "input_ids" in inputs, f"inputs must have 'input_ids' key, got {inputs.keys()}" + assert "attention_mask" in inputs, f"inputs must have 'attention_mask' key, got {inputs.keys()}" + assert ( + inputs["input_ids"].size(0) % n_microbatches == 0 + ), f"The batch size must be divisible by n_microbatches, got {inputs['input_ids'].size(0)} and {n_microbatches}" + + input_ids_microbatches = torch.split(inputs["input_ids"], n_microbatches) + attention_mask_microbatches = torch.split(inputs["attention_mask"], n_microbatches) microbatches = [] for input_ids, attention_mask in zip(input_ids_microbatches, attention_mask_microbatches): diff --git a/pipegoose/nn/pipeline_parallel/partitioner.py b/pipegoose/nn/pipeline_parallel/partitioner.py index 801f81f..cc10abf 100644 --- a/pipegoose/nn/pipeline_parallel/partitioner.py +++ b/pipegoose/nn/pipeline_parallel/partitioner.py @@ -1,12 +1,18 @@ +import re from abc import ABC, abstractclassmethod +from collections import defaultdict from enum import Enum, auto -from typing import List +from typing import Dict, List, Optional +import torch from torch import nn +from transformers.utils.fx import symbolic_trace from pipegoose.distributed.parallel_context import ParallelContext from pipegoose.distributed.parallel_mode import ParallelMode +INPUT_NAMES = ["input_ids", "attention_mask"] + class PartitionPolicy(Enum): UNIFORM = auto() @@ -21,41 +27,196 @@ def split(self) -> List[nn.Module]: class UniformPartitioner(BasePartitioner): - def __init__(self, module: nn.Module, parallel_context: ParallelContext): - self.module = module + def __init__(self, model: nn.Module, parallel_context: ParallelContext): + self.model = model self.parallel_context = parallel_context - def split(self) -> List[nn.Module]: - module = self.module + def _find_transformer_block_prefix(self): + # Since there is no standardized way to find the name of the transformer block, we have to find the name first + transformer_attr_name = self.model.base_model_prefix + + transformer_block_name = None + transformer = getattr(self.model, transformer_attr_name) + for attr_name in dir(transformer): + part = getattr(transformer, attr_name) + if isinstance(part, nn.ModuleList): + transformer_block_name = attr_name + break + + transformer_block_prefix = rf"{transformer_attr_name}_{transformer_block_name}_(\d+)" + + return transformer_block_prefix + + def _get_transformer_block_id(self, node_name: str) -> Optional[str]: + pattern = self._find_transformer_block_prefix() + match = re.match(pattern, node_name) + return match.group() if match else None + + def _split_nodes(self, traced_graph_module: torch.fx.GraphModule, shard_count: int = 3) -> Dict: + """Utility used to trace a graph and identify shard cutpoints.""" + + param_count: Dict[str, int] = {} + + # Find the total number of params in the model and + # the number of params per shard we are aiming for. + + # NOTE: we need to iterate over named_parameters AND named_modules because + # sometimes the parameters of a module is split into weight and bia in the + # traced graph and sometimes not. + # Example: + # The embedding in traced graph is called transformer_wte. The naming as parameter + # is transformer.wte.weight while as a module it is transformer.wte + # + # The projection inside the attention layer is split into weight and bias + # in the traced graph while in the module we only see the projection as a whole module. + + for name, param in traced_graph_module.named_parameters(): + name = name.replace(".", "_") + param_count[name] = param.numel() + + exclude_param_count = 0 + for name, module in traced_graph_module.named_modules(): + # exclude embedding layers + if isinstance(module, nn.Embedding): + name = name.replace(".", "_") + exclude_param_count += sum(x.numel() for x in module.parameters()) + param_count[name] = 0 + continue + + name = name.replace(".", "_") + param_count[name] = sum(x.numel() for x in module.parameters()) + + # Note that we use param_count[""] as total parameters which does not include the + # lm_head. We want to exclude the lm_head and the embeddings here because they + # usually cause high skew which does not lead to equal partitioning of the + # transformer blocks. + per_shard_param = (param_count[""] - exclude_param_count) // shard_count + node_name_to_shard_id: Dict[str, int] = {} + shard_id = 0 + shard_id_to_param_count = [0 for _ in range(shard_count)] + output_from_shard = {} + transformer_block_ended = False + current_transformer_block = 0 + + for node in traced_graph_module.graph.nodes: + if node.op == "output": + break + + # While splitting the nodes, we have to check if a node of the transformer block is detected + # If so, we have to force it in the current shard + if node.op in ("call_module", "get_attr"): + current_param_count = param_count.get(node.name, 0) + new_transformer_block_id = self._get_transformer_block_id(node.name) + # Check if a new transformer block has started + if new_transformer_block_id != current_transformer_block: + # End the previous block and start a new one + current_transformer_block = new_transformer_block_id + transformer_block_ended = True + else: + # We are still in the same transformer block + transformer_block_ended = False + + if ( + shard_id_to_param_count[shard_id] > 0 + and (shard_id_to_param_count[shard_id] + current_param_count) >= per_shard_param + and (shard_id + 1) < shard_count + and transformer_block_ended + ): + shard_id += 1 + transformer_block_ended = False + shard_id_to_param_count[shard_id] += current_param_count + + # we need to collect the nodes from the previous shards which are needed in the + # current shard because we need to propagate them until the current shard + if hasattr(node, "args"): + for arg in node.args: + if not hasattr(arg, "name"): + continue + + arg_shard_id = node_name_to_shard_id.get(arg.name, shard_id) + if arg_shard_id < shard_id: + # propagate the input from arg_shard_id until shard_id + for idx in range(arg_shard_id, shard_id): + # note that we use the dict as an ordered set data structure + output_from_shard.setdefault(idx, dict())[arg.name] = None + + node_name_to_shard_id[node.name] = shard_id + return node_name_to_shard_id, output_from_shard + + def split(self, input_names: List[str] = INPUT_NAMES) -> List[nn.Module]: n_partitions = self.parallel_context.pipeline_parallel_size - - # NOTE: BLOOM-560 - # embedding_module = module.transformer.word_embeddings - # transformer_blocks = module.transformer.h - # lm_head = module.lm_head - - # NOTE: For sshleifer/tiny-gpt2 - embed_module = module.transformer.wte - pos_embed_module = module.transformer.wpe - drop_module = module.transformer.drop - transformer_blocks = module.transformer.h - ln_f = module.transformer.ln_f - lm_head = module.lm_head - - # NOTE: Calculate the number of transformer blocks per partition - blocks_per_partition = len(transformer_blocks) // n_partitions - partitions = [] - - for i in range(n_partitions): - start = i * blocks_per_partition - # NOTE: if it's the last partition, get all remaining blocks - end = start + blocks_per_partition if i < n_partitions - 1 else None - partitions.append(nn.Sequential(*transformer_blocks[start:end])) - - partitions[0] = nn.Sequential(embed_module, pos_embed_module, drop_module, partitions[0]) - partitions[-1] = nn.Sequential(ln_f, lm_head, partitions[-1]) - - return partitions + model = self.model + module_list: List[torch.fx.GraphModule] = [] + num_graphs = 0 + new_graph = torch.fx.Graph() # type: ignore + + symbolic_traced_module = symbolic_trace(model, input_names=input_names) + + node_name_to_shard_id, output_from_shard = self._split_nodes(symbolic_traced_module, n_partitions) + + nodes_per_shard = defaultdict(dict) + + prev_shard_id = 1000 + prev_node = None + for node in symbolic_traced_module.graph.nodes: + # If the current node is in the next shard, we insert an output node. + # A new graph is created and a placeholder is added for the next shard. + if node.name in node_name_to_shard_id and prev_shard_id < node_name_to_shard_id[node.name]: + assert prev_node, "prev_node cannot be None" + + # generate output node for the past graph/shard + with new_graph.inserting_after(prev_node): + outputs = output_from_shard[prev_shard_id] + graph_output_names = [prev_node.name] + [i for i in outputs if i != prev_node.name] + + if isinstance(nodes_per_shard[prev_shard_id][prev_node.name], tuple): + graph_outputs = nodes_per_shard[prev_shard_id][prev_node.name] + tuple( + nodes_per_shard[prev_shard_id][i] for i in outputs if i != prev_node.name + ) + else: + graph_outputs = tuple( + [nodes_per_shard[prev_shard_id][prev_node.name]] + + [nodes_per_shard[prev_shard_id][i] for i in outputs if i != prev_node.name] + ) + new_graph.create_node(op="output", target="output", args=(graph_outputs,)) + + # generate new graph/shard and its input nodes (i.e., the output from the previous graph/shard) + num_graphs += 1 + module_list.append(torch.fx.GraphModule(model, new_graph)) + new_graph = torch.fx.Graph() + # generate placeholder nodes in the new graph/shard which matches the output nodes of the previous graph/shard + for new_graph_input_name in graph_output_names: + graph_input_node = new_graph.create_node("placeholder", new_graph_input_name) + nodes_per_shard[node_name_to_shard_id[node.name]][new_graph_input_name] = graph_input_node + + if node.op in [ + "placeholder", + "get_attr", + "call_function", + "call_method", + "call_module", + ]: + # Copy the nodes from the existing graph to the new graph. + current_shard_id = node_name_to_shard_id[node.name] + new_node = new_graph.node_copy(node, lambda x: nodes_per_shard[current_shard_id][x.name]) + nodes_per_shard[current_shard_id][node.name] = new_node + elif node.op == "output": + # If this is the last node, we should add an output + # node and add the last graph to the list. + assert prev_node, "prev_node cannot be None" + + with new_graph.inserting_after(prev_node): + new_graph.output(nodes_per_shard[prev_shard_id][prev_node.name]) + module_list.append(torch.fx.GraphModule(model, new_graph)) + break + prev_node = new_node + prev_shard_id = node_name_to_shard_id[node.name] + + for shard in module_list: + shard.graph.lint() + shard.recompile() + + return module_list def _get_partitioner(policy: PartitionPolicy) -> BasePartitioner: diff --git a/pipegoose/nn/pipeline_parallel/pipeline_engine.py b/pipegoose/nn/pipeline_parallel/pipeline_engine.py index 04273ab..bd1ec56 100644 --- a/pipegoose/nn/pipeline_parallel/pipeline_engine.py +++ b/pipegoose/nn/pipeline_parallel/pipeline_engine.py @@ -7,6 +7,7 @@ from pipegoose.distributed.parallel_context import ParallelContext from pipegoose.distributed.parallel_mode import ParallelMode +from pipegoose.nn.pipeline_parallel import microbatch from pipegoose.nn.pipeline_parallel._comm import RECV_QUEUE from pipegoose.nn.pipeline_parallel._job.creator import create_job from pipegoose.nn.pipeline_parallel._job.job_type import JobType @@ -56,13 +57,14 @@ def __init__( self.parallel_context = parallel_context self.pipeline_context = PipelineContext(scheduler, parallel_context) - def run(self, inputs: torch.Tensor) -> torch.Tensor: + def run(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor) -> torch.Tensor: self.worker_manager.spawn() self.pipeline_context.forward() n_microbatches = self.scheduler.n_microbatches - # microbatches = microbatch.split(inputs, n_microbatches=n_microbatches) - microbatches = torch.chunk(inputs, chunks=n_microbatches, dim=0) + inputs = {"input_ids": input_ids, "attention_mask": attention_mask} + microbatches = microbatch.split(inputs, n_microbatches=n_microbatches) + # microbatches = torch.chunk(inputs, chunks=n_microbatches, dim=0) # NOTE: add a callback to the progress tracker # that if the clock_idx is increased, then diff --git a/pipegoose/nn/pipeline_parallel/pipeline_parallel.py b/pipegoose/nn/pipeline_parallel/pipeline_parallel.py index 4f6ef20..14e8a11 100644 --- a/pipegoose/nn/pipeline_parallel/pipeline_parallel.py +++ b/pipegoose/nn/pipeline_parallel/pipeline_parallel.py @@ -1,5 +1,3 @@ -from typing import List - import torch from torch import nn @@ -7,6 +5,7 @@ from pipegoose.nn.parallel import Parallel from pipegoose.nn.pipeline_parallel._utils import get_partition_idx from pipegoose.nn.pipeline_parallel._worker import WorkerManager +from pipegoose.nn.pipeline_parallel.partitioner import UniformPartitioner from pipegoose.nn.pipeline_parallel.pipeline_engine import PipelineEngine from pipegoose.nn.pipeline_parallel.scheduler import GPipeScheduler @@ -16,11 +15,11 @@ class PipelineParallel(Parallel): def __init__( self, - modules: List[nn.Module], + module: nn.Module, num_microbatches: int, parallel_context: ParallelContext, ): - self.modules = modules + self.module = module self.num_microbatches = num_microbatches self.parallel_context = parallel_context @@ -28,7 +27,8 @@ def __init__( def parallelize(self) -> nn.Module: if self.parallel_context.pipeline_parallel_size > 1: partition_idx = get_partition_idx(self.parallel_context) - module = self.modules[partition_idx] + partitions = UniformPartitioner(self.module, self.parallel_context).split(["input_ids"]) + module = partitions[partition_idx] n_partitions = self.parallel_context.pipeline_parallel_size scheduler = GPipeScheduler(self.num_microbatches, n_partitions) @@ -47,4 +47,4 @@ def parallelize(self) -> nn.Module: return module else: - return self.modules + return self.module diff --git a/pipegoose/nn/pipeline_parallel/queue.py b/pipegoose/nn/pipeline_parallel/queue.py index 39e18b5..3ad42d5 100644 --- a/pipegoose/nn/pipeline_parallel/queue.py +++ b/pipegoose/nn/pipeline_parallel/queue.py @@ -73,7 +73,13 @@ def get_saved_activations(key: ActivationKey) -> torch.Tensor: """Get the saved activations for a given key for backward job.""" # NOTE: because a partition can have multiple microbatches, input = _INPUT_ACTIVATIONS[key] - return input.requires_grad_(True) + + # return input.requires_grad_(True) + # TODO: add support regular non-transformers model + if isinstance(input, torch.Tensor): + return input.requires_grad_(True) + else: + return input def save_activations(key: ActivationKey, data: torch.Tensor): """Save forward job's activations for backward job.""" diff --git a/pipegoose/testing/utils.py b/pipegoose/testing/utils.py index 742c96f..7542497 100644 --- a/pipegoose/testing/utils.py +++ b/pipegoose/testing/utils.py @@ -37,6 +37,7 @@ def spawn(func: Callable, world_size: int = 1, **kwargs): kwargs.pop("port") wrapped_func = partial(func, world_size=world_size, port=port, **kwargs) + mp.spawn(wrapped_func, nprocs=world_size) diff --git a/tests/nn/pipeline_parallel/test_microbatch.py b/tests/nn/pipeline_parallel/test_microbatch.py index 90cc1ff..cb3e119 100644 --- a/tests/nn/pipeline_parallel/test_microbatch.py +++ b/tests/nn/pipeline_parallel/test_microbatch.py @@ -6,31 +6,21 @@ def test_split_a_mini_batch_to_microbatches(): + BATCH_SIZE = 36 + N_MICROBATCHES = 6 + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer.pad_token = tokenizer.eos_token - batch_sentences = [ - "This is the first sentence.", - "Here's the second one.", - "This makes three.", - "Is this the fourth sentence?", - "Five sentences now.", - "This is the sixth sentence.", - "Sentence seven is here.", - "We're up to eight now.", - "This should be the ninth sentence.", - "And finally, the tenth sentence.", - ] - BATCH_SIZE = len(batch_sentences) - N_MICROBATCHES = 5 + text = "Persistence is all you need." + batch_sentences = [text for _ in range(BATCH_SIZE)] inputs = tokenizer(batch_sentences, padding=True, return_tensors="pt") microbatches = microbatch.split(inputs, n_microbatches=N_MICROBATCHES) assert isinstance(microbatches, list) assert len(microbatches) == N_MICROBATCHES - assert "input_ids" in microbatches[0] - assert "attention_mask" in microbatches[0] + assert all(set(batch.keys()) == set(inputs.keys()) for batch in microbatches) is True total_sentences = sum(microbatch["input_ids"].size(0) for microbatch in microbatches) assert total_sentences == BATCH_SIZE diff --git a/tests/nn/pipeline_parallel/test_partitioner.py b/tests/nn/pipeline_parallel/test_partitioner.py index adddfee..fafc38c 100644 --- a/tests/nn/pipeline_parallel/test_partitioner.py +++ b/tests/nn/pipeline_parallel/test_partitioner.py @@ -1,52 +1,106 @@ import pytest -from torch import nn -from transformers import AutoModelForCausalLM, AutoTokenizer - -from pipegoose.nn.pipeline_parallel.partitioner import ( # PartitionPolicy,; get_model_partition, - UniformPartitioner, +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BloomConfig, + BloomForCausalLM, ) + +from pipegoose.nn.pipeline_parallel.partitioner import UniformPartitioner from pipegoose.testing.utils import init_parallel_context, spawn -MODEL_NAME = "sshleifer/tiny-gpt2" +def get_gpt2_and_tokenizer(): + return AutoModelForCausalLM.from_pretrained("gpt2"), AutoTokenizer.from_pretrained("gpt2") + + +def get_bloom_560m_and_tokenizer(): + return AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m"), AutoTokenizer.from_pretrained( + "bigscience/bloom-560m" + ) + + +def get_bloom_and_tokenizer_with_6_layers(): + return BloomForCausalLM(BloomConfig(n_layer=6)), AutoTokenizer.from_pretrained("bigscience/bloom-560m") -def run_model_partitioner(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size): + +# TODO: Also add a function for a generic nn.Transformer model +def run_model_partitioner( + rank, + world_size, + port, + tensor_parallel_size, + pipeline_parallel_size, + data_parallel_size, + model_retrieval_func, +): parallel_context = init_parallel_context( - rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size + rank, + world_size, + port, + tensor_parallel_size, + pipeline_parallel_size, + data_parallel_size, ) - module = AutoModelForCausalLM.from_pretrained(MODEL_NAME) - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - tokenizer.pad_token = tokenizer.eos_token - text = ["Hello world", "How are you?"] - inputs = tokenizer(text, return_tensors="pt", padding=True) + torch.manual_seed(0) + batch_sentences = ["hello world from pipegoose"] + model, tokenizer = model_retrieval_func() + model.eval() + tokenizer.pad_token = tokenizer.eos_token + inputs = tokenizer(batch_sentences, padding=True, return_tensors="pt") + gt_logits = model(**inputs).logits - # policy = PartitionPolicy.UNIFORM - partitions = UniformPartitioner(module, parallel_context).split() - # partition = get_model_partition(module, policy, parallel_context) + partitioned_model = UniformPartitioner(model, parallel_context).split() + assert ( + len(partitioned_model) == pipeline_parallel_size + ), f"Received model with {len(partitioned_model)} instead of {pipeline_parallel_size}" - assert isinstance(partitions, list) - assert len(partitions) == pipeline_parallel_size + print("Start printing partitioned model") + for i, shard in enumerate(partitioned_model): + shard_param_count = 0 + print("==================") + print(f"Shard {i + 1}") + for _, module in shard.named_children(): + # Sum the parameters of each module in the shard + shard_param_count += sum(p.numel() for p in module.parameters()) + print(f"Layer type: {type(module).__name__}") + print(module) + print(f"Total parameters in Shard {i + 1}: {shard_param_count}") + print("==================") + print("End printing partitioned model") - for partition in partitions: - assert isinstance(partition, nn.Module) - assert partition != module + partitioned_model_result = inputs + for partition_id in range(pipeline_parallel_size): + if type(partitioned_model_result) in (list, tuple): + partitioned_model_result = partitioned_model[partition_id](*partitioned_model_result) + else: + partitioned_model_result = partitioned_model[partition_id](**partitioned_model_result) - outputs = inputs - for partition in partitions: - outputs = partition(outputs) + assert torch.allclose(gt_logits, partitioned_model_result), "Results are not close" -@pytest.mark.skip -@pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) -def test_naive_partitioning(pipeline_parallel_size): +@pytest.mark.parametrize("pipeline_parallel_size", [2, 3, 4, 5, 6]) +@pytest.mark.parametrize( + "model_retrieval_func", + [ + get_gpt2_and_tokenizer, + get_bloom_and_tokenizer_with_6_layers, + get_bloom_560m_and_tokenizer, + ], +) +def test_naive_partitioning(pipeline_parallel_size, model_retrieval_func): TENSOR_PARALLEL_SIZE = 1 DATA_PARALLEL_SIZE = 1 - + print( + f"Running test with pipeline_parallel_size={pipeline_parallel_size}, tensor_parallel_size={TENSOR_PARALLEL_SIZE}, data_parallel_size={DATA_PARALLEL_SIZE}" + ) spawn( run_model_partitioner, world_size=pipeline_parallel_size, tensor_parallel_size=TENSOR_PARALLEL_SIZE, pipeline_parallel_size=pipeline_parallel_size, data_parallel_size=DATA_PARALLEL_SIZE, + model_retrieval_func=model_retrieval_func, ) diff --git a/tests/nn/pipeline_parallel/test_pipeline_parallel.py b/tests/nn/pipeline_parallel/test_pipeline_parallel.py index d08c9d3..3e171b0 100644 --- a/tests/nn/pipeline_parallel/test_pipeline_parallel.py +++ b/tests/nn/pipeline_parallel/test_pipeline_parallel.py @@ -1,11 +1,8 @@ -from copy import deepcopy -from functools import reduce - -import torch +import pytest from torch import nn -from torch.optim import SGD +from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM -from pipegoose.nn.pipeline_parallel._utils import get_partition_idx, is_last_stage +from pipegoose.nn.pipeline_parallel._utils import get_partition_idx from pipegoose.nn.pipeline_parallel.pipeline_parallel import PipelineParallel from pipegoose.testing.utils import count_model_parameters, init_parallel_context, spawn @@ -22,11 +19,11 @@ def run_pipeline_parallel( rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, num_microbatches, kwargs ): MODEL = kwargs["model"] - UPDATED_MODEL = kwargs["updated_model"] + # UPDATED_MODEL = kwargs["updated_model"] INPUTS = kwargs["inputs"] - REF_OUTPUTS = kwargs["ref_outputs"] - REF_GRADS = kwargs["ref_grads"] - LR = kwargs["lr"] + # REF_OUTPUTS = kwargs["ref_outputs"] + # REF_GRADS = kwargs["ref_grads"] + kwargs["lr"] forward_timeline = [] backward_timeline = [] @@ -50,7 +47,8 @@ def forward(self, input): return self.module(input) # NOTE: just for recording the forward and backward timeline - model = nn.ModuleList([TimelineRegister(partition_idx, module) for partition_idx, module in enumerate(MODEL)]) + # model = nn.ModuleList([TimelineRegister(partition_idx, module) for partition_idx, module in enumerate(MODEL)]) + model = MODEL parallel_context = init_parallel_context( rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size @@ -59,67 +57,79 @@ def forward(self, input): EXPECTED_FORWARD_TIMELINE, EXPECTED_BACKWARD_TIMELINE = generate_expected_timeline(num_microbatches, partition_idx) parallelized_model = PipelineParallel(model, num_microbatches, parallel_context).parallelize() - optim = SGD(parallelized_model.parameters(), LR) + # optim = SGD(parallelized_model.parameters(), LR) assert isinstance(parallelized_model, nn.Module) assert count_model_parameters(parallelized_model) < count_model_parameters(model) assert count_model_parameters(parallelized_model) == count_model_parameters(model[partition_idx]) - outputs = parallelized_model(INPUTS) + parallelized_model(**INPUTS) assert forward_timeline == EXPECTED_FORWARD_TIMELINE - if is_last_stage(parallel_context): - assert torch.allclose(torch.cat(outputs, dim=0), REF_OUTPUTS) + # if is_last_stage(parallel_context): + # assert torch.allclose(torch.cat(outputs, dim=0), REF_OUTPUTS) - optim.zero_grad() - for output in outputs: - output.sum().backward(retain_graph=True) + # optim.zero_grad() + # for output in outputs: + # output.sum().backward(retain_graph=True) - optim.step() + # optim.step() - assert backward_timeline == EXPECTED_BACKWARD_TIMELINE - for p, ref_grad in zip(parallelized_model.parameters(), REF_GRADS[partition_idx]): - assert torch.allclose(p.grad, ref_grad) + # assert backward_timeline == EXPECTED_BACKWARD_TIMELINE + # for p, ref_grad in zip(parallelized_model.parameters(), REF_GRADS[partition_idx]): + # assert torch.allclose(p.grad, ref_grad) - for p, ref_p in zip(parallelized_model.parameters(), UPDATED_MODEL[partition_idx].parameters()): - assert torch.allclose(p, ref_p) + # for p, ref_p in zip(parallelized_model.parameters(), UPDATED_MODEL[partition_idx].parameters()): + # assert torch.allclose(p, ref_p) -def test_pipeline_parallel(): - TENSOR_PARALLEL_SIZE, PIPELINE_PARALLEL_SIZE, DATA_PARALLEL_SIZE = 1, 4, 1 - WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE +@pytest.mark.parametrize("pipeline_parallel_size", [4]) +def test_pipeline_parallel(pipeline_parallel_size): + TENSOR_PARALLEL_SIZE = 1 + DATA_PARALLEL_SIZE = 1 + WORLD_SIZE = TENSOR_PARALLEL_SIZE * pipeline_parallel_size * DATA_PARALLEL_SIZE BATCH_SIZE, NUM_MICROBATCHES = 32, 6 SEQ_LEN, HIDDEN_DIM = 10, 5 LR = 0.1 - inputs = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM, requires_grad=False) - model = nn.ModuleList([nn.Sequential(nn.Linear(HIDDEN_DIM, HIDDEN_DIM), nn.ReLU()) for _ in range(PIPELINE_PARALLEL_SIZE)]) - ORIG_MODEL = deepcopy(model) - optim = SGD(model.parameters(), lr=LR) + # inputs = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM, requires_grad=False) + # model = nn.ModuleList([nn.Sequential(nn.Linear(HIDDEN_DIM, HIDDEN_DIM), nn.ReLU()) for _ in range(pipeline_parallel_size)]) + text = "Persistence is all you need." + texts = [text for _ in range(BATCH_SIZE)] + # model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") + model = BloomForCausalLM(BloomConfig(n_layer=6)) + tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + # ORIG_MODEL = deepcopy(model) + + inputs = tokenizer(texts, return_tensors="pt") + # optim = SGD(model.parameters(), lr=LR) - outputs = reduce(lambda inputs, layer: layer(inputs), model, inputs) + # outputs = reduce(lambda inputs, layer: layer(inputs), model, inputs) + outputs = model(**inputs, labels=inputs["input_ids"]) - optim.zero_grad() - outputs.sum().backward() - optim.step() + # optim.zero_grad() + # outputs.loss.sum().backward() + # optim.step() - grads = [[p.grad for p in layer.parameters()] for layer in model] + # grads = [[p.grad for p in layer.parameters()] for layer in model] kwargs = { "lr": LR, - "model": ORIG_MODEL, + "model": model, + "tokenizer": tokenizer, "updated_model": model, - "inputs": inputs.detach(), - "ref_outputs": outputs.detach(), - "ref_grads": grads, + "inputs": inputs, + "ref_logits": outputs.logits.detach(), + "ref_loss": outputs.loss.detach(), + # "ref_grads": grads, } spawn( run_pipeline_parallel, world_size=WORLD_SIZE, tensor_parallel_size=TENSOR_PARALLEL_SIZE, - pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, + pipeline_parallel_size=pipeline_parallel_size, data_parallel_size=DATA_PARALLEL_SIZE, num_microbatches=NUM_MICROBATCHES, kwargs=kwargs, diff --git a/tests/optim/zero/test_sharding.py b/tests/optim/zero/test_sharding.py index 1c61151..855303f 100644 --- a/tests/optim/zero/test_sharding.py +++ b/tests/optim/zero/test_sharding.py @@ -41,9 +41,6 @@ def calculate_total_sharded_elements(sharded_params): assert len(sharded_params) == world_size for rank, shard in enumerate(sharded_params): - if rank == 4: - assert 1 == 1 - assert isinstance(shard, list) for param_group in shard: assert len(param_group["params"]) > 0