Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Nov 19, 2024
1 parent 129d77f commit 66cb69b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
21 changes: 11 additions & 10 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def __call__(self,
**kwargs)

async def handle_exception(self, session_id: int):
await self.metrics.failure_frame()
self.metrics.failure_frame()
await self.stop_session(session_id)

async def stop_session(self, session_id: int):
Expand Down Expand Up @@ -548,9 +548,9 @@ async def get_inputs_genconfig(gen_config):
tools=tools)
return prompt_input, gen_config

arrival_frame = await self.metrics.insert_frame()
arrival_frame = self.metrics.insert_frame()
prompt_input, gen_config = await get_inputs_genconfig(gen_config)
await self.metrics.update_preprocess(arrival_frame)
self.metrics.update_preprocess(arrival_frame)
prompt = prompt_input['prompt']
input_ids = prompt_input['input_ids']
finish_reason = None
Expand Down Expand Up @@ -586,9 +586,9 @@ async def get_inputs_genconfig(gen_config):
if sequence_end is True and sequence_start is False:
await self.end_session(session_id)
else:
start_frame = await self.metrics.insert_frame()
start_frame = self.metrics.insert_frame()
generator = await self.get_generator(False, session_id)
await self.metrics.update_queue_waiting(start_frame)
self.metrics.update_queue_waiting(start_frame)
iterator = generator.async_stream_infer(
session_id=session_id,
**prompt_input,
Expand All @@ -605,9 +605,8 @@ async def get_inputs_genconfig(gen_config):
start_ids_offset = state.ids_offset
response = ''
async for outputs in iterator:
start_frame = await self.metrics.insert_frame()
if state.prev_tokens is None:
await self.metrics.update_FTL(arrival_frame)
start_frame = self.metrics.insert_frame()
is_first_token = state.prev_tokens is None
# decode res
if is_error(outputs.status):
tokens = 0
Expand All @@ -627,7 +626,9 @@ async def get_inputs_genconfig(gen_config):
if outputs.logprobs:
log_offset = ids_offset - start_ids_offset
logprobs = outputs.logprobs[log_offset:]
await self.metrics.update_postprocess(start_frame)
self.metrics.update_postprocess(start_frame)
if is_first_token:
self.metrics.update_FTL(arrival_frame)
# response, history token len,
# input token len, gen token len
yield GenOut(response, self.id2step[str(session_id)],
Expand Down Expand Up @@ -659,7 +660,7 @@ async def get_inputs_genconfig(gen_config):
# TODO modify pytorch or turbomind api
if self.backend == 'pytorch' and sequence_end:
await self.end_session(session_id)
await self.metrics.last_token_frame(iterator)
self.metrics.last_token_frame(iterator)

def parse_tool_response(self, text, tools, **kwargs):
"""Parse model response containing tool information.
Expand Down
14 changes: 7 additions & 7 deletions lmdeploy/serve/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,42 +188,42 @@ def info(self, backend_config: object) -> None:
}
self.info_backend_config.info(config_dict)

async def failure_frame(self):
def failure_frame(self):
"""log the failaure frame."""
if self.applied:
self.stats.request_failure += 1
self.stats.request_total += 1

async def last_token_frame(self, iterator):
def last_token_frame(self, iterator):
"""log the last token frame."""
if self.applied:
self.stats.duration_infer += iterator.get_duration()
self.stats.request_success += 1
self.stats.request_total += 1
self.log()

async def insert_frame(self):
def insert_frame(self):
"""Insert a frame."""
if self.applied:
return time.time()
return None

async def update_postprocess(self, start_frame):
def update_postprocess(self, start_frame):
"""Update postprocess duration."""
if self.applied:
self.stats.duration_postprocess += time.time() - start_frame

async def update_preprocess(self, start_frame):
def update_preprocess(self, start_frame):
"""Update preprocess duration."""
if self.applied:
self.stats.duration_preprocess += time.time() - start_frame

async def update_queue_waiting(self, start_frame):
def update_queue_waiting(self, start_frame):
"""Update queue waiting time."""
if self.applied:
self.stats.duration_queue += time.time() - start_frame

async def update_FTL(self, start_frame):
def update_FTL(self, start_frame):
"""Update first token latency."""
if self.applied:
self.stats.first_token_latency += time.time() - start_frame
Expand Down

0 comments on commit 66cb69b

Please sign in to comment.