diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index ca3d1465ad27..a9b915d10485 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -65,7 +64,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: norm_cls = FusedLayerNorm else: norm_cls = LayerNorm - + sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 81de08e29730..36491b4b5522 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Union