diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index 77f9a05f1..3c7e749a1 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -137,6 +137,9 @@ def __init__(self, config: LlamaConfig): self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.pre_attn_graph = None self.post_attn_graph = None @@ -283,7 +286,8 @@ def _reorder_cache_from_bloom_to_llama( key_states, value_states = key_value key_states = key_states.permute(0, 2, 1) key_states = key_states.view( - batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim + batch_size, self.num_key_value_heads//2, seq_length, self.head_dim + #batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim ) value_states = value_states.view(*key_states.shape) return (key_states, value_states) @@ -291,6 +295,7 @@ def _reorder_cache_from_bloom_to_llama( def _reorder_cache_from_llama_to_bloom( self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int ) -> Tuple[torch.Tensor]: + raise NotImplementedError key_states, value_states = key_value value_states = value_states.view( batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim diff --git a/src/petals/models/llama/slicing.py b/src/petals/models/llama/slicing.py new file mode 100644 index 000000000..66406ab86 --- /dev/null +++ b/src/petals/models/llama/slicing.py @@ -0,0 +1,97 @@ +""" +Optimized configs for selected models. These configs are not necessary, but they can improve performance in some +cases, e.g. training with very small batches or inference with long sequences. + +NB: some of these configs get fairly complicated in order to squeeze a bit of extra performance. When developing your + own config, you can get most of the performance benefits by using auto config -- and maybe splitting MLP layers. +""" +from functools import partial +from itertools import chain +from typing import Callable, Dict, Sequence + +import torch +from transformers import PretrainedConfig, LlamaConfig + +from tensor_parallel.communications import CollectiveOperation +from tensor_parallel.slicer_wrapper import Config +from tensor_parallel.tensor_parallel import PerDeviceTensors + +ConfigGetter = Callable[[PretrainedConfig, Sequence[torch.device]], Config] + +def get_llama_config(model_config: LlamaConfig, devices: Sequence[torch.device]) -> Config: + assert model_config.model_type == "llama", f"Trying to pass {model_config.model_type} as llama config" + + world_size = len(devices) + head_dim = model_config.hidden_size // model_config.num_attention_heads + num_kv = model_config.num_key_value_heads + q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads + + gather_kv_across_ranks = CollectiveOperation( + world_size=world_size, + func=lambda *kvs: [PerDeviceTensors(*chain(*(x or [None] for x in kvs)))] * world_size + ) + + select_kv_for_rank = lambda kvs, rank: (kvs[2 * rank], kvs[2 * rank + 1]) if kvs else None + + config = Config( + state_rules={ + # LlamaAttention + r".*self_attn\.q_proj\.weight$": partial(split_heads, dim=0, head_dim=q_per_kv * head_dim, world_size=world_size), + r".*self_attn\.k_proj\.weight$": partial(split_heads, dim=0, head_dim=head_dim, world_size=world_size), + r".*self_attn\.v_proj\.weight$": partial(split_heads, dim=0, head_dim=head_dim, world_size=world_size), + r".*self_attn\.o_proj\.weight$": partial(split_heads, dim=1, head_dim=q_per_kv * head_dim, world_size=world_size), + # LlamaFeedForward + r".*mlp\.gate_proj\.weight$": "split 0", + r".*mlp\.down_proj\.weight$": "split 1", + r".*mlp\.up_proj\.weight$": "split 0", + # LlamaModel + #r".*embed_tokens.weight$": "split 1", + #r".*lm_head\.weight$": "split 0", + }, + input_rules={ + r".*self_attn$": {"past_key_value": select_kv_for_rank}, + }, + output_rules={ + r".*self_attn$": {0: "sum", 2: gather_kv_across_ranks}, + r".*mlp$": {0: "sum"}, + r".*embed_tokens$": {0: "gather -1"}, + r".*lm_head$": {0: "gather -1"}, + }, + attr_rules={ + r".*self_attn$": { + "hidden_size": partial(split_inner_dim, num_heads=num_kv, world_size=world_size), + "num_heads": partial(split_num_heads, world_size=world_size), + "num_key_value_heads": partial(split_num_heads, world_size=world_size), + } + }, + #attr_rules={ + # r".*self_attn$": { + # "hidden_size": partial(split_inner_dim, num_heads=num_kv, world_size=world_size), + # "num_heads": lambda n, rank: q_per_kv * split_num_heads(n // q_per_kv, rank=rank, world_size=world_size), + # } + #}, + ) + + return config + + + +def split_heads(tensor: torch.Tensor, *, dim: int, head_dim: int, rank: int, world_size: int, optional: bool = False): + """Split a tensor along dim such that each part size is divisible by head_dim""" + if tensor is None and optional: + return None + assert tensor.shape[dim] % head_dim == 0, tensor.shape + if dim < 0: + dim = (tensor.ndim + dim) % tensor.ndim + shape = list(tensor.shape) + shape[dim] //= head_dim + shape.insert(dim + 1, head_dim) + tensor_part = tensor.reshape(shape).tensor_split(world_size, dim=dim)[rank].flatten(dim, dim + 1) + return tensor_part + + +def split_num_heads(num_heads: int, *, rank: int, world_size: int): + return torch.empty(num_heads, device="meta").tensor_split(world_size)[rank].numel() + +def split_inner_dim(inner_dim: int, *, rank: int, num_heads: int, world_size: int): + return split_num_heads(num_heads=num_heads, rank=rank, world_size=world_size) * (inner_dim // num_heads) diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 94d3e29f3..2dc4f8bb2 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -121,6 +121,9 @@ def make_tensor_parallel( if model_config.model_type == "bloom": tp_config = get_bloom_config(model_config, devices) del tp_config.state_rules[re.compile(".*word_embeddings.weight$")] + if model_config.model_type == "llama": + from petals.models.llama.slicing import get_llama_config + tp_config = get_llama_config(model_config, devices) else: if len(devices) > 1: logger.warning("Tensor parallelism is not tested for models other than BLOOM yet, proceed with caution")