Skip to content

Commit

Permalink
Completely refactor and test YaRN finetuning (#78)
Browse files Browse the repository at this point in the history
* add yarn finetuning

* fixing type hints

* yarn refactor

* fix bugs

* bug and format

* bug fix

* in-place RoPE; major fix for YaRN

* format

* clean up; remove redundant  because we stick to NeoX style.

* dropping too long error because we do YaRN

* update config

* fix yarn

* fix yarn; change some defaults

* revamp kv cache

* fixing sample function; removing max length constraint (dynamic yarn can go a little further)

* format

* minor bug

* fixing batch size

* fix cache device

* fix attention mask

* bug fix

* bug fix

* revamp sampling code; refactor kv cache

* format

* fix mask

* fix bugs

* format

* fix typo on mscale default and dynamic scaling

* format
  • Loading branch information
honglu2875 authored Dec 13, 2023
1 parent 28960f2 commit 8205d85
Show file tree
Hide file tree
Showing 12 changed files with 725 additions and 456 deletions.
67 changes: 67 additions & 0 deletions aria/model/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Optional

import torch


class KVCache(torch.nn.Module):
def __init__(
self, max_batch_size, n_head, d_head, dtype=torch.float16, max_size=8192
):
super().__init__()
self.shape = (max_batch_size, max_size, n_head, d_head)
self.register_buffer(
"k_cache", torch.empty(self.shape, dtype=dtype), persistent=False
)
self.register_buffer(
"v_cache", torch.empty(self.shape, dtype=dtype), persistent=False
)
self.next_pos = 0

def update(
self,
k,
v,
pos: Optional[torch.Tensor] = None,
start_pos: int = 0,
max_pos: Optional[int] = None,
):
"""
Update the kv cache and return the new k, v sequences of vectors
Args:
k: key to update. Shape: (batch_size, num_positions, n_head, d_head)
v: value to update. Shape: (batch_size, num_positions, n_head, d_head)
pos: positions to update. Shape: (num_positions,).
Example: None to append to the end of the cache.
[0, 1, 2, 3, 4] to update the first 5 positions.
[5] to only update the 6th position.
start_pos: the starting position of the cache. Default to 0
max_pos: the maximum position to update. Default to None.
Only used when pos is *NOT* None. Can be inferred from pos.max(),
but such an operation causes a sync with massive overhead
due to dynamic shape.
"""
if pos is None:
self.k_cache[
: k.size(0), self.next_pos : self.next_pos + k.size(1)
] = k
self.v_cache[
: v.size(0), self.next_pos : self.next_pos + v.size(1)
] = v
self.next_pos += k.size(1)
else:
assert pos.size(0) == k.size(1)
assert max_pos is not None, (
"Need to pass in `pos.max()` explicitly. "
"Doing `pos.max()` creates massive overhead."
)
self.k_cache[: k.size(0), pos] = k
self.v_cache[: v.size(0), pos] = v
# Update next_pos using the max entry.
# Note: `self.next_pos = pos.max() + 1` could have worked, but it
# causes the shape to be dynamic and creates a massive overhead.
self.next_pos = max_pos + 1
return (
self.k_cache[: k.size(0), start_pos : self.next_pos],
self.v_cache[: v.size(0), start_pos : self.next_pos],
)
165 changes: 0 additions & 165 deletions aria/model/dynamic_yarn.py

This file was deleted.

Loading

0 comments on commit 8205d85

Please sign in to comment.