diff --git a/pipegoose/nn/pipeline_parallel/partitioner.py b/pipegoose/nn/pipeline_parallel/partitioner.py index 391c0f2..cc10abf 100644 --- a/pipegoose/nn/pipeline_parallel/partitioner.py +++ b/pipegoose/nn/pipeline_parallel/partitioner.py @@ -11,6 +11,8 @@ from pipegoose.distributed.parallel_context import ParallelContext from pipegoose.distributed.parallel_mode import ParallelMode +INPUT_NAMES = ["input_ids", "attention_mask"] + class PartitionPolicy(Enum): UNIFORM = auto() @@ -141,7 +143,7 @@ def _split_nodes(self, traced_graph_module: torch.fx.GraphModule, shard_count: i node_name_to_shard_id[node.name] = shard_id return node_name_to_shard_id, output_from_shard - def split(self, input_names: List[str]) -> List[nn.Module]: + def split(self, input_names: List[str] = INPUT_NAMES) -> List[nn.Module]: n_partitions = self.parallel_context.pipeline_parallel_size model = self.model module_list: List[torch.fx.GraphModule] = [] diff --git a/tests/nn/pipeline_parallel/test_partitioner.py b/tests/nn/pipeline_parallel/test_partitioner.py index 9d2c22d..fafc38c 100644 --- a/tests/nn/pipeline_parallel/test_partitioner.py +++ b/tests/nn/pipeline_parallel/test_partitioner.py @@ -50,9 +50,9 @@ def run_model_partitioner( model.eval() tokenizer.pad_token = tokenizer.eos_token inputs = tokenizer(batch_sentences, padding=True, return_tensors="pt") - gt_logits = model(input_ids=inputs["input_ids"]).logits + gt_logits = model(**inputs).logits - partitioned_model = UniformPartitioner(model, parallel_context).split(["input_ids"]) + partitioned_model = UniformPartitioner(model, parallel_context).split() assert ( len(partitioned_model) == pipeline_parallel_size ), f"Received model with {len(partitioned_model)} instead of {pipeline_parallel_size}" @@ -71,14 +71,12 @@ def run_model_partitioner( print("==================") print("End printing partitioned model") - inputs = tokenizer(batch_sentences, padding=True, return_tensors="pt") - - partitioned_model_result = inputs["input_ids"] + partitioned_model_result = inputs for partition_id in range(pipeline_parallel_size): if type(partitioned_model_result) in (list, tuple): partitioned_model_result = partitioned_model[partition_id](*partitioned_model_result) else: - partitioned_model_result = partitioned_model[partition_id](partitioned_model_result) + partitioned_model_result = partitioned_model[partition_id](**partitioned_model_result) assert torch.allclose(gt_logits, partitioned_model_result), "Results are not close"