diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 84625a0098..153d82d47c 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -673,13 +673,13 @@ async def __long_context_forward(inputs): if token_count == 0 and slen > max_prefill_token_num: tmp_out = await __long_context_single_forward(inputs, idx) logits_gather.gather(tmp_out) - del tmp_out + tmp_out.pop('logits', None) idx += 1 elif token_count + slen > max_prefill_token_num: tmp_out = await __long_context_batched_forward( inputs, indices[0], idx) logits_gather.gather(tmp_out) - del tmp_out + tmp_out.pop('logits', None) indices = [] token_count = 0 else: