From aed599b4422b1cdf7397abb05a58c3726523a333 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Tue, 27 Feb 2024 05:34:09 -0800 Subject: [PATCH] Fix assertion to run pipeline engine with a compiled module (#5197) This PR fixes assertion in the pipeline engine to compute a compile module with pipeline parallelism. --- deepspeed/runtime/pipe/engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 05029e44d0e8..ef1c98a95c7b 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -66,7 +66,9 @@ class PipelineEngine(DeepSpeedEngine): def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): super().__init__(*super_args, **super_kwargs) - assert isinstance(self.module, PipelineModule), "model must base PipelineModule" + assert isinstance(self.module, PipelineModule) \ + or (hasattr(self.module, 'wrapped') and isinstance(self.module.wrapped, PipelineModule)), \ + "model must base PipelineModule" assert self.zero_optimization_stage( ) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism"