Skip to content

Commit

Permalink
Merge pull request #230 from mistralai/add_layer_wise_rotated_cache
Browse files Browse the repository at this point in the history
Add per-layer sliding window
  • Loading branch information
patrickvonplaten authored Oct 16, 2024
2 parents db6b422 + b952450 commit 6428ccf
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 36 deletions.
8 changes: 7 additions & 1 deletion src/mistral_inference/args.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional

from simple_parsing.helpers import Serializable

Expand Down Expand Up @@ -39,12 +39,18 @@ class TransformerArgs(Serializable):
moe: Optional[MoeArgs] = None
# If this is set, we will load LoRA linear layers instead of linear layers.
lora: Optional[LoraArgs] = None
sliding_window: Optional[int] | Optional[List[int]] = None
_sliding_window: Optional[int] | Optional[List[int]] = None
model_type: str = "transformer"

vision_encoder: Optional[VisionEncoderArgs] = None

def __post_init__(self) -> None:
assert self.model_type == "transformer", self.model_type
assert self.sliding_window is None or self._sliding_window is None

# hack for now so that vLLM is supported correctly
self.sliding_window = self.sliding_window if self.sliding_window is not None else self._sliding_window


@dataclass
Expand Down
120 changes: 91 additions & 29 deletions src/mistral_inference/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,40 @@
)


def get_cache_sizes(n_layers: int, max_seq_len: int, sliding_window: Optional[int] | Optional[List[int]]) -> List[int]:
if sliding_window is None:
return n_layers * [max_seq_len]
elif isinstance(sliding_window, int):
return n_layers * [sliding_window]
else:
assert isinstance(sliding_window, list), f"Expected list, got {type(sliding_window)}"
assert (
n_layers % len(sliding_window) == 0
), f"Expected n_layers % len(sliding_window) == 0, got {n_layers} % {len(sliding_window)}"
num_repeats = n_layers // len(sliding_window)
return num_repeats * [w if w is not None else max_seq_len for w in sliding_window]


@dataclass
class CacheInputMetadata:
# # rope absolute positions
# positions: torch.Tensor
# # where tokens should go in the cache
# cache_positions: torch.Tensor

# # if prefill, use block diagonal causal mask
# # else use causal with padded key mask
# prefill: bool
# mask: AttentionBias
# seqlens: List[int]
# rope absolute positions
positions: torch.Tensor
# which elements in the sequences need to be cached
to_cache_mask: torch.Tensor
# how many elements are cached per sequence
cached_elements: torch.Tensor
# where tokens should go in the cache
cache_positions: torch.Tensor

# if prefill, use block diagonal causal mask
# else use causal with padded key mask
prefill: bool
Expand All @@ -29,6 +56,17 @@ def interleave_list(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> List[torc
return [v for pair in zip(l1, l2) for v in pair]


def unrotate(cache: torch.Tensor, seqlen: int) -> torch.Tensor:
assert cache.ndim == 3 # (W, H, D)
position = seqlen % cache.shape[0]
if seqlen < cache.shape[0]:
return cache[:seqlen]
elif position == 0:
return cache
else:
return torch.cat([cache[position:], cache[:position]], dim=0)


class CacheView:
def __init__(
self,
Expand All @@ -50,8 +88,8 @@ def update(self, xk: torch.Tensor, xv: torch.Tensor) -> None:
flat_cache_k = self.cache_k.view(-1, n_kv_heads, head_dim)
flat_cache_v = self.cache_v.view(-1, n_kv_heads, head_dim)

flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk)
flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv)
flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk[self.metadata.to_cache_mask])
flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv[self.metadata.to_cache_mask])

def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand All @@ -69,9 +107,9 @@ def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tenso
xv: Tuple[torch.Tensor] = torch.split(xv, self.metadata.seqlens) # type: ignore
assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}"

# Retrieve cache
cache_k = [cache_k[:seq_len] for cache_k, seq_len in zip(self.cache_k, self.kv_seqlens)]
cache_v = [cache_v[:seq_len] for cache_v, seq_len in zip(self.cache_v, self.kv_seqlens)]
# Order elements in cache by position by unrotating
cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)]
cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)]

