From 2a2ec49aa70c95f73b6017624e32cdad6b36b0e1 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 30 Nov 2023 18:37:47 +0800 Subject: [PATCH] [plugin]fix 3d checkpoint load when booster boost without optimizer. (#5135) * fix 3d checkpoint load when booster boost without optimizer fix 3d checkpoint load when booster boost without optimizer * test ci * revert ci * fix fix --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 4 ++-- tests/test_booster/test_plugin/test_gemini_plugin.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index ed3a61dede56..91fcba55a0aa 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -21,7 +21,7 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh -from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -42,7 +42,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): return x -class HybridParallelModule(ModelWrapper): +class HybridParallelModule(ModelWrapper, AMPModelMixin): def __init__( self, module: Module, diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 3c496ff64755..d4205e1f9d73 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -116,6 +116,9 @@ def check_gemini_plugin( "transformers_falcon_for_sequence_classification", "transformers_falcon_for_token_classification", "transformers_falcon_for_question_answering", + "transformers_gptj_lm", # lead to OOM when running in ci + "transformers_gptj_for_question_answering", + "transformers_gptj_for_sequence_classification", ]: continue