diff --git a/README.md b/README.md index ea39fd7..2cb8df6 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,16 @@ To run logits equivalence through chunking and sliding window, launch python -m test_generate ``` +### Running large models + +When running models that are too large to fit a single GPU's memory, use pipeline parallelism (PP) and `torchrun`. This is needed to run `Mixtral-7B-8x`. The code below does 2-way PP. + +``` +torchrun --nproc-per-node 2 -m main demo /path/to/mixtral-7B-8x-v0.1/ --num_pipeline_ranks=2 +``` + +> [!Note] +> PP is not supported when running in interactive mode. # Sliding window attention @@ -112,6 +122,17 @@ For this we can choose as chunk size the window size. For each chunk, we thus ne ![Chunking](assets/chunking.png) +# Sparse Mixture of Experts (SMoE) + +Sparse Mixture of Experts allows one to decouple throughput from memory costs by only activating subsets of the overall model for each token. In this approach, each token is assigned to one or more "experts" -- a separate set of weights -- and only processed by sunch experts. This division happens at feedforward layers of the model. The expert models specialize in different aspects of the data, allowing them to capture complex patterns and make more accurate predictions. + +![SMoE](assets/smoe.png) + +## Pipeline Parallelism + +Pipeline parallelism is a set of techniques for partitioning models, enabling the distribution of a large model across multiple GPUs. We provide a simple implementation of pipeline parallelism, which allows our larger models to be executed within the memory constraints of modern GPUs. Note that this implementation favours simplicity over throughput efficiency, and most notabably does not include microbatching. + + ## Integrations and related projects diff --git a/assets/smoe.png b/assets/smoe.png new file mode 100644 index 0000000..8f0ebf6 Binary files /dev/null and b/assets/smoe.png differ diff --git a/main.py b/main.py index bffda3f..047a4fa 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ from mistral.cache import RotatingBufferCache +import logging import torch import fire from typing import List @@ -31,7 +32,7 @@ def sample(logits: torch.Tensor, temperature: float, top_p: float): @torch.inference_mode() -def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, max_tokens: int, chunk_size: int = None, temperature: float = 0.7): +def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, max_tokens: int, temperature: float, chunk_size: int = None): model = model.eval() B, V = len(prompts), model.args.vocab_size @@ -40,8 +41,16 @@ def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, ma seqlens = [len(x) for x in encoded_prompts] # Cache - cache_window = min(model.args.sliding_window, max(seqlens) + max_tokens) - cache = RotatingBufferCache(model.args.n_layers, model.args.max_batch_size, cache_window, model.args.n_kv_heads, model.args.head_dim) + cache_window = max(seqlens) + max_tokens + if model.args.sliding_window is not None and cache_window > model.args.sliding_window: + cache_window = model.args.sliding_window + cache = RotatingBufferCache( + model.n_local_layers, + model.args.max_batch_size, + cache_window, + model.args.n_kv_heads, + model.args.head_dim, + ) cache.to(device=model.device, dtype=model.dtype) cache.reset() @@ -81,6 +90,7 @@ def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, ma # decode generated_tokens = [] + assert last_token_prelogits is not None for i_token in range(max_tokens): next_token = sample(last_token_prelogits, temperature=temperature, top_p=0.8) @@ -117,14 +127,25 @@ def interactive(model_path: str, max_tokens: int = 35, temperature: float = 0.7) print(res[0]) print("=====================") -def demo(model_path: str, max_tokens: int = 35, temperature: float = 0): + +def demo( + model_path: str, max_tokens: int = 35, temperature: float = 0, num_pipeline_ranks=1 +): + if num_pipeline_ranks > 1: + torch.distributed.init_process_group() + torch.cuda.set_device(torch.distributed.get_rank()) + should_print = torch.distributed.get_rank() == 0 + else: + should_print = True tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model")) - transformer = Transformer.from_folder(Path(model_path), max_batch_size=3) + transformer = Transformer.from_folder( + Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks + ) res, _logprobs = generate( [ "This is a test", - "This is another test", + "This is another great test", "This is a third test, mistral AI is very good at testing. ", ], transformer, @@ -132,11 +153,14 @@ def demo(model_path: str, max_tokens: int = 35, temperature: float = 0): max_tokens=max_tokens, temperature=temperature, ) - for x in res: - print(x) - print("=====================") + if should_print: + for x,l in zip(res, _logprobs): + print(x) + logging.debug('Logprobs: %s',l) + print("=====================") if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) fire.Fire({ "interactive": interactive, "demo": demo, diff --git a/mistral/model.py b/mistral/model.py index cc8203d..553bdb6 100644 --- a/mistral/model.py +++ b/mistral/model.py @@ -1,32 +1,41 @@ -import torch -from torch import nn +import json +import logging +import math from dataclasses import dataclass from pathlib import Path -import json from typing import List, Optional +import torch +from torch import nn +from simple_parsing.helpers import Serializable + from mistral.rope import precompute_freqs_cis, apply_rotary_emb from mistral.cache import CacheView, RotatingBufferCache +from mistral.moe import MoeArgs, MoeLayer -from xformers.ops.fmha import ( - memory_efficient_attention, -) +from xformers.ops.fmha import memory_efficient_attention @dataclass -class ModelArgs: +class ModelArgs(Serializable): dim: int n_layers: int head_dim: int hidden_dim: int n_heads: int n_kv_heads: int - sliding_window: int norm_eps: float vocab_size: int max_batch_size: int = 0 + # For rotary embeddings. If not set, will be infered from sliding window. + rope_theta: Optional[float] = None + # If this is set, use sliding window attention rotating cache. + sliding_window: Optional[int] = None + # If this is set, we will use MoE layers instead of dense layers. + moe: Optional[MoeArgs] = None + @dataclass class SimpleInputMetadata: @@ -36,9 +45,9 @@ class SimpleInputMetadata: @staticmethod def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleInputMetadata": return SimpleInputMetadata( - positions = torch.cat( - [torch.arange(0, seqlen) for seqlen in seqlens] - ).to(device=device, dtype=torch.long) + positions=torch.cat([torch.arange(0, seqlen) for seqlen in seqlens]).to( + device=device, dtype=torch.long + ) ) @@ -54,46 +63,30 @@ def __init__(self, args: ModelArgs): self.args = args self.n_heads: int = args.n_heads + self.head_dim: int = args.head_dim self.n_kv_heads: int = args.n_kv_heads - + self.repeats = self.n_heads // self.n_kv_heads - self.sliding_window = self.args.sliding_window self.scale = self.args.head_dim**-0.5 - self.wq = nn.Linear( - args.dim, - args.n_heads * args.head_dim, - bias=False - ) - self.wk = nn.Linear( - args.dim, - args.n_kv_heads * args.head_dim, - bias=False - ) - self.wv = nn.Linear( - args.dim, - args.n_kv_heads * args.head_dim, - bias=False - ) - self.wo = nn.Linear( - args.n_heads * args.head_dim, - args.dim, - bias=False - ) - + self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) + self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) def forward( - self, x: torch.Tensor, + self, + x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView], ) -> torch.Tensor: seqlen_sum, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(seqlen_sum, self.n_heads, self.args.head_dim) - xk = xk.view(seqlen_sum, self.n_kv_heads, self.args.head_dim) - xv = xv.view(seqlen_sum, self.n_kv_heads, self.args.head_dim) + xq = xq.view(seqlen_sum, self.n_heads, self.head_dim) + xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim) + xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) if cache is None: @@ -101,41 +94,35 @@ def forward( elif cache.prefill: key, val = cache.interleave_kv(xk, xv) cache.update(xk, xv) - else: + else: cache.update(xk, xv) key, val = cache.key, cache.value - key = key.view(seqlen_sum * cache.sliding_window, self.n_kv_heads, self.args.head_dim) - val = val.view(seqlen_sum * cache.sliding_window, self.n_kv_heads, self.args.head_dim) + key = key.view( + seqlen_sum * cache.sliding_window, self.n_kv_heads, self.head_dim + ) + val = val.view( + seqlen_sum * cache.sliding_window, self.n_kv_heads, self.head_dim + ) # Repeat keys and values to match number of query heads key, val = repeat_kv(key, val, self.repeats, dim=1) # xformers requires (B=1, S, H, D) xq, key, val = xq[None, ...], key[None, ...], val[None, ...] - output = memory_efficient_attention(xq, key, val, None if cache is None else cache.mask) + output = memory_efficient_attention( + xq, key, val, None if cache is None else cache.mask + ) - return self.wo(output.view_as(x)) + return self.wo(output.view(seqlen_sum, self.n_heads * self.head_dim)) class FeedForward(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.w1 = nn.Linear( - args.dim, - args.hidden_dim, - bias=False - ) - self.w2 = nn.Linear( - args.hidden_dim, - args.dim, - bias=False - ) - self.w3 = nn.Linear( - args.dim, - args.hidden_dim, - bias=False - ) + self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) def forward(self, x) -> torch.Tensor: return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) @@ -161,11 +148,20 @@ def __init__(self, args: ModelArgs): self.n_heads = args.n_heads self.dim = args.dim self.attention = Attention(args) - self.feed_forward = FeedForward(args=args) self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) self.args = args + self.feed_forward: nn.Module + if args.moe is not None: + self.feed_forward = MoeLayer( + experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)], + gate=nn.Linear(args.dim, args.moe.num_experts, bias=False), + moe_args=args.moe, + ) + else: + self.feed_forward = FeedForward(args=args) + def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView] ) -> torch.Tensor: @@ -177,77 +173,197 @@ def forward( class Transformer(nn.Module): - def __init__(self, args: ModelArgs): + def __init__( + self, + args: ModelArgs, + pipeline_rank: int = 0, + num_pipeline_ranks: int = 1, + ): super().__init__() self.args = args self.vocab_size = args.vocab_size self.n_layers = args.n_layers + self._precomputed_freqs_cis: Optional[torch.Tensor] = None assert self.vocab_size > 0 - - self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) - - self.layers = torch.nn.ModuleList( - [TransformerBlock(args=args) for _ in range(args.n_layers)] - ) - - self.norm = RMSNorm(args.dim, eps=args.norm_eps) - - self.output = nn.Linear( - args.dim, - args.vocab_size, - bias=False - ) - - self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000).to("cuda") + assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks) + self.pipeline_rank = pipeline_rank + self.num_pipeline_ranks = num_pipeline_ranks + # Modules specific to some ranks: + self.tok_embeddings: Optional[nn.Embedding] = None + self.norm: Optional[RMSNorm] = None + self.output: Optional[nn.Linear] = None + if pipeline_rank == 0: + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + if pipeline_rank == num_pipeline_ranks - 1: + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + # Initialize all layers but slice off those not of this rank. + layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] + num_layers_per_rank = math.ceil(self.n_layers / self.num_pipeline_ranks) + offset = self.pipeline_rank * num_layers_per_rank + end = min(self.n_layers, offset + num_layers_per_rank) + self.layers = nn.ModuleDict({str(i): layers[i] for i in range(offset, end)}) + self.n_local_layers = len(self.layers) @property def dtype(self) -> torch.dtype: - return self.tok_embeddings.weight.dtype + return next(self.parameters()).dtype @property def device(self) -> torch.device: - return self.tok_embeddings.weight.device + return next(self.parameters()).device + + @property + def freqs_cis(self) -> torch.Tensor: + # We cache freqs_cis but need to take care that it is on the right device + # and has the right dtype (complex64). The fact that the dtype is different + # from the module's dtype means we cannot register it as a buffer + if self._precomputed_freqs_cis is None: + # If no sliding window, assume a larger seqlen + theta = self.args.rope_theta + if theta is None: + theta = 1000000.0 if self.args.sliding_window is None else 10000.0 + # theta = 10000. + self._precomputed_freqs_cis = precompute_freqs_cis( + self.args.head_dim, 128_000, theta + ) + if self._precomputed_freqs_cis.device != self.device: + self._precomputed_freqs_cis = self._precomputed_freqs_cis.to( + device=self.device + ) + return self._precomputed_freqs_cis def forward_partial( self, input_ids: torch.Tensor, seqlens: List[int], - cache: Optional[RotatingBufferCache]=None, + cache: Optional[RotatingBufferCache] = None, ) -> torch.Tensor: - assert len(seqlens) <= self.args.max_batch_size, f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}" - assert sum(seqlens) == input_ids.shape[0], (sum(seqlens), input_ids.shape[0]) + """Local forward pass. + + If doing pipeline parallelism, this will return the activations of the last layer of this stage. + For the last stage, this will return the normalized final embeddings. + """ + assert ( + len(seqlens) <= self.args.max_batch_size + ), f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}" + (num_toks,) = input_ids.shape + assert sum(seqlens) == num_toks, (sum(seqlens), num_toks) if cache is not None: input_metadata = cache.get_input_metadata(seqlens) else: input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device) - h = self.tok_embeddings(input_ids) + + if self.pipeline_rank == 0: + assert self.tok_embeddings is not None + h = self.tok_embeddings(input_ids) + else: + h = torch.empty( + num_toks, self.args.dim, device=self.device, dtype=self.dtype + ) + torch.distributed.recv(h, src=self.pipeline_rank - 1) + freqs_cis = self.freqs_cis[input_metadata.positions] - for layer_id, layer in enumerate(self.layers): - cache_view = None if cache is None else cache.get_view(layer_id, input_metadata) + for local_layer_id, layer in enumerate(self.layers.values()): + if cache is not None: + assert input_metadata is not None + cache_view = cache.get_view(local_layer_id, input_metadata) + else: + cache_view = None h = layer(h, freqs_cis, cache_view) - + if cache is not None: cache.update_seqlens(seqlens) - - return self.norm(h) + if self.pipeline_rank < self.num_pipeline_ranks - 1: + torch.distributed.send(h, dst=self.pipeline_rank + 1) + return h + else: + # Last rank has a final normalization step. + assert self.norm is not None + return self.norm(h) def forward( self, input_ids: torch.Tensor, seqlens: List[int], - cache: Optional[RotatingBufferCache]=None, + cache: Optional[RotatingBufferCache] = None, ) -> torch.Tensor: - return self.output(self.forward_partial( - input_ids, seqlens, cache=cache - )).float() + h = self.forward_partial(input_ids, seqlens, cache=cache) + if self.pipeline_rank < self.num_pipeline_ranks - 1: + # ignore the intermediate activations as we'll get the final output from + # the last stage + outs = torch.empty( + h.shape[0], self.vocab_size, device=h.device, dtype=h.dtype + ) + else: + assert self.output is not None + outs = self.output(h) + if self.num_pipeline_ranks > 1: + torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1) + return outs.float() + + def load_state_dict(self, state_dict, *args, **kwargs): + state_to_load = {} + skipped = set([]) + for k, v in state_dict.items(): + if k.startswith("tok_embeddings"): + if self.pipeline_rank == 0: + state_to_load[k] = v + else: + logging.debug( + "Skipping parameter %s at pipeline rank %d", + k, + self.pipeline_rank, + ) + skipped.add(k) + elif k.startswith("norm") or k.startswith("output"): + if self.pipeline_rank == self.num_pipeline_ranks - 1: + state_to_load[k] = v + else: + logging.debug( + "Skipping parameter %s at pipeline rank %d", + k, + self.pipeline_rank, + ) + skipped.add(k) + elif k.startswith("layers"): + layer_id = k.split(".")[1] + if layer_id in self.layers: + state_to_load[k] = v + else: + logging.debug( + "Skipping parameter %s at pipeline rank %d", + k, + self.pipeline_rank, + ) + skipped.add(k) + else: + raise ValueError(f"Unexpected key {k}") + assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys())) + super().load_state_dict(state_to_load, *args, **kwargs) @staticmethod - def from_folder(folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16) -> "Transformer": - with open(folder / 'params.json', 'r') as f: - model_args = ModelArgs(**json.loads(f.read())) + def from_folder( + folder: Path, + max_batch_size: int = 1, + num_pipeline_ranks: int = 1, + device="cuda", + dtype=torch.float16, + ) -> "Transformer": + with open(folder / "params.json", "r") as f: + model_args = ModelArgs.from_dict(json.load(f)) model_args.max_batch_size = max_batch_size - model = Transformer(model_args).to(device=device, dtype=dtype) - loaded = torch.load(folder / 'consolidated.00.pth') - model.load_state_dict(loaded) - return model + if num_pipeline_ranks > 1: + pipeline_rank = torch.distributed.get_rank() + else: + pipeline_rank = 0 + with torch.device("meta"): + model = Transformer( + model_args, + pipeline_rank=pipeline_rank, + num_pipeline_ranks=num_pipeline_ranks, + ) + loaded = torch.load(str(folder / "consolidated.00.pth"), mmap=True) + model.load_state_dict(loaded, assign=True) + return model.to(device=device, dtype=dtype) diff --git a/mistral/moe.py b/mistral/moe.py new file mode 100644 index 0000000..edee500 --- /dev/null +++ b/mistral/moe.py @@ -0,0 +1,34 @@ +import dataclasses +from typing import List + +import torch +import torch.nn.functional as F +from simple_parsing.helpers import Serializable +from torch import nn + + +@dataclasses.dataclass +class MoeArgs(Serializable): + num_experts: int + num_experts_per_tok: int + + +class MoeLayer(nn.Module): + def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs): + super().__init__() + assert len(experts) > 0 + self.experts = nn.ModuleList(experts) + self.gate = gate + self.args = moe_args + + def forward(self, inputs: torch.Tensor): + gate_logits = self.gate(inputs) + weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok) + weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) + results = torch.zeros_like(inputs) + for i, expert in enumerate(self.experts): + batch_idx, nth_expert = torch.where(selected_experts == i) + results[batch_idx] += weights[batch_idx, nth_expert, None] * expert( + inputs[batch_idx] + ) + return results diff --git a/mistral/rope.py b/mistral/rope.py index 4e9eea8..2c023d3 100644 --- a/mistral/rope.py +++ b/mistral/rope.py @@ -2,7 +2,7 @@ from typing import Tuple -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: +def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor: freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore diff --git a/moe_one_file_ref.py b/moe_one_file_ref.py new file mode 100644 index 0000000..4def8e4 --- /dev/null +++ b/moe_one_file_ref.py @@ -0,0 +1,541 @@ +import json +import logging +import math +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Tuple + +import fire +import torch +from sentencepiece import SentencePieceProcessor +from simple_parsing.helpers import Serializable +from torch import nn + + +@dataclass +class MoeArgs(Serializable): + num_experts: int + num_experts_per_tok: int + + +@dataclass +class ModelArgs(Serializable): + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float + vocab_size: int + moe: MoeArgs + + max_batch_size: int = 0 + max_seq_len: int = 0 + + +def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int): + keys = torch.repeat_interleave(keys, repeats=repeats, dim=2) + values = torch.repeat_interleave(values, repeats=repeats, dim=2) + return keys, values + + +def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + freqs_cis: complex - (seq_len, head_dim / 2) + x: complex - (bsz, seq_len, head_dim / 2) + """ + ndim = x.ndim + assert 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( + freqs_cis.shape, + (x.shape[1], x.shape[-1]), + ) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.n_heads: int = args.n_heads + self.n_kv_heads: int = args.n_kv_heads + + self.repeats = self.n_heads // self.n_kv_heads + + self.scale = self.args.head_dim**-0.5 + + self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) + self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) + self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) + self._cache_k: Optional[torch.Tensor] = None + self._cache_v: Optional[torch.Tensor] = None + + def get_caches(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + dtype, device = x.dtype, x.device + assert x.shape[0] <= self.args.max_batch_size + assert x.shape[1] <= self.args.max_seq_len + if self._cache_k is None: + self._cache_k = torch.empty( + ( + self.args.max_batch_size, + self.args.max_seq_len, + self.n_kv_heads, + self.args.head_dim, + ), + dtype=dtype, + device=device, + ) + if self._cache_v is None: + self._cache_v = torch.empty( + ( + self.args.max_batch_size, + self.args.max_seq_len, + self.n_kv_heads, + self.args.head_dim, + ), + dtype=dtype, + device=device, + ) + return self._cache_k, self._cache_v + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + positions: torch.Tensor, + mask: Optional[torch.Tensor], + ) -> torch.Tensor: + bsz, seqlen, _ = x.shape + + cache_k, cache_v = self.get_caches(x) + + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seqlen, self.n_heads, self.args.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # The cache is a rotating buffer + scatter_pos = (positions % self.args.max_seq_len)[None, :, None, None] + scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim) + cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk) + cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv) + + if positions.shape[0] > 1: + # prefill + key, value = repeat_kv(xk, xv, self.repeats) + else: + assert mask is None + cur_pos = int(positions[-1].item() + 1) + key, value = repeat_kv( + cache_k[:bsz, :cur_pos, ...], + cache_v[:bsz, :cur_pos, ...], + self.repeats, + ) + + query = xq.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + # scores : [bsz, n_heads, seqlen | 1, seqlen] + scores = torch.matmul(query, key.transpose(2, 3)) * self.scale + + if mask is not None: + scores += mask[None, None, ...] + + scores = scores.float() + scores = nn.functional.softmax(scores, dim=-1).type_as(query) + output = torch.matmul(scores, value) # (bs, n_local_heads, slen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) + + def forward(self, x) -> torch.Tensor: + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class MoeLayer(nn.Module): + def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs): + super().__init__() + assert len(experts) > 0 + self.experts = nn.ModuleList(experts) + self.gate = gate + self.args = moe_args + + def forward(self, inputs: torch.Tensor): + inputs_squashed = inputs.view(-1, inputs.shape[-1]) + gate_logits = self.gate(inputs_squashed) + weights, selected_experts = torch.topk( + gate_logits, self.args.num_experts_per_tok + ) + weights = nn.functional.softmax( + weights, + dim=1, + dtype=torch.float, + ).type_as(inputs) + results = torch.zeros_like(inputs_squashed) + for i, expert in enumerate(self.experts): + batch_idx, nth_expert = torch.where(selected_experts == i) + results[batch_idx] += weights[batch_idx, nth_expert, None] * expert( + inputs_squashed[batch_idx] + ) + return results.view_as(inputs) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + self.feed_forward = MoeLayer( + experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)], + gate=nn.Linear(args.dim, args.moe.num_experts, bias=False), + moe_args=args.moe, + ) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.args = args + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + positions: torch.Tensor, + mask: Optional[torch.Tensor], + ) -> torch.Tensor: + r = self.attention.forward(self.attention_norm(x), freqs_cis, positions, mask) + h = x + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +def precompute_freqs_cis(dim: int, end: int) -> torch.Tensor: + theta = 1000000.0 + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + return torch.polar(torch.ones_like(freqs), freqs) # complex64 + + +class Transformer(nn.Module): + def __init__( + self, + args: ModelArgs, + pipeline_rank: int = 0, + num_pipeline_ranks: int = 1, + ): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + assert self.vocab_size > 0 + assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks) + self.pipeline_rank = pipeline_rank + self.num_pipeline_ranks = num_pipeline_ranks + self._precomputed_freqs_cis: Optional[torch.Tensor] = None + + # Modules specific to some ranks: + self.tok_embeddings: Optional[nn.Embedding] = None + self.norm: Optional[RMSNorm] = None + self.output: Optional[nn.Linear] = None + if pipeline_rank == 0: + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + if pipeline_rank == num_pipeline_ranks - 1: + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + + # Initialize all layers but slice off those not of this rank. + layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] + num_layers_per_rank = math.ceil(args.n_layers / self.num_pipeline_ranks) + offset = self.pipeline_rank * num_layers_per_rank + end = min(args.n_layers, offset + num_layers_per_rank) + self.layers = nn.ModuleDict({str(i): layers[i] for i in range(offset, end)}) + self.n_local_layers = len(self.layers) + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def freqs_cis(self) -> torch.Tensor: + # We cache freqs_cis but need to take care that it is on the right device + # and has the right dtype (complex64). The fact that the dtype is different + # from the module's dtype means we cannot register it as a buffer + if self._precomputed_freqs_cis is None: + self._precomputed_freqs_cis = precompute_freqs_cis( + self.args.head_dim, 128_000 + ) + if self._precomputed_freqs_cis.device != self.device: + self._precomputed_freqs_cis = self._precomputed_freqs_cis.to( + device=self.device + ) + return self._precomputed_freqs_cis + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + ): + freqs_cis = self.freqs_cis[positions] + + (bsz, seqlen) = input_ids.shape + num_toks = bsz * seqlen + + if self.pipeline_rank == 0: + assert self.tok_embeddings is not None + h = self.tok_embeddings(input_ids) + assert h.shape == (bsz, seqlen, self.args.dim) + assert h.dtype == self.dtype + else: + h = torch.empty( + bsz, seqlen, self.args.dim, device=self.device, dtype=self.dtype + ) + torch.distributed.recv(h, src=self.pipeline_rank - 1) + + mask: Optional[torch.Tensor] = None + if input_ids.shape[1] > 1: + tensor = torch.full( + (seqlen, seqlen), + dtype=h.dtype, + fill_value=1, + device=h.device, + ) + mask = torch.log(torch.tril(tensor, diagonal=0)).to(h.dtype) + + for layer in self.layers.values(): + h = layer(h, freqs_cis, positions, mask) + + if self.pipeline_rank < self.num_pipeline_ranks - 1: + torch.distributed.send(h, dst=self.pipeline_rank + 1) + outs = torch.empty( + *h.shape[:-1], self.vocab_size, device=h.device, dtype=h.dtype + ) + else: + assert self.output is not None + assert self.norm is not None + outs = self.output(self.norm(h)) + if self.num_pipeline_ranks > 1: + torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1) + return outs.float() + + def load_state_dict(self, state_dict, *args, **kwargs): + state_to_load = {} + skipped = set([]) + for k, v in state_dict.items(): + if k.startswith("tok_embeddings"): + if self.pipeline_rank == 0: + state_to_load[k] = v + else: + logging.debug( + "Skipping parameter %s at pipeline rank %d", + k, + self.pipeline_rank, + ) + skipped.add(k) + elif k.startswith("norm") or k.startswith("output"): + if self.pipeline_rank == self.num_pipeline_ranks - 1: + state_to_load[k] = v + else: + logging.debug( + "Skipping parameter %s at pipeline rank %d", + k, + self.pipeline_rank, + ) + skipped.add(k) + elif k.startswith("layers"): + layer_id = k.split(".")[1] + if layer_id in self.layers: + state_to_load[k] = v + else: + logging.debug( + "Skipping parameter %s at pipeline rank %d", + k, + self.pipeline_rank, + ) + skipped.add(k) + else: + raise ValueError(f"Unexpected key {k}") + assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys())) + super().load_state_dict(state_to_load, *args, **kwargs) + + @staticmethod + def from_folder( + folder: Path, + max_batch_size: int, + max_seq_len: int, + num_pipeline_ranks: int = 1, + device="cuda", + dtype=torch.float16, + ) -> "Transformer": + with open(folder / "params.json", "r") as f: + model_args = ModelArgs.from_dict(json.load(f)) + model_args.max_batch_size = max_batch_size + model_args.max_seq_len = max_seq_len + if num_pipeline_ranks > 1: + pipeline_rank = torch.distributed.get_rank() + else: + pipeline_rank = 0 + with torch.device("meta"): + model = Transformer( + model_args, + pipeline_rank=pipeline_rank, + num_pipeline_ranks=num_pipeline_ranks, + ) + loaded = torch.load(str(folder / "consolidated.00.pth"), mmap=True) + model.load_state_dict(loaded, assign=True) + return model.to(device=device, dtype=dtype) + + +class Tokenizer: + def __init__(self, model_path: str): + assert Path(model_path).exists(), model_path + self._model = SentencePieceProcessor(model_file=model_path) + assert self._model.vocab_size() == self._model.get_piece_size() + + @property + def eos_id(self) -> int: + return self._model.eos_id() + + @property + def pad_id(self) -> int: + return self._model.pad_id() + + def encode(self, s: str) -> List[int]: + return [self._model.bos_id(), *self._model.encode(s)] + + def decode(self, t: List[int]) -> str: + return self._model.decode(t) + + +@torch.no_grad() +def generate( + prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int +): + encoded_prompts = [tokenizer.encode(prompt) for prompt in prompts] + prompt_lens = [len(x) for x in encoded_prompts] + min_prompt_len = min(prompt_lens) + max_prompt_len = max(prompt_lens) + + input_tokens = torch.full( + (len(prompts), max_prompt_len), + tokenizer.pad_id, + dtype=torch.long, + device="cuda", + ) + for i, encoded in enumerate(encoded_prompts): + input_tokens[i, : len(encoded)] = torch.tensor(encoded).to(input_tokens) + input_mask = input_tokens != tokenizer.pad_id + + # pre-fill + positions = torch.arange(0, min_prompt_len).to("cuda") + logits = model.forward(input_tokens[:, :min_prompt_len], positions) + logprobs = nn.functional.log_softmax(logits, dim=-1) + + # decode + generated = [] + all_logprobs = [ + logprobs[:, :-1, :] + .gather(2, input_tokens[:, 1:min_prompt_len, None]) + .squeeze(-1), + ] + for cur_pos in range(min_prompt_len, max_tokens): + next_token = torch.argmax(logprobs[:, -1, :], dim=-1) + if cur_pos < input_mask.shape[1]: + next_token = torch.where( + input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token + ) + all_logprobs.append( + logprobs[:, -1, :].gather(1, next_token[:, None]), + ) + generated.append(next_token[:, None]) + logits = model.forward( + next_token[:, None], torch.LongTensor([cur_pos]).to(next_token) + ) + logprobs = nn.functional.log_softmax(logits, dim=-1) + + all_logprobs_merged = torch.cat(all_logprobs, 1) + res = [] + if max_tokens > 0: + generated = torch.cat(generated, 1) + for i, x in enumerate(encoded_prompts): + res.append(tokenizer.decode(x[:min_prompt_len] + generated[i].tolist())) + return res, all_logprobs_merged + + +def demo(model_path: str, max_tokens: int = 30, num_pipeline_ranks=2): + if num_pipeline_ranks > 1: + torch.distributed.init_process_group() + torch.cuda.set_device(torch.distributed.get_rank()) + should_print = torch.distributed.get_rank() == 0 + else: + should_print = True + + tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model")) + transformer = Transformer.from_folder( + Path(model_path), + max_batch_size=3, + max_seq_len=max_tokens, + num_pipeline_ranks=num_pipeline_ranks, + ) + + res, logprobs = generate( + [ + "This is a test", + "This is another great test", + "This is a third test, mistral AI is very good at testing. ", + ], + transformer, + tokenizer, + max_tokens=max_tokens, + ) + if should_print: + for x, l in zip(res, logprobs): + print(x) + logging.debug("logprobs: %s", l) + print("=====================") + + +if __name__ == "__main__": + fire.Fire(demo) diff --git a/requirements.txt b/requirements.txt index a76f2e5..a72226f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ fire sentencepiece -torch -xformers \ No newline at end of file +torch>=2.1.0 +xformers +simple-parsing \ No newline at end of file