Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
yukang2017 authored Oct 18, 2023
1 parent 8edbeec commit a947cfe
Show file tree
Hide file tree
Showing 9 changed files with 966 additions and 0 deletions.
115 changes: 115 additions & 0 deletions run_streaming_llama_longalpaca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import warnings

warnings.filterwarnings("ignore")

import torch
import argparse
import json
import os
import time
import re
import sys

from tqdm import tqdm
from streaming_llm.utils import load, download_url, load_jsonl
from streaming_llm.enable_streaming_llm import enable_streaming_llm


@torch.no_grad()
def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids = [pred_token_idx.item()]
pos = 0
for _ in range(max_gen_len - 1):
outputs = model(
input_ids=pred_token_idx,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids.append(pred_token_idx.item())
generated_text = (
tokenizer.decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
spaces_between_special_tokens=False,
)
.strip()
.split(" ")
)

now = len(generated_text) - 1
if now > pos:
print(" ".join(generated_text[pos:now]), end=" ", flush=True)
pos = now

if pred_token_idx == tokenizer.eos_token_id:
break
print(" ".join(generated_text[pos:]), flush=True)
return past_key_values


@torch.no_grad()
def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=1000):
past_key_values = None
for idx, prompt in enumerate(prompts):
prompt = "USER: " + prompt + "\n\nASSISTANT: "
print("\n" + prompt, end="")
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
seq_len = input_ids.shape[1]
if kv_cache is not None:
space_needed = seq_len + max_gen_len
past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)

past_key_values = greedy_generate(
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
)


def main(args):
model_name_or_path = args.model_name_or_path
model, tokenizer = load(model_name_or_path)
print(f"Loading data from {args.test_filepath} ...")

list_data = json.load(open(args.test_filepath))
prompts = []
for sample in list_data:
prompts += [sample["instruction"]]

if args.enable_streaming:
kv_cache = enable_streaming_llm(
model, start_size=args.start_size, recent_size=args.recent_size, use_flash_attn=args.use_flash_attn
)
else:
kv_cache = None

streaming_inference(
model,
tokenizer,
prompts,
kv_cache,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path", type=str, default="Yukang/LongAlpaca-7B"
)
parser.add_argument("--test_filepath", type=str, default="outputs_stream.json")
parser.add_argument("--enable_streaming", action="store_true")
parser.add_argument("--start_size", type=int, default=4)
parser.add_argument("--recent_size", type=int, default=8192)
parser.add_argument("--use_flash_attn", type=bool, default=True)
args = parser.parse_args()

main(args)
Empty file added streaming_llm/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions streaming_llm/enable_streaming_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from streaming_llm.kv_cache import StartRecentKVCache


def enable_streaming_llm(model, start_size, recent_size, use_flash_attn=True):
if "llama" in model.config.model_type:
k_seq_dim = v_seq_dim = 2
from streaming_llm.pos_shift.modify_llama import (
enable_llama_pos_shift_attention,
)

enable_llama_pos_shift_attention(model, use_flash_attn)
elif "mpt" in model.config.model_type:
v_seq_dim = 2
k_seq_dim = 3
elif "gpt_neox" in model.config.model_type:
k_seq_dim = v_seq_dim = 2
from streaming_llm.pos_shift.modify_gpt_neox import (
enable_gpt_neox_pos_shift_attention,
)

enable_gpt_neox_pos_shift_attention(model)
elif "falcon" in model.config.model_type:
v_seq_dim = 1
k_seq_dim = 1
from streaming_llm.pos_shift.modify_falcon import (
enable_falcon_pos_shift_attention,
)

enable_falcon_pos_shift_attention(model)
else:
raise ValueError(f"got {model.config.model_type}")
kv_cache = StartRecentKVCache(
start_size=start_size,
recent_size=recent_size,
k_seq_dim=k_seq_dim,
v_seq_dim=v_seq_dim,
)
return kv_cache
119 changes: 119 additions & 0 deletions streaming_llm/kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch


def slice2d(x, start, end):
return x[:, :, start:end, ...]


def slice3d(x, start, end):
return x[:, :, :, start:end, ...]


def slice1d(x, start, end):
return x[:, start:end, ...]


DIM_TO_SLICE = {
1: slice1d,
2: slice2d,
3: slice3d,
}


class StartRecentKVCache:
def __init__(
self,
start_size=4,
recent_size=512,
k_seq_dim=2,
v_seq_dim=2,
):
print(f"StartRecentKVCache: {start_size}, {recent_size}")
self.start_size = start_size
self.recent_size = recent_size
self.cache_size = start_size + recent_size
self.k_seq_dim = k_seq_dim
self.v_seq_dim = v_seq_dim
self.k_slice = DIM_TO_SLICE[k_seq_dim]
self.v_slice = DIM_TO_SLICE[v_seq_dim]

def __call__(self, past_key_values):
if past_key_values is None:
return None
seq_len = past_key_values[0][0].size(self.k_seq_dim)
if seq_len <= self.cache_size:
return past_key_values
return [
[
torch.cat(
[
self.k_slice(k, 0, self.start_size),
self.k_slice(k, seq_len - self.recent_size, seq_len),
],
dim=self.k_seq_dim,
),
torch.cat(
[
self.v_slice(v, 0, self.start_size),
self.v_slice(v, seq_len - self.recent_size, seq_len),
],
dim=self.v_seq_dim,
),
]
for k, v in past_key_values
]

def evict_for_space(self, past_key_values, num_coming):
if past_key_values is None:
return None
seq_len = past_key_values[0][0].size(self.k_seq_dim)
if seq_len + num_coming <= self.cache_size:
return past_key_values
return [
[
torch.cat(
[
self.k_slice(k, 0, self.start_size),
self.k_slice(
k, seq_len - self.recent_size + num_coming, seq_len
),
],
dim=self.k_seq_dim,
),
torch.cat(
[
self.v_slice(v, 0, self.start_size),
self.v_slice(
v, seq_len - self.recent_size + num_coming, seq_len
),
],
dim=self.v_seq_dim,
),
]
for k, v in past_key_values
]

def evict_range(self, past_key_values, start, end):
if past_key_values is None:
return None
seq_len = past_key_values[0][0].size(self.k_seq_dim)
assert start <= end and end <= seq_len
return [
[
torch.cat(
[
self.k_slice(k, 0, start),
self.k_slice(k, end, seq_len),
],
dim=self.k_seq_dim,
),
torch.cat(
[
self.v_slice(v, 0, start),
self.v_slice(v, end, seq_len),
],
dim=self.v_seq_dim,
),
]
for k, v in past_key_values
]
Empty file.
Loading

0 comments on commit a947cfe

Please sign in to comment.