diff --git a/colossalai/shardformer/modeling/gemma2.py b/colossalai/shardformer/modeling/gemma2.py new file mode 100644 index 000000000000..75e46c41c912 --- /dev/null +++ b/colossalai/shardformer/modeling/gemma2.py @@ -0,0 +1,302 @@ +from typing import List, Optional + +import torch +import torch.distributed +import torch.utils.checkpoint +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM, Gemma2Model +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import gather_sp_output +from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag +from colossalai.shardformer.shard import ShardConfig + +from ..layer import RingAttention, dist_cross_entropy + +_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] + + +class Gemma2PipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def gemma2_model_forward( + self: Gemma2Model, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + force_sp_gather: bool = True, # Set to false only when computing cross entropy + ): + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with pipeline parallelism. Setting `use_cache=False`..." + ) + use_cache = False + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + disable_pp = stage_manager is None + # retrieve input_ids and inputs_embeds + if disable_pp or stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + device = hidden_states.device + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + # Support SP + PP + sp_mode = shard_config.sequence_parallelism_mode + shard_config.sequence_parallel_process_group + sp_size = shard_config.sequence_parallel_size + # Generating full positions ids for modes that gather sequence before attn + if stage_manager and (sp_mode != "ring_attn" and not stage_manager.is_first_stage()): + seq_length *= sp_size + + past_seen_tokens = 0 + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device) + + seq_length + past_seen_tokens + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + attn_kwargs: torch.Tensor = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + start_idx, end_idx = (0, len(self.layers)) if disable_pp else (stage_index[0], stage_index[1]) + + num_ckpt_layers = 0 + if self.gradient_checkpointing and self.training: + num_ckpt_layers = end_idx - start_idx + # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer + if shard_config.gradient_checkpoint_config is not None: + num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( + stage=stage_manager.stage, + num_stages=stage_manager.num_stages, + num_layers=end_idx - start_idx, + model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + num_model_chunks=stage_manager.num_model_chunks, + ) + assert num_ckpt_layers <= end_idx - start_idx + + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + if idx - start_idx < num_ckpt_layers: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attn_kwargs, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attn_kwargs, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if disable_pp or stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + if (not shard_config.parallel_output) or force_sp_gather or is_share_sp_tp(sp_mode): # noqa + hidden_states = gather_sp_output(hidden_states, shard_config) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if disable_pp or stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + # always return dict for intermediate stage + return {"hidden_states": hidden_states} + + @staticmethod + def gemma2_for_causal_lm_forward( + self: Gemma2ForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + **kwargs, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + if shard_config.sequence_parallelism_mode == "ring_attn" and shard_config.parallel_output: + # Split labels in a zigzag fashion too + sp_group = shard_config.sequence_parallel_process_group + if attention_mask.bool().all(): + labels = split_batch_zigzag(labels, sp_group, seq_dim=1, is_label=True) + else: + # [B, max_seqlen // sp_size] + labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = Gemma2PipelineForwards.gemma2_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + force_sp_gather=False, + ) + past_key_values = None + + disable_pp = stage_manager is None + if disable_pp or stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 9a0da82f55a5..5309bcd6df7a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -141,7 +141,9 @@ def llama_model_forward( invert=(sp_mode != "ring_attn"), ) else: - attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position) + attn_kwargs: torch.Tensor = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions + ) # Support SP + PP. Later stages have already received the split input. split_input = disable_pp or stage_manager.is_first_stage() diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 7b9c759a66c2..3994611f7054 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -220,6 +220,10 @@ class PolicyLocation: "transformers.models.cohere.modeling_cohere.CohereForCausalLM": PolicyLocation( file_name="command", class_name="CommandForCausalLMPolicy" ), + # gemma2 + "transformers.models.gemma2.modeling_gemma2.Gemma2ForCausalLM": PolicyLocation( + file_name="gemma2", class_name="Gemma2ForCausalLMPolicy" + ), } diff --git a/colossalai/shardformer/policies/gemma2.py b/colossalai/shardformer/policies/gemma2.py new file mode 100644 index 000000000000..3e8815751acb --- /dev/null +++ b/colossalai/shardformer/policies/gemma2.py @@ -0,0 +1,153 @@ +from functools import partial +from typing import Dict, Union + +import torch.nn as nn + +from colossalai.shardformer.layer import ( + Linear1D_Col, + Linear1D_Row, + PaddingEmbedding, + PaddingLMHead, + RMSNorm, + VocabParallelEmbedding1D, + VocabParallelLMHead1D, +) + +from ..modeling.gemma2 import Gemma2PipelineForwards +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["Gemma2Policy", "Gemma2ForCausalLMPolicy"] + + +class Gemma2Policy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + self.tie_weight = self.tie_weight_check() + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer, Gemma2Model + + policy = {} + + embedding_cls = None + if self.shard_config.enable_tensor_parallelism: + embedding_cls = VocabParallelEmbedding1D + else: + if self.tie_weight: + embedding_cls = PaddingEmbedding + + norm_cls = RMSNorm + + if self.shard_config.enable_tensor_parallelism: + tp_size = self.shard_config.tensor_parallel_size + num_q_heads = self.model.config.num_attention_heads // tp_size + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // tp_size, + "self_attn.num_heads": num_q_heads, + } + num_kv_heads = self.model.config.num_key_value_heads // tp_size + decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads + policy[Gemma2DecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="mlp.gate_proj", target_module=Linear1D_Col), + SubModuleReplacementDescription(suffix="mlp.up_proj", target_module=Linear1D_Col), + SubModuleReplacementDescription(suffix="mlp.down_proj", target_module=Linear1D_Row), + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + ], + ) + + if embedding_cls is not None: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=embedding_cls, + kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}, + ), + policy=policy, + target_key=Gemma2Model, + ) + + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=norm_cls), + SubModuleReplacementDescription(suffix="pre_feedforward_layernorm", target_module=norm_cls), + SubModuleReplacementDescription(suffix="post_feedforward_layernorm", target_module=norm_cls), + SubModuleReplacementDescription(suffix="post_attention_layernorm", target_module=norm_cls), + ], + policy=policy, + target_key=Gemma2DecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=norm_cls, + ), + policy=policy, + target_key=Gemma2Model, + ) + return policy + + def postprocess(self): + return self.model + + +class Gemma2ForCausalLMPolicy(Gemma2Policy): + def module_policy(self): + from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs=dict( + gather_output=not self.shard_config.parallel_output, + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, + ), + ), + policy=policy, + target_key=Gemma2ForCausalLM, + ) + if self.shard_config.parallel_output: + method_replacement = { + "forward": partial( + Gemma2PipelineForwards.gemma2_for_causal_lm_forward, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=Gemma2ForCausalLM + ) + else: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=PaddingLMHead, + kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by), + ), + policy=policy, + target_key=Gemma2ForCausalLM, + ) + + return policy