Skip to content

Commit

Permalink
profile throughput without new threads (#2826)
Browse files Browse the repository at this point in the history
* profile throughput without threads

* optimize main loop

* fix torch.event

* fix python>3.11

* optimize tp

* reduce cudagraph copy

* optimize fill kv cache

* optimize silu and mul

* optimize apply rotary

* remove executor

* remove kernel

* remove num_heads==1
  • Loading branch information
grimoire authored Dec 3, 2024
1 parent 6734c71 commit 8fbfed6
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 344 deletions.
38 changes: 19 additions & 19 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import asyncio
import csv
import json
import os
import random
import time
from queue import Queue
from threading import Thread
from typing import List, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -86,23 +86,23 @@ def __init__(self, model_path: str,
self.csv = csv
self.pbar = None

def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
temperature: float, top_p: float, top_k: int,
stream_output: bool):
async def _inference(self, req_queue: Queue, res_queue: Queue,
session_id: int, temperature: float, top_p: float,
top_k: int, stream_output: bool):
model_inst = self.tm_model.create_instance()
stats = []
# get each generated token's latency
per_token_latency_stats = []
for prompt, input_seqlen, output_seqlen in iter(
req_queue.get, [None, None, None]):
req_queue.get_nowait, [None, None, None]):
_per_token_latency_stats = [0] * (output_seqlen + 1)
prev = time.perf_counter()
n_prev_token = 0

input_ids = self.tokenizer(prompt).input_ids
state = DetokenizeState(len(input_ids))

for outputs in model_inst.stream_infer(
async for outputs in model_inst.async_stream_infer(
session_id,
input_ids=input_ids,
gen_config=GenerationConfig(max_new_tokens=output_seqlen,
Expand All @@ -123,7 +123,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
prev = now
# for pytorch engine to restart a session
if isinstance(model_inst, EngineInstance):
model_inst.end(session_id)
await model_inst.async_end(session_id)
assert output_seqlen <= n_token <= output_seqlen + 1, \
f'Error. session_id({session_id}) request {output_seqlen} ' \
f'tokens, but generate {n_token} tokens.\n' \
Expand All @@ -139,13 +139,12 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
# skip the first token latency
per_token_latency_stats.append(_per_token_latency_stats[1:])
self.pbar.update(1)
res_queue.put((session_id, stats, per_token_latency_stats))
res_queue.put_nowait((session_id, stats, per_token_latency_stats))

def process_request(self, requests, concurrency, temperature, top_p, top_k,
stream_output):
res_queue = Queue()
req_queue = Queue()
threads = []

self.pbar = tqdm(total=len(requests))

Expand All @@ -157,18 +156,20 @@ def process_request(self, requests, concurrency, temperature, top_p, top_k,

start = time.time()

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

# start threads
tasks = []
for i in range(concurrency):
t = Thread(target=self._inference,
args=(req_queue, res_queue, i, temperature, top_p,
top_k, stream_output),
daemon=True)
t.start()
threads.append(t)
task = self._inference(req_queue, res_queue, i, temperature, top_p,
top_k, stream_output)
tasks.append(task)

async def _gather_tasks(tasks):
return await asyncio.gather(*tasks)

# wait for finish
for t in threads:
t.join()
event_loop.run_until_complete(_gather_tasks(tasks))

elapsed_time = time.time() - start

Expand Down Expand Up @@ -333,7 +334,6 @@ def main():
block_size=args.cache_block_seq_len,
max_batch_size=args.concurrency,
tp=args.tp,
thread_safe=True,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
quant_policy=args.quant_policy,
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def forward(
kv_seqlens = attn_metadata.kv_seqlens
kv_flatten_size = attn_metadata.kv_flatten_size
quant_policy = attn_metadata.quant_policy
max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
if attn_metadata.is_decoding:
max_q_seqlen = 1
else:
max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
fill_max_q_seqlen = max_q_seqlen
if attn_metadata.fill_seqlens is not None:
fill_seqlens = attn_metadata.fill_seqlens
Expand Down
22 changes: 16 additions & 6 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def __init__(self,
self._start_loop()
self._create_buffers()
self.engine_instance = self.create_instance()
self._output_stream = torch.cuda.Stream()

@classmethod
def from_pretrained(cls,
Expand Down Expand Up @@ -673,7 +674,8 @@ async def __long_context_single_forward(inputs):
return ret

def _make_infer_outputs(self, next_token_ids: torch.LongTensor,
logits: torch.Tensor, stopped: torch.Tensor):
logits: torch.Tensor, stopped: torch.Tensor,
event: torch.cuda.Event):
"""make infer output."""

def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence,
Expand All @@ -694,6 +696,11 @@ def __get_q_start_loc():
else:
return seq_length.cumsum(0) - seq_length

with torch.cuda.stream(self._output_stream):
event.wait()
next_token_ids = next_token_ids.cpu()
stopped = stopped.cpu()

running = self._running
is_run = [seq.status == MessageStatus.RUNNING for seq in running]
stopped = stopped.tolist()
Expand Down Expand Up @@ -755,6 +762,8 @@ def __update_inputs(next_token_ids):
logger.debug('<ForwardTask>: '
f'batch_size={inputs.seq_length.size(0)} '
f'num_tokens={inputs.input_ids.size(-1)}')
if self.gpu_count == 1:
inputs = inputs.to_device('cuda')
is_decoding = inputs.is_decoding
if all_ids is not None:
all_ids = all_ids.cuda()
Expand Down Expand Up @@ -785,10 +794,11 @@ def __update_inputs(next_token_ids):
next_token_ids, sampling_inputs.stop_words, num_appendable_ids)

# send output
stopped = stopped.cpu()
finish = stopped.all().item() or (idx == loop_count - 1)
finish = (idx == loop_count - 1)
finish = finish or _check_finish(self.scheduler, idx)
output = (next_token_ids.cpu(), logits, stopped)
event = torch.cuda.Event()
event.record()
output = (next_token_ids, logits, stopped, event)
output_que.put_nowait((finish, output))

if finish:
Expand Down Expand Up @@ -951,9 +961,9 @@ async def __step():
try:
if isinstance(out, Exception):
raise out
next_token_ids, logits, stopped = out
next_token_ids, logits, stopped, event = out
step_outputs = self._make_infer_outputs(
next_token_ids, logits, stopped)
next_token_ids, logits, stopped, event)
__send_resps(step_outputs)
except Exception as e:
raise e
Expand Down
30 changes: 20 additions & 10 deletions lmdeploy/pytorch/engine/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ def _process_temperature_(scores: torch.Tensor, temperature: torch.Tensor):

def _process_bad_words_(scores: torch.Tensor,
bad_words: torch.LongTensor,
mask: torch.BoolTensor,
filter_value: float = -float('inf')):
"""process bad words."""
mask = bad_words >= 0
bad_words = bad_words.where(mask, 0)
filtered_scores = scores.gather(1, bad_words)
filtered_scores[mask] = filter_value
scores.scatter_(1, bad_words, filtered_scores)
Expand Down Expand Up @@ -127,7 +126,9 @@ def _guided_sampling(response_formats: Tuple[Dict], scores: torch.Tensor,
class SamplingInputs:
temperature: torch.Tensor = None
bad_words: torch.LongTensor = None
bad_mask: torch.BoolTensor = None
stop_words: torch.LongTensor = None
stop_mask: torch.BoolTensor = None
repetition_penalty: torch.Tensor = None
top_k: torch.LongTensor = None
top_p: torch.Tensor = None
Expand Down Expand Up @@ -200,17 +201,22 @@ def __get_bad_words(bad_words):
"""get bad words."""
max_bw_len = max(len(bw) for bw in bad_words)
if max_bw_len == 0:
return None
return None, None
if all(len(bw) == max_bw_len for bw in bad_words):
return torch.tensor(bad_words)
ret = torch.tensor(bad_words)
mask = torch.ones_like(ret, dtype=bool)
return ret, mask
ret = torch.full((batch_size, max_bw_len), -1, dtype=torch.int64)
for idx, bw in enumerate(bad_words):
bw_len = len(bw)
if bw_len == 0:
continue
bw = ret.new_tensor(bw)
ret[idx, :bw_len] = bw
return ret

mask = ret >= 0
ret = ret.where(mask, 0)
return ret, mask

__gather_params()

Expand All @@ -221,8 +227,8 @@ def __get_bad_words(bad_words):

temperature = torch.tensor(temperature)

bad_words = __get_bad_words(bad_words)
stop_words = __get_bad_words(stop_words)
bad_words, bad_mask = __get_bad_words(bad_words)
stop_words, stop_mask = __get_bad_words(stop_words)

max_top_k = max(top_k)
if min(top_k) <= 0:
Expand All @@ -243,7 +249,9 @@ def __get_bad_words(bad_words):
sampling_input = cls(
temperature=temperature,
bad_words=bad_words,
bad_mask=bad_mask,
stop_words=stop_words,
stop_mask=stop_mask,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
Expand Down Expand Up @@ -326,12 +334,14 @@ def __call__(self, all_ids: torch.LongTensor,

bad_words = sampling_inputs.bad_words
if bad_words is not None:
scores = _process_bad_words_(scores, bad_words)
bad_mask = sampling_inputs.bad_mask
scores = _process_bad_words_(scores, bad_words, bad_mask)

stop_words = sampling_inputs.stop_words
if stop_words is not None:
stop_words = torch.where(self.ignore_eos[:, None], stop_words, -1)
scores = _process_bad_words_(scores, stop_words)
stop_mask = sampling_inputs.stop_mask
stop_mask = torch.where(self.ignore_eos[:, None], stop_mask, False)
scores = _process_bad_words_(scores, stop_words, stop_mask)

scores = _guided_sampling(sampling_inputs.response_formats, scores,
guided_input_ids, self.tokenizer)
Expand Down
65 changes: 6 additions & 59 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,6 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig):
self.model_config = model_config
self.cache_config = cache_config

def get_block_numel(self):
"""get block nelement."""
raise NotImplementedError('Not implemented')

async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""model forward.
Expand All @@ -177,17 +173,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
"""
raise NotImplementedError('Not implemented.')

def forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
raise NotImplementedError('Not implemented.')

def get_logits(self, hidden_states: torch.Tensor):
"""get logits of model output."""
raise NotImplementedError('Not implemented.')
Expand Down Expand Up @@ -255,11 +240,6 @@ def _build_model(self,
device=device)
return patched_model

def get_block_numel(self):
"""get block nelement."""
k_cache = self.cache_engine.local_gpu_cache[0][0]
return k_cache[0].numel()

def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
cache_swapping(self.cache_engine,
Expand All @@ -274,21 +254,6 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap,
)
return output

def forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
output = self._forward_impl(inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
self.stream.synchronize()
return output

async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""model forward.
Expand All @@ -301,8 +266,9 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
output = self._forward_impl(inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
await asyncio.get_event_loop().run_in_executor(None,
self.stream.synchronize)
await asyncio.sleep(0)
while not self.stream.query():
await asyncio.sleep(0)
return output

def get_logits(self, hidden_states: torch.Tensor):
Expand Down Expand Up @@ -688,11 +654,6 @@ def _build_model(

return model, cache_engine, cache_config

def get_block_numel(self):
"""get block nelement."""
k_cache = self.cache_engine.local_gpu_cache[0][0]
return k_cache[0].numel()

def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""forward impl."""
Expand All @@ -713,21 +674,6 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap,
)
return output

def forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
output = self._forward_impl(inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
self.stream.synchronize()
return output

async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""model forward.
Expand All @@ -740,8 +686,9 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
output = self._forward_impl(inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
await asyncio.get_event_loop().run_in_executor(None,
self.stream.synchronize)
await asyncio.sleep(0)
while not self.stream.query():
await asyncio.sleep(0)
return output

def get_logits(self, hidden_states: torch.Tensor):
Expand Down
Loading

0 comments on commit 8fbfed6

Please sign in to comment.