interleaved_k = interleave_list(cache_k, list(xk))
interleaved_v = interleave_list(cache_v, list(xv))
Expand Down Expand Up @@ -112,13 +150,22 @@ def __init__(
max_seq_len: int,
n_kv_heads: int,
head_dim: int,
sliding_window: Optional[int] | Optional[List[int]] = None,
):
self.max_seq_len = max_seq_len
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.n_layers = n_layers

self.cache_sizes: List[int] = get_cache_sizes(n_layers, max_seq_len, sliding_window)
assert len(self.cache_sizes) == n_layers, f"Expected {n_layers} cache sizes, got {len(self.cache_sizes)}"

self.cache_k = {}
self.cache_v = {}
for i, cache_size in enumerate(self.cache_sizes):
self.cache_k[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim))
self.cache_v[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim))

self.cache_k = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim))
self.cache_v = torch.empty((n_layers, max_batch_size, max_seq_len, n_kv_heads, head_dim))
# holds the valid length for each batch element in the cache
self.kv_seqlens: Optional[torch.Tensor] = None

Expand All @@ -134,67 +181,82 @@ def init_kvseqlens(self, batch_size: int) -> None:

@property
def device(self) -> torch.device:
return self.cache_k.device
return self.cache_k[0].device

def to(self, device: torch.device, dtype: torch.dtype) -> "BufferCache":
self.cache_k = self.cache_k.to(device=device, dtype=dtype)
self.cache_v = self.cache_v.to(device=device, dtype=dtype)
for i in range(self.n_layers):
self.cache_k[i] = self.cache_k[i].to(device=device, dtype=dtype)
self.cache_v[i] = self.cache_v[i].to(device=device, dtype=dtype)

return self

def update_seqlens(self, seqlens: List[int]) -> None:
assert self.kv_seqlens is not None
self.kv_seqlens += torch.tensor(seqlens, device=self.device, dtype=torch.long)

def get_input_metadata(self, seqlens: List[int]) -> CacheInputMetadata:
def get_input_metadata(self, seqlens: List[int]) -> List[CacheInputMetadata]:
"""
Get metadata about cache positions
input = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3
--> only cache last 3 tokens in each sequence
- to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1]
- cached_elements = [3 | 3 | 2]
--> absolute positions are used for rope
- positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4]
--> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window
- cache_positions = [2 0 1 | 5 3 4 | 6 7]
"""
metadata: List[CacheInputMetadata] = []

if self.kv_seqlens is None:
self.init_kvseqlens(len(seqlens))

assert isinstance(self.kv_seqlens, torch.Tensor)
assert self.kv_seqlens is not None
assert len(seqlens) == len(
self.kv_seqlens
), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?"
seqpos = self.kv_seqlens.tolist()

assert len(seqlens) > 0, seqlens
cached_elements = torch.tensor(seqlens, device=self.device, dtype=torch.long)

for cache_size in self.cache_sizes:
metadata.append(self._get_input_metadata_layer(cache_size, seqlens, seqpos))

return metadata

def _get_input_metadata_layer(self, cache_size: int, seqlens: List[int], seqpos: List[int]) -> CacheInputMetadata:
masks = [[x >= seqlen - cache_size for x in range(seqlen)] for seqlen in seqlens]
to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool)
cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long)
positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(
device=self.device, dtype=torch.long
)

batch_idx = torch.tensor(
sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []),
device=self.device,
dtype=torch.long,
sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long
)
cache_positions = positions + batch_idx * self.max_seq_len

cache_positions = positions % cache_size + batch_idx * cache_size
first_prefill = seqpos[0] == 0
subsequent_prefill = any(seqlen > 1 for seqlen in seqlens)
if first_prefill:
assert all([pos == 0 for pos in seqpos]), seqpos
mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.max_seq_len)
mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(cache_size)
elif subsequent_prefill:
assert self.kv_seqlens is not None
mask = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens,
kv_seqlen=[
s + cached_s.clamp(max=self.max_seq_len).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)
s + cached_s.clamp(max=cache_size).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)
],
).make_local_attention_from_bottomright(self.max_seq_len)
).make_local_attention_from_bottomright(cache_size)
else:
mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=seqlens,
kv_padding=self.max_seq_len,
kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.max_seq_len).tolist(),
kv_padding=cache_size,
kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=cache_size).tolist(),
)

