Skip to content

Commit

Permalink
[Refactor] Remove sample input in model partitioning for
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 27, 2023
1 parent ed433d1 commit 369263a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 3 additions & 1 deletion pipegoose/nn/pipeline_parallel/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode

INPUT_NAMES = ["input_ids", "attention_mask"]


class PartitionPolicy(Enum):
UNIFORM = auto()
Expand Down Expand Up @@ -141,7 +143,7 @@ def _split_nodes(self, traced_graph_module: torch.fx.GraphModule, shard_count: i
node_name_to_shard_id[node.name] = shard_id
return node_name_to_shard_id, output_from_shard

def split(self, input_names: List[str]) -> List[nn.Module]:
def split(self, input_names: List[str] = INPUT_NAMES) -> List[nn.Module]:
n_partitions = self.parallel_context.pipeline_parallel_size
model = self.model
module_list: List[torch.fx.GraphModule] = []
Expand Down
10 changes: 4 additions & 6 deletions tests/nn/pipeline_parallel/test_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def run_model_partitioner(
model.eval()
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(batch_sentences, padding=True, return_tensors="pt")
gt_logits = model(input_ids=inputs["input_ids"]).logits
gt_logits = model(**inputs).logits

partitioned_model = UniformPartitioner(model, parallel_context).split(["input_ids"])
partitioned_model = UniformPartitioner(model, parallel_context).split()
assert (
len(partitioned_model) == pipeline_parallel_size
), f"Received model with {len(partitioned_model)} instead of {pipeline_parallel_size}"
Expand All @@ -71,14 +71,12 @@ def run_model_partitioner(
print("==================")
print("End printing partitioned model")

inputs = tokenizer(batch_sentences, padding=True, return_tensors="pt")

partitioned_model_result = inputs["input_ids"]
partitioned_model_result = inputs
for partition_id in range(pipeline_parallel_size):
if type(partitioned_model_result) in (list, tuple):
partitioned_model_result = partitioned_model[partition_id](*partitioned_model_result)
else:
partitioned_model_result = partitioned_model[partition_id](partitioned_model_result)
partitioned_model_result = partitioned_model[partition_id](**partitioned_model_result)

assert torch.allclose(gt_logits, partitioned_model_result), "Results are not close"

Expand Down

0 comments on commit 369263a

Please sign in to comment.