diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 58786d9c80..4f06fad4f9 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import asyncio import csv import json import os import random import time from queue import Queue -from threading import Thread from typing import List, Tuple, Union import numpy as np @@ -86,15 +86,15 @@ def __init__(self, model_path: str, self.csv = csv self.pbar = None - def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, - temperature: float, top_p: float, top_k: int, - stream_output: bool): + async def _inference(self, req_queue: Queue, res_queue: Queue, + session_id: int, temperature: float, top_p: float, + top_k: int, stream_output: bool): model_inst = self.tm_model.create_instance() stats = [] # get each generated token's latency per_token_latency_stats = [] for prompt, input_seqlen, output_seqlen in iter( - req_queue.get, [None, None, None]): + req_queue.get_nowait, [None, None, None]): _per_token_latency_stats = [0] * (output_seqlen + 1) prev = time.perf_counter() n_prev_token = 0 @@ -102,7 +102,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, input_ids = self.tokenizer(prompt).input_ids state = DetokenizeState(len(input_ids)) - for outputs in model_inst.stream_infer( + async for outputs in model_inst.async_stream_infer( session_id, input_ids=input_ids, gen_config=GenerationConfig(max_new_tokens=output_seqlen, @@ -123,7 +123,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, prev = now # for pytorch engine to restart a session if isinstance(model_inst, EngineInstance): - model_inst.end(session_id) + await model_inst.async_end(session_id) assert output_seqlen <= n_token <= output_seqlen + 1, \ f'Error. session_id({session_id}) request {output_seqlen} ' \ f'tokens, but generate {n_token} tokens.\n' \ @@ -139,13 +139,12 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, # skip the first token latency per_token_latency_stats.append(_per_token_latency_stats[1:]) self.pbar.update(1) - res_queue.put((session_id, stats, per_token_latency_stats)) + res_queue.put_nowait((session_id, stats, per_token_latency_stats)) def process_request(self, requests, concurrency, temperature, top_p, top_k, stream_output): res_queue = Queue() req_queue = Queue() - threads = [] self.pbar = tqdm(total=len(requests)) @@ -157,18 +156,20 @@ def process_request(self, requests, concurrency, temperature, top_p, top_k, start = time.time() + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + # start threads + tasks = [] for i in range(concurrency): - t = Thread(target=self._inference, - args=(req_queue, res_queue, i, temperature, top_p, - top_k, stream_output), - daemon=True) - t.start() - threads.append(t) + task = self._inference(req_queue, res_queue, i, temperature, top_p, + top_k, stream_output) + tasks.append(task) + + async def _gather_tasks(tasks): + return await asyncio.gather(*tasks) - # wait for finish - for t in threads: - t.join() + event_loop.run_until_complete(_gather_tasks(tasks)) elapsed_time = time.time() - start @@ -333,7 +334,6 @@ def main(): block_size=args.cache_block_seq_len, max_batch_size=args.concurrency, tp=args.tp, - thread_safe=True, eager_mode=args.eager_mode, enable_prefix_caching=args.enable_prefix_caching, quant_policy=args.quant_policy, diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index d01d6fe9b4..8261b869f0 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -94,7 +94,10 @@ def forward( kv_seqlens = attn_metadata.kv_seqlens kv_flatten_size = attn_metadata.kv_flatten_size quant_policy = attn_metadata.quant_policy - max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) + if attn_metadata.is_decoding: + max_q_seqlen = 1 + else: + max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) fill_max_q_seqlen = max_q_seqlen if attn_metadata.fill_seqlens is not None: fill_seqlens = attn_metadata.fill_seqlens diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 26b507e9d4..b7a803a7a7 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -172,6 +172,7 @@ def __init__(self, self._start_loop() self._create_buffers() self.engine_instance = self.create_instance() + self._output_stream = torch.cuda.Stream() @classmethod def from_pretrained(cls, @@ -673,7 +674,8 @@ async def __long_context_single_forward(inputs): return ret def _make_infer_outputs(self, next_token_ids: torch.LongTensor, - logits: torch.Tensor, stopped: torch.Tensor): + logits: torch.Tensor, stopped: torch.Tensor, + event: torch.cuda.Event): """make infer output.""" def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence, @@ -694,6 +696,11 @@ def __get_q_start_loc(): else: return seq_length.cumsum(0) - seq_length + with torch.cuda.stream(self._output_stream): + event.wait() + next_token_ids = next_token_ids.cpu() + stopped = stopped.cpu() + running = self._running is_run = [seq.status == MessageStatus.RUNNING for seq in running] stopped = stopped.tolist() @@ -755,6 +762,8 @@ def __update_inputs(next_token_ids): logger.debug(': ' f'batch_size={inputs.seq_length.size(0)} ' f'num_tokens={inputs.input_ids.size(-1)}') + if self.gpu_count == 1: + inputs = inputs.to_device('cuda') is_decoding = inputs.is_decoding if all_ids is not None: all_ids = all_ids.cuda() @@ -785,10 +794,11 @@ def __update_inputs(next_token_ids): next_token_ids, sampling_inputs.stop_words, num_appendable_ids) # send output - stopped = stopped.cpu() - finish = stopped.all().item() or (idx == loop_count - 1) + finish = (idx == loop_count - 1) finish = finish or _check_finish(self.scheduler, idx) - output = (next_token_ids.cpu(), logits, stopped) + event = torch.cuda.Event() + event.record() + output = (next_token_ids, logits, stopped, event) output_que.put_nowait((finish, output)) if finish: @@ -951,9 +961,9 @@ async def __step(): try: if isinstance(out, Exception): raise out - next_token_ids, logits, stopped = out + next_token_ids, logits, stopped, event = out step_outputs = self._make_infer_outputs( - next_token_ids, logits, stopped) + next_token_ids, logits, stopped, event) __send_resps(step_outputs) except Exception as e: raise e diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 54740a4fb3..24cb336d71 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -21,10 +21,9 @@ def _process_temperature_(scores: torch.Tensor, temperature: torch.Tensor): def _process_bad_words_(scores: torch.Tensor, bad_words: torch.LongTensor, + mask: torch.BoolTensor, filter_value: float = -float('inf')): """process bad words.""" - mask = bad_words >= 0 - bad_words = bad_words.where(mask, 0) filtered_scores = scores.gather(1, bad_words) filtered_scores[mask] = filter_value scores.scatter_(1, bad_words, filtered_scores) @@ -127,7 +126,9 @@ def _guided_sampling(response_formats: Tuple[Dict], scores: torch.Tensor, class SamplingInputs: temperature: torch.Tensor = None bad_words: torch.LongTensor = None + bad_mask: torch.BoolTensor = None stop_words: torch.LongTensor = None + stop_mask: torch.BoolTensor = None repetition_penalty: torch.Tensor = None top_k: torch.LongTensor = None top_p: torch.Tensor = None @@ -200,9 +201,11 @@ def __get_bad_words(bad_words): """get bad words.""" max_bw_len = max(len(bw) for bw in bad_words) if max_bw_len == 0: - return None + return None, None if all(len(bw) == max_bw_len for bw in bad_words): - return torch.tensor(bad_words) + ret = torch.tensor(bad_words) + mask = torch.ones_like(ret, dtype=bool) + return ret, mask ret = torch.full((batch_size, max_bw_len), -1, dtype=torch.int64) for idx, bw in enumerate(bad_words): bw_len = len(bw) @@ -210,7 +213,10 @@ def __get_bad_words(bad_words): continue bw = ret.new_tensor(bw) ret[idx, :bw_len] = bw - return ret + + mask = ret >= 0 + ret = ret.where(mask, 0) + return ret, mask __gather_params() @@ -221,8 +227,8 @@ def __get_bad_words(bad_words): temperature = torch.tensor(temperature) - bad_words = __get_bad_words(bad_words) - stop_words = __get_bad_words(stop_words) + bad_words, bad_mask = __get_bad_words(bad_words) + stop_words, stop_mask = __get_bad_words(stop_words) max_top_k = max(top_k) if min(top_k) <= 0: @@ -243,7 +249,9 @@ def __get_bad_words(bad_words): sampling_input = cls( temperature=temperature, bad_words=bad_words, + bad_mask=bad_mask, stop_words=stop_words, + stop_mask=stop_mask, repetition_penalty=repetition_penalty, top_k=top_k, top_p=top_p, @@ -326,12 +334,14 @@ def __call__(self, all_ids: torch.LongTensor, bad_words = sampling_inputs.bad_words if bad_words is not None: - scores = _process_bad_words_(scores, bad_words) + bad_mask = sampling_inputs.bad_mask + scores = _process_bad_words_(scores, bad_words, bad_mask) stop_words = sampling_inputs.stop_words if stop_words is not None: - stop_words = torch.where(self.ignore_eos[:, None], stop_words, -1) - scores = _process_bad_words_(scores, stop_words) + stop_mask = sampling_inputs.stop_mask + stop_mask = torch.where(self.ignore_eos[:, None], stop_mask, False) + scores = _process_bad_words_(scores, stop_words, stop_mask) scores = _guided_sampling(sampling_inputs.response_formats, scores, guided_input_ids, self.tokenizer) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 2877f59375..59d77f264a 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -162,10 +162,6 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig): self.model_config = model_config self.cache_config = cache_config - def get_block_numel(self): - """get block nelement.""" - raise NotImplementedError('Not implemented') - async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. @@ -177,17 +173,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, """ raise NotImplementedError('Not implemented.') - def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, - swap_out_map: SwapMap): - """model forward. - - Args: - inputs (Dict): The input data comes from _make_inputs. - swap_in_map (SwapMap): Cache maps to swap in. - swap_out_map (SwapMap): Cache maps to swap out. - """ - raise NotImplementedError('Not implemented.') - def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" raise NotImplementedError('Not implemented.') @@ -255,11 +240,6 @@ def _build_model(self, device=device) return patched_model - def get_block_numel(self): - """get block nelement.""" - k_cache = self.cache_engine.local_gpu_cache[0][0] - return k_cache[0].numel() - def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): cache_swapping(self.cache_engine, @@ -274,21 +254,6 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, ) return output - def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, - swap_out_map: SwapMap): - """model forward. - - Args: - inputs (Dict): The input data comes from _make_inputs. - swap_in_map (SwapMap): Cache maps to swap in. - swap_out_map (SwapMap): Cache maps to swap out. - """ - output = self._forward_impl(inputs, - swap_in_map=swap_in_map, - swap_out_map=swap_out_map) - self.stream.synchronize() - return output - async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. @@ -301,8 +266,9 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map) - await asyncio.get_event_loop().run_in_executor(None, - self.stream.synchronize) + await asyncio.sleep(0) + while not self.stream.query(): + await asyncio.sleep(0) return output def get_logits(self, hidden_states: torch.Tensor): @@ -688,11 +654,6 @@ def _build_model( return model, cache_engine, cache_config - def get_block_numel(self): - """get block nelement.""" - k_cache = self.cache_engine.local_gpu_cache[0][0] - return k_cache[0].numel() - def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """forward impl.""" @@ -713,21 +674,6 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, ) return output - def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, - swap_out_map: SwapMap): - """model forward. - - Args: - inputs (Dict): The input data comes from _make_inputs. - swap_in_map (SwapMap): Cache maps to swap in. - swap_out_map (SwapMap): Cache maps to swap out. - """ - output = self._forward_impl(inputs, - swap_in_map=swap_in_map, - swap_out_map=swap_out_map) - self.stream.synchronize() - return output - async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. @@ -740,8 +686,9 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map) - await asyncio.get_event_loop().run_in_executor(None, - self.stream.synchronize) + await asyncio.sleep(0) + while not self.stream.query(): + await asyncio.sleep(0) return output def get_logits(self, hidden_states: torch.Tensor): diff --git a/lmdeploy/pytorch/kernels/cuda/activation.py b/lmdeploy/pytorch/kernels/cuda/activation.py index 2533840a95..9a00e7354f 100644 --- a/lmdeploy/pytorch/kernels/cuda/activation.py +++ b/lmdeploy/pytorch/kernels/cuda/activation.py @@ -7,10 +7,8 @@ TRITON_VERSION = version.parse(triton.__version__) if TRITON_VERSION >= version.parse('3.0.0'): - fast_expf = tl.math.exp else: - tanh = tl.math.tanh fast_expf = tl.math.fast_expf @@ -26,63 +24,29 @@ def _silu_and_mul_kernel( BLOCK_SIZE_N: tl.constexpr, ): """silu and mul kernel.""" - m_id = tl.program_id(0) + n_block_id = tl.program_id(0) + m_id = tl.program_id(1) up_ptr = gateup_ptr + N * stride_gun - offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_n = n_block_id * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) gate_ptrs = gateup_ptr + m_id * stride_gum + offs_n * stride_gun up_ptrs = up_ptr + m_id * stride_gum + offs_n * stride_gun out_ptrs = out_ptr + m_id * stride_om + offs_n * stride_on - for _ in range(0, N, BLOCK_SIZE_N): - gate = tl.load(gate_ptrs).to(tl.float32) - up = tl.load(up_ptrs).to(tl.float32) - - gate = gate / (1 + fast_expf(-gate)) - out = gate * up - - tl.store(out_ptrs, out) - - gate_ptrs += BLOCK_SIZE_N * stride_gun - up_ptrs += BLOCK_SIZE_N * stride_gun - out_ptrs += BLOCK_SIZE_N * stride_on - - -@triton.jit -def _silu_and_mul_no_align_kernel( - gateup_ptr, - out_ptr, - N: tl.constexpr, - stride_gum: tl.constexpr, - stride_gun: tl.constexpr, - stride_om: tl.constexpr, - stride_on: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, -): - """silu and mul kernel.""" - m_id = tl.program_id(0) - - up_ptr = gateup_ptr + N * stride_gun - - offs_n = tl.arange(0, BLOCK_SIZE_N) - gate_ptrs = gateup_ptr + m_id * stride_gum + offs_n * stride_gun - up_ptrs = up_ptr + m_id * stride_gum + offs_n * stride_gun - out_ptrs = out_ptr + m_id * stride_om + offs_n * stride_on - - for n in range(0, N, BLOCK_SIZE_N): - mask = n + offs_n < N - gate = tl.load(gate_ptrs, mask=mask).to(tl.float32) - up = tl.load(up_ptrs, mask=mask).to(tl.float32) - - gate = gate / (1 + fast_expf(-gate)) - out = gate * up + if N % BLOCK_SIZE_N == 0: + mask = None + else: + mask = offs_n < N + gate = tl.load(gate_ptrs, mask=mask) + up = tl.load(up_ptrs, mask=mask) + gate = gate.to(tl.float32) + up = up.to(tl.float32) - tl.store(out_ptrs, out, mask=mask) + gate = gate / (1 + fast_expf(-gate)) + out = gate * up - gate_ptrs += BLOCK_SIZE_N * stride_gun - up_ptrs += BLOCK_SIZE_N * stride_gun - out_ptrs += BLOCK_SIZE_N * stride_on + tl.store(out_ptrs, out, mask=mask) def silu_and_mul(gate_up: torch.Tensor, out: torch.Tensor = None): @@ -96,31 +60,22 @@ def silu_and_mul(gate_up: torch.Tensor, out: torch.Tensor = None): out = gate_up.new_empty(out_shape) BLOCK_SIZE_N = triton.next_power_of_2(N) - BLOCK_SIZE_N = min(BLOCK_SIZE_N, 1024) + BLOCK_SIZE_N = min(BLOCK_SIZE_N, 512) num_warps = 4 - num_stages = 2 - grid = (M, ) - if N % BLOCK_SIZE_N == 0: - _silu_and_mul_kernel[grid](gate_up, - out, - N, - stride_gum=gate_up.stride(0), - stride_gun=gate_up.stride(1), - stride_om=out.stride(0), - stride_on=out.stride(1), - BLOCK_SIZE_N=BLOCK_SIZE_N, - num_warps=num_warps, - num_stages=num_stages) - else: - _silu_and_mul_no_align_kernel[grid](gate_up, - out, - N, - stride_gum=gate_up.stride(0), - stride_gun=gate_up.stride(1), - stride_om=out.stride(0), - stride_on=out.stride(1), - BLOCK_SIZE_N=BLOCK_SIZE_N, - num_warps=num_warps, - num_stages=num_stages) + num_stages = 1 + grid = ( + triton.cdiv(N, BLOCK_SIZE_N), + M, + ) + _silu_and_mul_kernel[grid](gate_up, + out, + N, + stride_gum=gate_up.stride(0), + stride_gun=gate_up.stride(1), + stride_om=out.stride(0), + stride_on=out.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + num_warps=num_warps, + num_stages=num_stages) return out diff --git a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py index 9e14dc6a0c..f9d5f2f171 100644 --- a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py @@ -4,35 +4,9 @@ import triton.language as tl from torch import Tensor -from .triton_utils import get_kernel_meta, wrap_jit_func - - -@wrap_jit_func(type_hint=dict( - Q=Tensor, - K=Tensor, - COS=Tensor, - SIN=Tensor, - POS=Tensor, - Q_EMB=Tensor, - K_EMB=Tensor, - seq_len=int, - stride_qs=int, - stride_qh=int, - stride_qd=int, - stride_ks=int, - stride_kh=int, - stride_kd=int, - stride_qes=int, - stride_qeh=int, - stride_qed=int, - stride_kes=int, - stride_keh=int, - stride_ked=int, - half_size=torch.int32, - BLOCK=torch.int32, - BLOCK_QH=torch.int32, - BLOCK_N=torch.int32, -)) +from .triton_utils import get_kernel_meta + + @triton.jit(do_not_specialize=('seq_len', )) def apply_rotary_pos_emb_qk_kernel( Q, @@ -60,8 +34,8 @@ def apply_rotary_pos_emb_qk_kernel( BLOCK_N: tl.constexpr, ): """apply rotary on key AND query kernel.""" - seq_block_id = tl.program_id(0) - head_id = tl.program_id(1) + seq_block_id = tl.program_id(1) + head_id = tl.program_id(0) pos_offset = seq_block_id * BLOCK + tl.arange(0, BLOCK) pos_mask = pos_offset < seq_len @@ -158,10 +132,13 @@ def apply_rotary_pos_emb(q: Tensor, num_heads_q = q.size(-2) num_heads_k = k.size(-2) num_warps = 4 - num_stages = 4 + num_stages = 1 kernel_meta = get_kernel_meta(q) - grid = [triton.cdiv(seq_len, BLOCK), num_heads_q + num_heads_k] + grid = [ + num_heads_q + num_heads_k, + triton.cdiv(seq_len, BLOCK), + ] apply_rotary_pos_emb_qk_kernel[grid](q, k, cos, diff --git a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py index 9ef614fadd..93bd89f488 100644 --- a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py @@ -1,12 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Literal -import torch import triton import triton.language as tl from torch import Tensor -from .triton_utils import get_kernel_meta, wrap_jit_func +from .triton_utils import get_kernel_meta @triton.jit @@ -38,37 +37,6 @@ def _quant_int4(val1, val2): return q_val, scales, zeros -@wrap_jit_func(type_hint=dict( - KStates=Tensor, - VStates=Tensor, - KCaches=Tensor, - VCaches=Tensor, - QStartLoc=Tensor, - QSeqLens=Tensor, - KVSeqLens=Tensor, - BlockOffsets=Tensor, - num_heads=torch.int32, - head_dim=torch.int32, - stride_kss=int, - stride_ksh=int, - stride_ksd=int, - stride_vss=int, - stride_vsh=int, - stride_vsd=int, - stride_kcn=int, - stride_kcb=int, - stride_kch=int, - stride_kcd=int, - stride_vcn=int, - stride_vcb=int, - stride_vch=int, - stride_vcd=int, - stride_boff=int, - BLOCK=torch.int32, - BLOCK_D=torch.int32, - BLOCK_DV=torch.int32, - BLOCK_H=torch.int32, -)) @triton.jit def _fill_kv_cache_kernel( KStates, @@ -79,7 +47,7 @@ def _fill_kv_cache_kernel( QSeqLens, KVSeqLens, BlockOffsets, - num_heads: tl.constexpr, + is_decoding: tl.constexpr, head_dim: tl.constexpr, head_dim_v: tl.constexpr, stride_kss, @@ -100,108 +68,70 @@ def _fill_kv_cache_kernel( BLOCK: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_DV: tl.constexpr, - BLOCK_H: tl.constexpr, ): """fill kv cache kernel.""" - batch_id = tl.program_id(0) + batch_id = tl.program_id(2) + head_id = tl.program_id(0) block_id = tl.program_id(1) - # initialize - h_off = tl.arange(0, BLOCK_H) - d_off = tl.arange(0, BLOCK_D) - q_startloc = tl.load(QStartLoc + batch_id) q_seqlen = tl.load(QSeqLens + batch_id) kv_seqlen = tl.load(KVSeqLens + batch_id) history_seqlen = kv_seqlen - q_seqlen - block0_first_tokenloc = history_seqlen % BLOCK - - state_token_offset = tl.maximum(block_id * BLOCK - block0_first_tokenloc, - 0) - kv_block_id = _div_up(history_seqlen + 1, BLOCK) - 1 + block_id - kv_block_id = min(kv_block_id, stride_boff - 1) - block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id) + kv_block_id = history_seqlen // BLOCK + block_id - cur_startloc = q_startloc + state_token_offset - ks_ptr = KStates + cur_startloc * stride_kss - vs_ptr = VStates + cur_startloc * stride_vss + if kv_seqlen <= 0: + return - kc_ptr = KCaches + block_off * stride_kcn - vc_ptr = VCaches + block_off * stride_vcn + if kv_block_id * BLOCK >= kv_seqlen: + return - c_first_tokenloc = block0_first_tokenloc - if block_id != 0: - c_first_tokenloc *= 0 - c_last_tokenloc = tl.minimum( - BLOCK, q_seqlen + block0_first_tokenloc - block_id * BLOCK) + if is_decoding: + page_offs = tl.full((1, ), history_seqlen % BLOCK, dtype=tl.int32) + kv_mask = tl.full((1, ), 1, dtype=tl.int1) + q_offs = tl.full((1, ), q_startloc, dtype=tl.int32) + else: + page_offs = tl.arange(0, BLOCK) + kv_offs = kv_block_id * BLOCK + page_offs + kv_mask = (kv_offs >= history_seqlen) & (kv_offs < kv_seqlen) + token_off = q_startloc + kv_block_id * BLOCK - history_seqlen + q_offs = token_off + page_offs - for bidx in range(c_first_tokenloc, c_last_tokenloc): - sidx = bidx - c_first_tokenloc - mask = (h_off[:, None] < num_heads) & (d_off[None, :] < head_dim) - k = tl.load(ks_ptr + sidx * stride_kss + h_off[:, None] * stride_ksh + - d_off[None, :] * stride_ksd, - mask=mask) - tl.store(kc_ptr + bidx * stride_kcb + h_off[:, None] * stride_kch + - d_off[None, :] * stride_kcd, - k, - mask=mask) + block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id) - if BLOCK_DV > 0: - dv_off = tl.arange(0, BLOCK_DV) - maskv = (h_off[:, None] < num_heads) & (dv_off[None, :] < - head_dim_v) - v = tl.load(vs_ptr + sidx * stride_vss + - h_off[:, None] * stride_vsh + - dv_off[None, :] * stride_vsd, - mask=maskv) - tl.store(vc_ptr + bidx * stride_vcb + h_off[:, None] * stride_vch + - dv_off[None, :] * stride_vcd, - v, - mask=maskv) + d_off = tl.arange(0, BLOCK_D) + mask_ks = kv_mask[:, None] + mask_kc = mask_ks & (d_off[None, :] < head_dim) + d_off = d_off % head_dim + + ks_ptr = KStates + head_id * stride_ksh + ks_ptrs = ks_ptr + q_offs[:, + None] * stride_kss + d_off[None, :] * stride_ksd + kc_ptr = KCaches + block_off * stride_kcn + head_id * stride_kch + kc_ptrs = kc_ptr + page_offs[:, None] * stride_kcb + d_off[ + None, :] * stride_kcd + + if BLOCK_DV > 0: + dv_off = tl.arange(0, BLOCK_DV) + mask_vs = kv_mask[:, None] + mask_vc = mask_vs & (dv_off[None, :] < head_dim_v) + dv_off = dv_off % head_dim_v + vs_ptr = VStates + head_id * stride_vsh + vs_ptrs = vs_ptr + q_offs[:, None] * stride_vss + dv_off[ + None, :] * stride_vsd + vc_ptr = VCaches + block_off * stride_vcn + head_id * stride_vch + vc_ptrs = vc_ptr + page_offs[:, None] * stride_vcb + dv_off[ + None, :] * stride_vcd + + k = tl.load(ks_ptrs, mask=mask_ks) + if BLOCK_DV > 0: + v = tl.load(vs_ptrs, mask=mask_vs) + tl.store(kc_ptrs, k, mask=mask_kc) + if BLOCK_DV > 0: + tl.store(vc_ptrs, v, mask=mask_vc) -@wrap_jit_func(type_hint=dict( - KStates=Tensor, - VStates=Tensor, - KCaches=Tensor, - VCaches=Tensor, - KScalesZeros=Tensor, - VScalesZeros=Tensor, - QStartLoc=Tensor, - QSeqLens=Tensor, - KVSeqLens=Tensor, - BlockOffsets=Tensor, - num_heads=torch.int32, - head_dim=torch.int32, - stride_kss=int, - stride_ksh=int, - stride_ksd=int, - stride_vss=int, - stride_vsh=int, - stride_vsd=int, - stride_kcn=int, - stride_kcb=int, - stride_kch=int, - stride_kcd=int, - stride_vcn=int, - stride_vcb=int, - stride_vch=int, - stride_vcd=int, - stride_kszn=int, - stride_kszb=int, - stride_kszh=int, - stride_kszd=int, - stride_vszn=int, - stride_vszb=int, - stride_vszh=int, - stride_vszd=int, - stride_boff=int, - BLOCK=torch.int32, - BLOCK_D=torch.int32, - BLOCK_DV=torch.int32, - BLOCK_H=torch.int32, -)) @triton.jit def _fill_kv_cache_quant_kernel( KStates, @@ -394,15 +324,19 @@ def fill_kv_cache(k_states: Tensor, num_heads = k_caches.size(h_dim) head_dim = k_caches.size(d_dim) head_dim_v = v_states.size(-1) - max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1 + if max_q_seq_length == 1: + max_num_blocks = 1 + else: + max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1 BLOCK = block_size BLOCK_H = triton.next_power_of_2(num_heads) BLOCK_D = triton.next_power_of_2(head_dim) BLOCK_DV = triton.next_power_of_2(head_dim_v) - grid = [batch_size, max_num_blocks] kernel_meta = get_kernel_meta(k_states) if quant_policy == 0: + grid = [num_heads, max_num_blocks, batch_size] + is_decoding = max_num_blocks == 1 _fill_kv_cache_kernel[grid]( k_states, v_states, @@ -412,7 +346,7 @@ def fill_kv_cache(k_states: Tensor, q_seq_length, kv_seq_length, block_offsets, - num_heads=num_heads, + is_decoding=is_decoding, head_dim=head_dim, head_dim_v=head_dim_v, stride_kss=k_states.stride(-3), @@ -433,12 +367,12 @@ def fill_kv_cache(k_states: Tensor, BLOCK=BLOCK, BLOCK_D=BLOCK_D, BLOCK_DV=BLOCK_DV, - BLOCK_H=BLOCK_H, num_warps=4, num_stages=3, **kernel_meta, ) else: + grid = [batch_size, max_num_blocks] _fill_kv_cache_quant_kernel[grid]( k_states, v_states, diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index 149376e4be..74d090a9a3 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -70,15 +70,14 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), dtype=torch.int64, device=device) - input_buffers['q_start_loc'] = torch.zeros(max_batches, - dtype=torch.int64, - device=device) - input_buffers['q_seqlens'] = torch.zeros(max_batches, - dtype=torch.int64, - device=device) - input_buffers['kv_seqlens'] = torch.zeros(max_batches, - dtype=torch.int64, - device=device) + + input_buffers['qkv_lens'] = torch.zeros(3, + max_batches, + dtype=torch.int64, + device=device) + input_buffers['q_start_loc'] = input_buffers['qkv_lens'][0] + input_buffers['q_seqlens'] = input_buffers['qkv_lens'][1] + input_buffers['kv_seqlens'] = input_buffers['qkv_lens'][2] input_buffers['local_adapter_ids'] = torch.zeros(max_batches, dtype=torch.int64, device=device) @@ -111,13 +110,10 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_buffers['position_ids'][:, :num_tokens] = position_ids input_buffers[ 'block_offsets'][:batch_size, :num_blocks] = block_offsets - if q_seqlens.data_ptr() != input_buffers['q_seqlens'].data_ptr(): - input_buffers['q_seqlens'].zero_() - input_buffers['q_seqlens'][:batch_size] = q_seqlens - if kv_seqlens.data_ptr() != input_buffers['kv_seqlens'].data_ptr(): - input_buffers['kv_seqlens'].zero_() - input_buffers['kv_seqlens'][:batch_size] = kv_seqlens - input_buffers['q_start_loc'][:batch_size] = q_start_loc + + qkv = torch.stack((q_start_loc, q_seqlens, kv_seqlens)) + input_buffers['qkv_lens'].zero_() + input_buffers['qkv_lens'][:, :batch_size] = qkv if inputs_embeds is not None: emb_size = inputs_embeds.size(-1) if 'inputs_embeds' not in input_buffers: diff --git a/tests/pytorch/engine/test_logits_process.py b/tests/pytorch/engine/test_logits_process.py index 5c5fdbdc18..69c8315411 100644 --- a/tests/pytorch/engine/test_logits_process.py +++ b/tests/pytorch/engine/test_logits_process.py @@ -35,8 +35,9 @@ def test_process_bad_words(): [4, 4], [-1, -1], ]) + mask = bad_words >= 0 - out_scores = _process_bad_words_(scores, bad_words) + out_scores = _process_bad_words_(scores, bad_words.where(mask, 0), mask) for score, bw in zip(out_scores, bad_words): bw = bw.tolist()