diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py
index 03527265d2..ff20bec5bd 100644
--- a/benchmark/profile_throughput.py
+++ b/benchmark/profile_throughput.py
@@ -12,7 +12,7 @@
 import numpy as np
 from tqdm import tqdm
 
-from lmdeploy.tokenizer import Tokenizer
+from lmdeploy.tokenizer import DetokenizeState, Tokenizer
 from lmdeploy.turbomind import TurboMind
 
 
@@ -80,7 +80,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
         for prompt, input_seqlen, output_seqlen in iter(
                 req_queue.get, [None, None, None]):
             _per_token_latency_stats = [0] * (output_seqlen + 1)
-            offset = 0
+            state = DetokenizeState()
             prev = time.perf_counter()
             n_prev_token = 0
 
@@ -96,8 +96,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
                     ignore_eos=True,
                     stream_output=stream_output):
                 _, res, n_token = outputs
-                self.tokenizer.decode(res, offset)
-                offset = n_token
+                _, state = self.tokenizer.detokenize_incrementally(res, state)
                 now = time.perf_counter()
                 if n_prev_token != n_token:
                     _per_token_latency_stats[n_prev_token] = np.round(
diff --git a/benchmark/profile_torch_throughput.py b/benchmark/profile_torch_throughput.py
index e82f657593..9c5f6d3b82 100644
--- a/benchmark/profile_torch_throughput.py
+++ b/benchmark/profile_torch_throughput.py
@@ -14,7 +14,7 @@
 
 from lmdeploy.messages import EngineGenerationConfig, PytorchEngineConfig
 from lmdeploy.pytorch.engine import Engine as LMEngine
-from lmdeploy.tokenizer import Tokenizer
+from lmdeploy.tokenizer import DetokenizeState, Tokenizer
 
 
 def sample_requests(
@@ -83,7 +83,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
         for prompt, input_seqlen, output_seqlen in iter(
                 req_queue.get, [None, None, None]):
             _per_token_latency_stats = [0] * (output_seqlen + 1)
-            offset = 0
+            state = DetokenizeState
             prev = time.perf_counter()
             n_prev_token = 0
 
@@ -98,8 +98,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
                                                    input_ids=input_ids,
                                                    gen_config=gen_config):
                 res, n_token = outputs[-2:]
-                self.tokenizer.decode(res, offset)
-                offset = n_token
+                _, state = self.tokenizer.detokenize_incrementally(res, state)
                 now = time.perf_counter()
                 if n_prev_token != n_token:
                     _per_token_latency_stats[n_prev_token] = np.round(
diff --git a/lmdeploy/pytorch/chat.py b/lmdeploy/pytorch/chat.py
index a0f41eb998..9876dbf98f 100644
--- a/lmdeploy/pytorch/chat.py
+++ b/lmdeploy/pytorch/chat.py
@@ -6,7 +6,7 @@
 
 from lmdeploy.messages import EngineGenerationConfig, PytorchEngineConfig
 from lmdeploy.model import MODELS, best_match_model
-from lmdeploy.tokenizer import Tokenizer
+from lmdeploy.tokenizer import DetokenizeState, Tokenizer
 
 os.environ['TM_LOG_LEVEL'] = 'ERROR'
 
@@ -107,7 +107,7 @@ def run_chat(model_path: str,
                 continue
 
             print(f'{prompt} ', end='', flush=True)
-            response_size = 0
+            state = DetokenizeState()
             gen_config.random_seed = seed
             gen_config.stop_words = stop_words
             for outputs in generator.stream_infer(session_id=session_id,
@@ -116,15 +116,10 @@ def run_chat(model_path: str,
                                                   adapter_name=adapter_name):
                 status, res, tokens = outputs
                 # decode res
-                response = tokenizer.decode(res, offset=response_size)
-                # utf-8 char at the end means it's a potential unfinished
-                # byte sequence, continue to concate it with the next
-                # sequence and decode them together
-                if response.endswith('�'):
-                    continue
+                response, state = tokenizer.detokenize_incrementally(
+                    res, state)
                 response = valid_str(response)
                 print(f'{response}', end='', flush=True)
-                response_size = tokens
 
             # update step
             step += len(input_ids) + tokens
diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py
index 5d2b318731..bd81372c9e 100644
--- a/lmdeploy/serve/async_engine.py
+++ b/lmdeploy/serve/async_engine.py
@@ -12,6 +12,7 @@
                                PytorchEngineConfig, Response,
                                TurbomindEngineConfig)
 from lmdeploy.model import ChatTemplateConfig, best_match_model
+from lmdeploy.tokenizer import DetokenizeState
 from lmdeploy.utils import _stop_words, get_logger
 
 
@@ -452,7 +453,7 @@ async def generate(
         else:
             generator = await self.get_generator(False, session_id)
             with self.safe_run(session_id):
-                response_size = 0
+                state = DetokenizeState()
                 async for outputs in generator.async_stream_infer(
                         session_id=session_id,
                         input_ids=input_ids,
@@ -461,25 +462,20 @@ async def generate(
                         sequence_start=(sequence_start),
                         sequence_end=sequence_end,
                         step=self.id2step[str(session_id)]):
-                    status, res, tokens = outputs
+                    _, res, tokens = outputs
                     # decode res
-                    response = self.tokenizer.decode(res, offset=response_size)
-                    # utf-8 char at the end means it's a potential unfinished
-                    # byte sequence, continue to concate it with the next
-                    # sequence and decode them together
-                    if response.endswith('�'):
-                        continue
+                    response, state = self.tokenizer.detokenize_incrementally(
+                        res, state)
                     # response, history token len,
                     # input token len, gen token len
                     yield GenOut(response, self.id2step[str(session_id)],
                                  len(input_ids), tokens, finish_reason)
-                    response_size = tokens
 
                 finish_reason = 'length' \
                     if tokens >= gen_config.max_new_tokens else 'stop'
-                # `response_size` might be note updated since
-                # ` if response.endswith('�')`
-                if response_size == tokens:
+                # utf-8 char at the end means it's a potential unfinished
+                # byte sequence
+                if not response.endswith('�'):
                     response = ''  # avaid returning the last response twice
                 yield GenOut(response, self.id2step[str(session_id)],
                              len(input_ids), tokens, finish_reason)
diff --git a/lmdeploy/tokenizer.py b/lmdeploy/tokenizer.py
index 28e8092c63..b37a5a2483 100644
--- a/lmdeploy/tokenizer.py
+++ b/lmdeploy/tokenizer.py
@@ -2,7 +2,8 @@
 import json
 import os.path as osp
 from collections import deque
-from typing import List, Optional, Sequence, Union
+from dataclasses import dataclass
+from typing import List, Optional, Sequence, Tuple, Union
 
 import torch
 
@@ -12,6 +13,31 @@
 # importing are starting from the package root lmdeploy
 
 
+@dataclass
+class DetokenizeState:
+    """A state collection of incrementally detekenization.
+
+    Args:
+        ids_offset (int): offset to all input ids. In LMDeploy, the output
+            ids length is not one by one. It could be random by random.
+        prev_tokens (List[str] | None): for incrementally decoding.
+            Default to None, which means the first round.
+        prefix_offset (int): the start index of tokens to be converted to
+            string (prev + new tokens). Default to 0 for the first round.
+        read_offset (int): the end index of tokens to be converted to
+            string (prev token). Default to 0 for the first round.
+    """
+    ids_offset: int = 0
+    prev_tokens: Optional[List[str]] = None
+    prefix_offset: int = 0
+    read_offset: int = 0
+
+    def as_tuple(self) -> Tuple:
+        """Return a tuple of states."""
+        return (self.ids_offset, self.prev_tokens, self.prefix_offset,
+                self.read_offset)
+
+
 class SentencePieceTokenizer:
     """Tokenizer of sentencepiece.
 
@@ -113,6 +139,34 @@ def decode(self, t: Sequence[int], offset: Optional[int] = None):
             out_string = self._maybe_add_prefix_space(t, out_string)
         return out_string
 
+    def detokenize_incrementally(self,
+                                 all_input_ids: Sequence[int],
+                                 state: DetokenizeState,
+                                 skip_special_tokens: bool = True,
+                                 spaces_between_special_tokens: bool = True):
+        """Incrementally detokenize the input indexes.
+
+        Args:
+            all_input_ids (List[int]): a list of token ids. Expected to be
+                different sections of a long sequence.
+            state (DetokenizeState): an instance of DetokenizeState. Consists
+                of incrementally decoding states.
+            skip_special_tokens (bool): Whether or not to remove special tokens
+                in the decoding. Default to be True.
+            spaces_between_special_tokens (bool): Whether or not to add spaces
+                between special tokens. Default to be True.
+        Returns:
+            str: decoding output string of the current round.
+            state (DetokenizeState): an instance of DetokenizeState. Consists
+                of incrementally decoding states.
+        """
+        out_string = self.model.Decode(all_input_ids)
+        if state.prev_tokens is not None:
+            out_string = self._maybe_add_prefix_space(all_input_ids,
+                                                      out_string)
+        state.prev_tokens = []  # not None for the above condition
+        return out_string, state
+
     def __call__(self, s: Union[str, Sequence[str]]):
         """Tokenize prompts.
 
@@ -289,9 +343,117 @@ def decode(self, t: Sequence[int], offset: Optional[int] = None):
         out_string = self.model.decode(t,
                                        skip_special_tokens=skip_special_tokens)
         if offset:
+            logger = get_logger('lmdeploy')
+            logger.warning('For incrementally detokenization, please try'
+                           'detokenize_incrementally function instead.')
             out_string = self._maybe_add_prefix_space(t, out_string)
         return out_string
 
+    @staticmethod
+    def _convert_tokens_to_string_with_added_encoders(
+        tokenizer,
+        output_tokens: List[str],
+        skip_special_tokens: bool,
+        spaces_between_special_tokens: bool,
+    ) -> str:
+        if tokenizer.is_fast or not tokenizer.get_added_vocab():
+            return tokenizer.convert_tokens_to_string(output_tokens)
+        # Adapted from
+        # https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L68-L99
+        sub_texts = []
+        current_sub_text = []
+        all_special_tokens = set(tokenizer.all_special_tokens)
+        for token in output_tokens:
+            if skip_special_tokens and token in all_special_tokens:
+                continue
+            if token in tokenizer.get_added_vocab():
+                if current_sub_text:
+                    sub_text = tokenizer.convert_tokens_to_string(
+                        current_sub_text)
+                    sub_texts.append(sub_text)
+                    current_sub_text = []
+                sub_texts.append(token)
+            else:
+                current_sub_text.append(token)
+        if current_sub_text:
+            sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
+            sub_texts.append(sub_text)
+        if spaces_between_special_tokens:
+            return ' '.join(sub_texts)
+        else:
+            return ''.join(sub_texts)
+
+    # Based on
+    # https://github.com/vllm-project/vllm/blob/v0.2.7/vllm/transformers_utils/tokenizer.py#L105-L165
+    def detokenize_incrementally(self,
+                                 all_input_ids: Sequence[int],
+                                 state: DetokenizeState,
+                                 skip_special_tokens: bool = True,
+                                 spaces_between_special_tokens: bool = True):
+        """Incrementally detokenize the input indexes.
+
+        Args:
+            all_input_ids (List[int]): a list of token ids. Expected to be
+                different sections of a long sequence.
+            state (DetokenizeState): an instance of DetokenizeState. Consists
+                of incrementally decoding states.
+            skip_special_tokens (bool): Whether or not to remove special tokens
+                in the decoding. Default to be True.
+            spaces_between_special_tokens (bool): Whether or not to add spaces
+                between special tokens. Default to be True.
+        Returns:
+            str: decoding output string of the current round.
+            state (DetokenizeState): an instance of DetokenizeState. Consists
+                of incrementally decoding states.
+        """
+        tokenizer = self.model
+        ids_offset, prev_tokens, prefix_offset, read_offset = state.as_tuple()
+        # This is the first iteration for this sequence
+        new_tokens = tokenizer.convert_ids_to_tokens(
+            all_input_ids[ids_offset:],
+            skip_special_tokens=skip_special_tokens)
+        if prev_tokens is None:
+            # Please notice that in VLLM, indexes are detokenized one by one
+            # while in LMDeploy, every turn, the detokenized indexes length
+            # can be different.
+            if skip_special_tokens and new_tokens[
+                    0] in tokenizer.all_special_ids:
+                read_offset = 1  # skip special token
+            output_tokens = new_tokens
+            prev_tokens = new_tokens
+        else:
+            # Put new_token_id in a list so skip_special_tokens is respected
+            output_tokens = prev_tokens + new_tokens
+            prev_tokens += new_tokens
+
+        prefix_text = self._convert_tokens_to_string_with_added_encoders(
+            tokenizer,
+            output_tokens[prefix_offset:read_offset],
+            skip_special_tokens=skip_special_tokens,
+            spaces_between_special_tokens=spaces_between_special_tokens,
+        )
+        new_text = self._convert_tokens_to_string_with_added_encoders(
+            tokenizer,
+            output_tokens[prefix_offset:],
+            skip_special_tokens=skip_special_tokens,
+            spaces_between_special_tokens=spaces_between_special_tokens,
+        )
+
+        # update state and get final decoded output
+        if len(new_text) > len(prefix_text) and not new_text.endswith('�'):
+            # utf-8 char at the end means it's a potential unfinished byte
+            # sequence from byte fallback tokenization.
+            # If it's in the middle, it's probably a real invalid id generated
+            # by the model
+            prefix_offset = read_offset
+            read_offset = len(output_tokens)
+            new_text = new_text[len(prefix_text):]
+        else:
+            new_text = ''
+
+        return new_text, DetokenizeState(len(all_input_ids), prev_tokens,
+                                         prefix_offset, read_offset)
+
     def __call__(self, s: Union[str, Sequence[str]]):
         """Tokenize prompts.
 
@@ -365,6 +527,33 @@ def decode(self, t: Sequence[int], offset: Optional[int] = None):
         """
         return self.model.decode(t, offset)
 
+    def detokenize_incrementally(self,
+                                 all_input_ids: Sequence[int],
+                                 state: DetokenizeState,
+                                 skip_special_tokens: bool = True,
+                                 spaces_between_special_tokens: bool = True):
+        """Incrementally detokenize the input indexes.
+
+        Args:
+            all_input_ids (List[int]): a list of token ids. Expected to be
+                different sections of a long sequence.
+            state (DetokenizeState): an instance of DetokenizeState. Consists
+                of incrementally decoding states.
+            skip_special_tokens (bool): Whether or not to remove special tokens
+                in the decoding. Default to be True.
+            spaces_between_special_tokens (bool): Whether or not to add spaces
+                between special tokens. Default to be True.
+        Returns:
+            str: decoding output string of the current round.
+            state (DetokenizeState): an instance of DetokenizeState. Consists
+                of incrementally decoding states.
+        """
+        return self.model.detokenize_incrementally(
+            all_input_ids,
+            state=state,
+            skip_special_tokens=skip_special_tokens,
+            spaces_between_special_tokens=spaces_between_special_tokens)
+
     def __call__(self, s: Union[str, Sequence[str]]):
         """Tokenize prompts.
 
diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py
index 694a3cbc31..e6da93b11e 100644
--- a/lmdeploy/turbomind/chat.py
+++ b/lmdeploy/turbomind/chat.py
@@ -4,6 +4,7 @@
 import random
 
 from lmdeploy.model import ChatTemplateConfig
+from lmdeploy.tokenizer import DetokenizeState
 from lmdeploy.turbomind.utils import get_gen_param
 
 os.environ['TM_LOG_LEVEL'] = 'ERROR'
@@ -113,7 +114,7 @@ def main(model_path: str,
                                       step, request_output_len, **kwargs)
 
             print(f'{prompt} ', end='', flush=True)
-            response_size = 0
+            state = DetokenizeState()
             for outputs in generator.stream_infer(
                     session_id=session_id,
                     input_ids=[input_ids],
@@ -123,15 +124,10 @@ def main(model_path: str,
                     random_seed=seed if nth_round == 1 else None):
                 _, res, tokens = outputs
                 # decode res
-                response = tokenizer.decode(res, offset=response_size)
-                # utf-8 char at the end means it's a potential unfinished
-                # byte sequence, continue to concate it with the next
-                # sequence and decode them together
-                if response.endswith('�'):
-                    continue
+                response, state = tokenizer.detokenize_incrementally(
+                    res, state=state)
                 response = valid_str(response)
                 print(f'{response}', end='', flush=True)
-                response_size = tokens
 
             # update step
             step += len(input_ids) + tokens
diff --git a/tests/test_lmdeploy/test_model.py b/tests/test_lmdeploy/test_model.py
index 93acfd17a4..c899b4c9d3 100644
--- a/tests/test_lmdeploy/test_model.py
+++ b/tests/test_lmdeploy/test_model.py
@@ -17,7 +17,7 @@
      ('01-ai/Yi-34B-Chat', ['yi-chat', 'yi-34b', 'yi-200k']),
      ('01-ai/Yi-6B-Chat', ['yi', 'yi-chat']),
      ('WizardLM/WizardLM-70B-V1.0', ['wizardlm']),
-     ('CodeLlama-34b-Instruct-hf', ['codellama']),
+     ('codellama/CodeLlama-34b-Instruct-hf', ['codellama']),
      ('tiiuae/falcon-7b', ['falcon']), ('workspace', [None])])
 @pytest.mark.parametrize('suffix', ['', '-w4', '-4bit', '-16bit'])
 def test_best_match_model(model_path_and_name, suffix):
diff --git a/tests/test_lmdeploy/test_tokenizer.py b/tests/test_lmdeploy/test_tokenizer.py
index d5638193d1..d887334564 100644
--- a/tests/test_lmdeploy/test_tokenizer.py
+++ b/tests/test_lmdeploy/test_tokenizer.py
@@ -1,26 +1,31 @@
 import pytest
 
-from lmdeploy.tokenizer import HuggingFaceTokenizer
+from lmdeploy.tokenizer import DetokenizeState, HuggingFaceTokenizer
 
 
 @pytest.mark.parametrize('model_path', [
     'internlm/internlm-chat-7b', 'Qwen/Qwen-7B-Chat',
-    'baichuan-inc/Baichuan2-7B-Chat', 'codellama/CodeLlama-7b-hf',
-    'upstage/SOLAR-0-70b-16bit'
+    'baichuan-inc/Baichuan2-7B-Chat', 'upstage/SOLAR-0-70b-16bit',
+    'baichuan-inc/Baichuan-7B', 'codellama/CodeLlama-7b-hf',
+    'THUDM/chatglm2-6b', '01-ai/Yi-6B-200k', '01-ai/Yi-34B-Chat',
+    '01-ai/Yi-6B-Chat', 'WizardLM/WizardLM-70B-V1.0',
+    'codellama/CodeLlama-34b-Instruct-hf', 'tiiuae/falcon-7b'
+])
+@pytest.mark.parametrize('input', [
+    'hi, this is a test 😆😆! ' * 5, '為什麼我還在用繁體字 😆😆 gg! ' * 5,
+    ' License at\n#\n#' + ' ' * 100 + 'ht', '   '
 ])
-@pytest.mark.parametrize(
-    'input', ['hi, this is a test 😆😆! ' * 5, '為什麼我還在用繁體字 😆😆 gg! ' * 5])
-def test_tokenizer(model_path, input):
+@pytest.mark.parametrize('interval', [1, 3])
+@pytest.mark.parametrize('skip_special_tokens', [True, False])
+def test_tokenizer(model_path, input, interval, skip_special_tokens):
     tokenizer = HuggingFaceTokenizer(model_path)
-    encoded = tokenizer.encode(input, False)
+    encoded = tokenizer.encode(input, False, add_special_tokens=False)
     output = ''
-    offset = 0
-    for i in range(1, len(encoded) + 1):
-        decoded = tokenizer.decode(encoded[:i], offset)
-        if decoded.endswith('�'):
-            continue
+    state = DetokenizeState()
+    for i in range(0, len(encoded), interval):
+        decoded, state = tokenizer.detokenize_incrementally(
+            encoded, state, skip_special_tokens)
         output += decoded
-        offset = i
     assert input == output, 'input string should equal to output after enc-dec'