From 299d5224700d043ccf8cbbd86805f7bcbecf5480 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 18 Mar 2024 14:56:27 +0800 Subject: [PATCH] fix tp lora adapter in pytorch engine (#1300) --- lmdeploy/pytorch/models/peft.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/models/peft.py b/lmdeploy/pytorch/models/peft.py index cbda491809..54163f5537 100644 --- a/lmdeploy/pytorch/models/peft.py +++ b/lmdeploy/pytorch/models/peft.py @@ -198,8 +198,8 @@ def __gather_xa(xa): if len(lora_input.ranks) > 1: gathered_xa = rearange_all_gather( gathered_xa, - q_start_loc=lora_input.q_start_loc, - q_seqlens=lora_input.q_seqlens, + b_start_loc=lora_input.q_start_loc, + b_seq_lens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, ranks=lora_input.ranks, world_size=world_size, @@ -230,8 +230,8 @@ def __gather_xa(xa): if len(lora_input.ranks) > 1: gathered_xa = rearange_all_gather( gathered_xa, - q_start_loc=lora_input.q_start_loc, - q_seqlens=lora_input.q_seqlens, + b_start_loc=lora_input.q_start_loc, + b_seq_lens=lora_input.q_seqlens, adapter_ids=lora_input.adapter_ids, ranks=lora_input.ranks, world_size=world_size,