From 93465219133a5e4f521aaeca1228abd3319cb051 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 16 Oct 2024 13:28:29 +0000 Subject: [PATCH 1/4] WIP --- src/mistral_inference/args.py | 8 +- src/mistral_inference/cache.py | 141 +++++++++++++++++++-------- src/mistral_inference/generate.py | 1 + src/mistral_inference/transformer.py | 11 ++- 4 files changed, 112 insertions(+), 49 deletions(-) diff --git a/src/mistral_inference/args.py b/src/mistral_inference/args.py index a94a2c6..4be2470 100644 --- a/src/mistral_inference/args.py +++ b/src/mistral_inference/args.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, List from simple_parsing.helpers import Serializable @@ -39,12 +39,18 @@ class TransformerArgs(Serializable): moe: Optional[MoeArgs] = None # If this is set, we will load LoRA linear layers instead of linear layers. lora: Optional[LoraArgs] = None + sliding_window: Optional[int] | Optional[List[int]] = None + _sliding_window: Optional[int] | Optional[List[int]] = None model_type: str = "transformer" vision_encoder: Optional[VisionEncoderArgs] = None def __post_init__(self) -> None: assert self.model_type == "transformer", self.model_type + assert self.sliding_window is None or self._sliding_window is None + + # hack for now so that vLLM is supported correctly + self.sliding_window = self.sliding_window if self.sliding_window is not None else self._sliding_window @dataclass diff --git a/src/mistral_inference/cache.py b/src/mistral_inference/cache.py index 93cfb1c..3cddb4f 100644 --- a/src/mistral_inference/cache.py +++ b/src/mistral_inference/cache.py @@ -10,13 +10,39 @@ ) +def get_cache_sizes(n_layers: int, max_seq_len: int, sliding_window: Optional[int] | Optional[List[int]]) -> List[int]: + if sliding_window is None: + return n_layers * [max_seq_len] + elif isinstance(sliding_window, int): + return n_layers * [sliding_window] + else: + assert isinstance(sliding_window, list), f"Expected list, got {type(sliding_window)}" + assert n_layers % len(sliding_window) == 0, f"Expected n_layers % len(sliding_window) == 0, got {n_layers} % {len(sliding_window)}" + num_repeats = n_layers // len(sliding_window) + return num_repeats * [w if w is not None else max_seq_len for w in sliding_window] + + + @dataclass class CacheInputMetadata: + # # rope absolute positions + # positions: torch.Tensor + # # where tokens should go in the cache + # cache_positions: torch.Tensor + + # # if prefill, use block diagonal causal mask + # # else use causal with padded key mask + # prefill: bool + # mask: AttentionBias + # seqlens: List[int] # rope absolute positions positions: torch.Tensor + # which elements in the sequences need to be cached + to_cache_mask: torch.Tensor + # how many elements are cached per sequence + cached_elements: torch.Tensor # where tokens should go in the cache cache_positions: torch.Tensor - # if prefill, use block diagonal causal mask # else use causal with padded key mask prefill: bool @@ -29,6 +55,17 @@ def interleave_list(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> List[torc return [v for pair in zip(l1, l2) for v in pair] +def unrotate(cache: torch.Tensor, seqlen: int) -> torch.Tensor: + assert cache.ndim == 3 # (W, H, D) + position = seqlen % cache.shape[0] + if seqlen < cache.shape[0]: + return cache[:seqlen] + elif position == 0: + return cache + else: + return torch.cat([cache[position:], cache[:position]], dim=0) + + class CacheView: def __init__( self, @@ -50,8 +87,8 @@ def update(self, xk: torch.Tensor, xv: torch.Tensor) -> None: flat_cache_k = self.cache_k.view(-1, n_kv_heads, head_dim) flat_cache_v = self.cache_v.view(-1, n_kv_heads, head_dim) - flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk) - flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv) + flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk[self.metadata.to_cache_mask]) + flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv[self.metadata.to_cache_mask]) def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -69,15 +106,16 @@ def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tenso xv: Tuple[torch.Tensor] = torch.split(xv, self.metadata.seqlens) # type: ignore assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}" - # Retrieve cache - cache_k = [cache_k[:seq_len] for cache_k, seq_len in zip(self.cache_k, self.kv_seqlens)] - cache_v = [cache_v[:seq_len] for cache_v, seq_len in zip(self.cache_v, self.kv_seqlens)] + # Order elements in cache by position by unrotating + cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)] + cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)] - interleaved_k = interleave_list(cache_k, list(xk)) - interleaved_v = interleave_list(cache_v, list(xv)) + interleaved_k = interleave_list(cache_k, xk) + interleaved_v = interleave_list(cache_v, xv) return torch.cat(interleaved_k, dim=0), torch.cat(interleaved_v, dim=0) + @property def max_seq_len(self) -> int: return self.cache_k.shape[1] @@ -112,13 +150,23 @@ def __init__( max_seq_len: int, n_kv_heads: int, head_dim: int, + sliding_window: Optional[int] | Optional[List[int]] = None ): + print(f"yeeeees {sliding_window}") self.max_seq_len = max_seq_len self.n_kv_heads = n_kv_heads self.head_dim = head_dim + self.n_layers = n_layers + + self.cache_sizes: List[int] = get_cache_sizes(n_layers, max_seq_len, sliding_window) + assert len(self.cache_sizes) == n_layers, f"Expected {n_layers} cache sizes, got {len(self.cache_sizes)}" + + self.cache_k = {} + self.cache_v = {} + for i, cache_size in enumerate(self.cache_sizes): + self.cache_k[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim)) + self.cache_v[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim)) - self.cache_k = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim)) - self.cache_v = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim)) # holds the valid length for each batch element in the cache self.kv_seqlens: Optional[torch.Tensor] = None @@ -134,11 +182,12 @@ def init_kvseqlens(self, batch_size: int) -> None: @property def device(self) -> torch.device: - return self.cache_k.device + return self.cache_k[0].device def to(self, device: torch.device, dtype: torch.dtype) -> "BufferCache": - self.cache_k = self.cache_k.to(device=device, dtype=dtype) - self.cache_v = self.cache_v.to(device=device, dtype=dtype) + for i in range(self.n_layers): + self.cache_k[i] = self.cache_k[i].to(device=device, dtype=dtype) + self.cache_v[i] = self.cache_v[i].to(device=device, dtype=dtype) return self @@ -146,55 +195,61 @@ def update_seqlens(self, seqlens: List[int]) -> None: assert self.kv_seqlens is not None self.kv_seqlens += torch.tensor(seqlens, device=self.device, dtype=torch.long) - def get_input_metadata(self, seqlens: List[int]) -> CacheInputMetadata: + def get_input_metadata(self, seqlens: List[int]) -> List[CacheInputMetadata]: """ - Get metadata about cache positions + input = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3 + --> only cache last 3 tokens in each sequence + - to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1] + - cached_elements = [3 | 3 | 2] + --> absolute positions are used for rope + - positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4] + --> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window + - cache_positions = [2 0 1 | 5 3 4 | 6 7] """ + metadata: List[CacheInputMetadata] = [] + if self.kv_seqlens is None: self.init_kvseqlens(len(seqlens)) - - assert isinstance(self.kv_seqlens, torch.Tensor) - assert len(seqlens) == len( - self.kv_seqlens - ), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?" + assert len(seqlens) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?" seqpos = self.kv_seqlens.tolist() - assert len(seqlens) > 0, seqlens - cached_elements = torch.tensor(seqlens, device=self.device, dtype=torch.long) - - positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to( - device=self.device, dtype=torch.long - ) - - batch_idx = torch.tensor( - sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []), - device=self.device, - dtype=torch.long, - ) - cache_positions = positions + batch_idx * self.max_seq_len + for cache_size in self.cache_sizes: + metadata.append(self._get_input_metadata_layer(cache_size, seqlens, seqpos)) + + return metadata + + def _get_input_metadata_layer(self, cache_size: int, seqlens: List[int], seqpos: List[int]) -> CacheInputMetadata: + masks = [ + [x >= seqlen - cache_size for x in range(seqlen)] + for seqlen in seqlens + ] + to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool) + cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long) + positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(device=self.device, dtype=torch.long) + batch_idx = torch.tensor(sum([[i]*seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long) + cache_positions = positions % cache_size + batch_idx * cache_size first_prefill = seqpos[0] == 0 subsequent_prefill = any(seqlen > 1 for seqlen in seqlens) if first_prefill: - assert all([pos == 0 for pos in seqpos]), seqpos - mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.max_seq_len) + assert all([pos == 0 for pos in seqpos]), (seqpos) + mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(cache_size) elif subsequent_prefill: mask = BlockDiagonalMask.from_seqlens( q_seqlen=seqlens, - kv_seqlen=[ - s + cached_s.clamp(max=self.max_seq_len).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens) - ], - ).make_local_attention_from_bottomright(self.max_seq_len) + kv_seqlen=[s + cached_s.clamp(max=cache_size).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)] + ).make_local_attention_from_bottomright(cache_size) else: mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( q_seqlen=seqlens, - kv_padding=self.max_seq_len, - kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.max_seq_len).tolist(), + kv_padding=cache_size, + kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=cache_size).tolist() ) - return CacheInputMetadata( positions=positions, - cache_positions=cache_positions, + to_cache_mask=to_cache_mask, + cached_elements=cached_elements, + cache_positions=cache_positions[to_cache_mask], prefill=first_prefill or subsequent_prefill, mask=mask, seqlens=seqlens, diff --git a/src/mistral_inference/generate.py b/src/mistral_inference/generate.py index 1e906b3..bc6112d 100644 --- a/src/mistral_inference/generate.py +++ b/src/mistral_inference/generate.py @@ -72,6 +72,7 @@ def generate( cache_window, model.args.n_kv_heads, model.args.head_dim, + model.args.sliding_window, ) cache.to(device=model.device, dtype=model.dtype) cache.reset() diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py index 9c9aebe..c4aa897 100644 --- a/src/mistral_inference/transformer.py +++ b/src/mistral_inference/transformer.py @@ -150,12 +150,12 @@ def forward_partial( (num_toks,) = input_ids.shape assert sum(seqlens) == num_toks, (sum(seqlens), num_toks) - input_metadata: Union[CacheInputMetadata, SimpleInputMetadata] + input_metadata: List[Union[CacheInputMetadata, SimpleInputMetadata]] if cache is not None: input_metadata = cache.get_input_metadata(seqlens) else: - input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device) + input_metadata = [SimpleInputMetadata.from_seqlens(seqlens, self.device) for _ in range(len(self.layers))] if self.pipeline_rank == 0: assert self.tok_embeddings is not None @@ -167,13 +167,14 @@ def forward_partial( h = torch.empty(num_toks, self.args.dim, device=self.device, dtype=self.dtype) torch.distributed.recv(h, src=self.pipeline_rank - 1) - freqs_cis = self.freqs_cis[input_metadata.positions] + # freqs_cis is always the same for every layer + freqs_cis = self.freqs_cis[input_metadata[0].positions] for local_layer_id, layer in enumerate(self.layers.values()): if cache is not None: assert input_metadata is not None - assert isinstance(input_metadata, CacheInputMetadata) - cache_view = cache.get_view(local_layer_id, input_metadata) + assert isinstance(input_metadata[local_layer_id], CacheInputMetadata) + cache_view = cache.get_view(local_layer_id, input_metadata[local_layer_id]) else: cache_view = None h = layer(h, freqs_cis, cache_view) From d37fbb54ec97a24dfa7342e7ad09c84c5ed910b3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 16 Oct 2024 15:34:03 +0200 Subject: [PATCH 2/4] Up --- src/mistral_inference/args.py | 2 +- src/mistral_inference/cache.py | 56 ++++++++++++++++------------ src/mistral_inference/transformer.py | 7 ++-- 3 files changed, 37 insertions(+), 28 deletions(-) diff --git a/src/mistral_inference/args.py b/src/mistral_inference/args.py index 4be2470..7ca4f43 100644 --- a/src/mistral_inference/args.py +++ b/src/mistral_inference/args.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional, List +from typing import List, Optional from simple_parsing.helpers import Serializable diff --git a/src/mistral_inference/cache.py b/src/mistral_inference/cache.py index 3cddb4f..22644c4 100644 --- a/src/mistral_inference/cache.py +++ b/src/mistral_inference/cache.py @@ -17,12 +17,13 @@ def get_cache_sizes(n_layers: int, max_seq_len: int, sliding_window: Optional[in return n_layers * [sliding_window] else: assert isinstance(sliding_window, list), f"Expected list, got {type(sliding_window)}" - assert n_layers % len(sliding_window) == 0, f"Expected n_layers % len(sliding_window) == 0, got {n_layers} % {len(sliding_window)}" + assert ( + n_layers % len(sliding_window) == 0 + ), f"Expected n_layers % len(sliding_window) == 0, got {n_layers} % {len(sliding_window)}" num_repeats = n_layers // len(sliding_window) return num_repeats * [w if w is not None else max_seq_len for w in sliding_window] - @dataclass class CacheInputMetadata: # # rope absolute positions @@ -110,12 +111,11 @@ def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tenso cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)] cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)] - interleaved_k = interleave_list(cache_k, xk) - interleaved_v = interleave_list(cache_v, xv) + interleaved_k = interleave_list(cache_k, list(xk)) + interleaved_v = interleave_list(cache_v, list(xv)) return torch.cat(interleaved_k, dim=0), torch.cat(interleaved_v, dim=0) - @property def max_seq_len(self) -> int: return self.cache_k.shape[1] @@ -150,7 +150,7 @@ def __init__( max_seq_len: int, n_kv_heads: int, head_dim: int, - sliding_window: Optional[int] | Optional[List[int]] = None + sliding_window: Optional[int] | Optional[List[int]] = None, ): print(f"yeeeees {sliding_window}") self.max_seq_len = max_seq_len @@ -197,20 +197,24 @@ def update_seqlens(self, seqlens: List[int]) -> None: def get_input_metadata(self, seqlens: List[int]) -> List[CacheInputMetadata]: """ - input = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3 - --> only cache last 3 tokens in each sequence - - to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1] - - cached_elements = [3 | 3 | 2] - --> absolute positions are used for rope - - positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4] - --> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window - - cache_positions = [2 0 1 | 5 3 4 | 6 7] + input = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3 + --> only cache last 3 tokens in each sequence + - to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1] + - cached_elements = [3 | 3 | 2] + --> absolute positions are used for rope + - positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4] + --> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window + - cache_positions = [2 0 1 | 5 3 4 | 6 7] """ metadata: List[CacheInputMetadata] = [] if self.kv_seqlens is None: self.init_kvseqlens(len(seqlens)) - assert len(seqlens) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?" + + assert self.kv_seqlens is not None + assert len(seqlens) == len( + self.kv_seqlens + ), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?" seqpos = self.kv_seqlens.tolist() assert len(seqlens) > 0, seqlens @@ -220,30 +224,34 @@ def get_input_metadata(self, seqlens: List[int]) -> List[CacheInputMetadata]: return metadata def _get_input_metadata_layer(self, cache_size: int, seqlens: List[int], seqpos: List[int]) -> CacheInputMetadata: - masks = [ - [x >= seqlen - cache_size for x in range(seqlen)] - for seqlen in seqlens - ] + masks = [[x >= seqlen - cache_size for x in range(seqlen)] for seqlen in seqlens] to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool) cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long) - positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(device=self.device, dtype=torch.long) - batch_idx = torch.tensor(sum([[i]*seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long) + positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to( + device=self.device, dtype=torch.long + ) + batch_idx = torch.tensor( + sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long + ) cache_positions = positions % cache_size + batch_idx * cache_size first_prefill = seqpos[0] == 0 subsequent_prefill = any(seqlen > 1 for seqlen in seqlens) if first_prefill: - assert all([pos == 0 for pos in seqpos]), (seqpos) + assert all([pos == 0 for pos in seqpos]), seqpos mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(cache_size) elif subsequent_prefill: + assert self.kv_seqlens is not None mask = BlockDiagonalMask.from_seqlens( q_seqlen=seqlens, - kv_seqlen=[s + cached_s.clamp(max=cache_size).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)] + kv_seqlen=[ + s + cached_s.clamp(max=cache_size).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens) + ], ).make_local_attention_from_bottomright(cache_size) else: mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( q_seqlen=seqlens, kv_padding=cache_size, - kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=cache_size).tolist() + kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=cache_size).tolist(), ) return CacheInputMetadata( positions=positions, diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py index c4aa897..a53195f 100644 --- a/src/mistral_inference/transformer.py +++ b/src/mistral_inference/transformer.py @@ -150,7 +150,7 @@ def forward_partial( (num_toks,) = input_ids.shape assert sum(seqlens) == num_toks, (sum(seqlens), num_toks) - input_metadata: List[Union[CacheInputMetadata, SimpleInputMetadata]] + input_metadata: List[CacheInputMetadata] | List[SimpleInputMetadata] if cache is not None: input_metadata = cache.get_input_metadata(seqlens) @@ -173,8 +173,9 @@ def forward_partial( for local_layer_id, layer in enumerate(self.layers.values()): if cache is not None: assert input_metadata is not None - assert isinstance(input_metadata[local_layer_id], CacheInputMetadata) - cache_view = cache.get_view(local_layer_id, input_metadata[local_layer_id]) + cache_metadata = input_metadata[local_layer_id] + assert isinstance(cache_metadata, CacheInputMetadata) + cache_view = cache.get_view(local_layer_id, cache_metadata) else: cache_view = None h = layer(h, freqs_cis, cache_view) From 5d4cc685fbaabebd7f7eb3b32d7a2f9f4ef038a3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 16 Oct 2024 13:38:44 +0000 Subject: [PATCH 3/4] WIP --- src/mistral_inference/cache.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mistral_inference/cache.py b/src/mistral_inference/cache.py index 22644c4..6f8aa7d 100644 --- a/src/mistral_inference/cache.py +++ b/src/mistral_inference/cache.py @@ -152,7 +152,6 @@ def __init__( head_dim: int, sliding_window: Optional[int] | Optional[List[int]] = None, ): - print(f"yeeeees {sliding_window}") self.max_seq_len = max_seq_len self.n_kv_heads = n_kv_heads self.head_dim = head_dim From b9524508e98db42d155b98fdf6494ebc70e3934e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 16 Oct 2024 13:58:03 +0000 Subject: [PATCH 4/4] WIP --- src/mistral_inference/transformer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py index a53195f..cb782dd 100644 --- a/src/mistral_inference/transformer.py +++ b/src/mistral_inference/transformer.py @@ -36,6 +36,7 @@ def __init__( args: TransformerArgs, pipeline_rank: int = 0, num_pipeline_ranks: int = 1, + softmax_fp32: bool = True, ): super().__init__() self.args = args @@ -46,6 +47,8 @@ def __init__( assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks) self.pipeline_rank = pipeline_rank self.num_pipeline_ranks = num_pipeline_ranks + self.softmax_fp32 = softmax_fp32 + # Modules specific to some ranks: self.tok_embeddings: Optional[nn.Embedding] = None self.norm: Optional[RMSNorm] = None @@ -207,7 +210,11 @@ def forward( outs = self.output(h) if self.num_pipeline_ranks > 1: torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1) - return outs.float() + + if self.softmax_fp32: + return outs.float() + else: + return outs def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None: state_to_load = {} @@ -259,6 +266,7 @@ def from_folder( num_pipeline_ranks: int = 1, device: Union[torch.device, str] = "cuda", dtype: Optional[torch.dtype] = None, + softmax_fp32: bool = True, ) -> "Transformer": with open(Path(folder) / "params.json", "r") as f: model_args = TransformerArgs.from_dict(json.load(f)) @@ -272,6 +280,7 @@ def from_folder( model_args, pipeline_rank=pipeline_rank, num_pipeline_ranks=num_pipeline_ranks, + softmax_fp32=softmax_fp32, ) pt_model_file = Path(folder) / "consolidated.00.pth"