return CacheInputMetadata(
positions=positions,
cache_positions=cache_positions,
to_cache_mask=to_cache_mask,
cached_elements=cached_elements,
cache_positions=cache_positions[to_cache_mask],
prefill=first_prefill or subsequent_prefill,
mask=mask,
seqlens=seqlens,
Expand Down
1 change: 1 addition & 0 deletions src/mistral_inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def generate(
cache_window,
model.args.n_kv_heads,
model.args.head_dim,
model.args.sliding_window,
)
cache.to(device=model.device, dtype=model.dtype)
cache.reset()
Expand Down
23 changes: 17 additions & 6 deletions src/mistral_inference/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
args: TransformerArgs,
pipeline_rank: int = 0,
num_pipeline_ranks: int = 1,
softmax_fp32: bool = True,
):
super().__init__()
self.args = args
Expand All @@ -46,6 +47,8 @@ def __init__(
assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks)
self.pipeline_rank = pipeline_rank
self.num_pipeline_ranks = num_pipeline_ranks
self.softmax_fp32 = softmax_fp32

# Modules specific to some ranks:
self.tok_embeddings: Optional[nn.Embedding] = None
self.norm: Optional[RMSNorm] = None
Expand Down Expand Up @@ -150,12 +153,12 @@ def forward_partial(
(num_toks,) = input_ids.shape
assert sum(seqlens) == num_toks, (sum(seqlens), num_toks)

input_metadata: Union[CacheInputMetadata, SimpleInputMetadata]
input_metadata: List[CacheInputMetadata] | List[SimpleInputMetadata]

if cache is not None:
input_metadata = cache.get_input_metadata(seqlens)
else:
input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device)
input_metadata = [SimpleInputMetadata.from_seqlens(seqlens, self.device) for _ in range(len(self.layers))]

if self.pipeline_rank == 0:
assert self.tok_embeddings is not None
Expand All @@ -167,13 +170,15 @@ def forward_partial(
h = torch.empty(num_toks, self.args.dim, device=self.device, dtype=self.dtype)
torch.distributed.recv(h, src=self.pipeline_rank - 1)

freqs_cis = self.freqs_cis[input_metadata.positions]
# freqs_cis is always the same for every layer
freqs_cis = self.freqs_cis[input_metadata[0].positions]

for local_layer_id, layer in enumerate(self.layers.values()):
if cache is not None:
assert input_metadata is not None
assert isinstance(input_metadata, CacheInputMetadata)
cache_view = cache.get_view(local_layer_id, input_metadata)
cache_metadata = input_metadata[local_layer_id]
assert isinstance(cache_metadata, CacheInputMetadata)
cache_view = cache.get_view(local_layer_id, cache_metadata)
else:
cache_view = None
h = layer(h, freqs_cis, cache_view)
Expand Down Expand Up @@ -205,7 +210,11 @@ def forward(
outs = self.output(h)
if self.num_pipeline_ranks > 1:
torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1)
return outs.float()

if self.softmax_fp32:
return outs.float()
else:
return outs

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None:
state_to_load = {}
Expand Down Expand Up @@ -257,6 +266,7 @@ def from_folder(
num_pipeline_ranks: int = 1,
device: Union[torch.device, str] = "cuda",
dtype: Optional[torch.dtype] = None,
softmax_fp32: bool = True,
) -> "Transformer":
with open(Path(folder) / "params.json", "r") as f:
model_args = TransformerArgs.from_dict(json.load(f))
Expand All @@ -270,6 +280,7 @@ def from_folder(
model_args,
pipeline_rank=pipeline_rank,
num_pipeline_ranks=num_pipeline_ranks,
softmax_fp32=softmax_fp32,
)

pt_model_file = Path(folder) / "consolidated.00.pth"
Expand Down

0 comments on commit 6428ccf

Please sign in to comment.