Skip to content

Commit

Permalink
WIP: fix tensor parallel issue
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Mar 13, 2024
1 parent e637d89 commit 50395a3
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
6 changes: 4 additions & 2 deletions server/lorax_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,13 @@ def load_batched_adapter_weights(
lora_a_list = [pad_rank(w, dim=1, world_size=self.world_size) for w in lora_a_list]
lora_b_list = [pad_rank(w, dim=0, world_size=self.world_size) for w in lora_b_list]

if lora_b_list:
if lora_a_list:
# update rank if it was padded
padded_rank = lora_b_list[0].size(0)
print("PADDED RANK", lora_a_list[0].shape, lora_b_list[0].shape)
padded_rank = lora_a_list[0].size(1)
adapter_config.r = padded_rank

print("ADAPTER CONFIG", adapter_config.r)
q_lora_merged = MergedLoraWeights(
*self.shard_lora_weights(lora_a_list, lora_b_list, layer_type), adapter_config,
)
Expand Down
6 changes: 6 additions & 0 deletions server/lorax_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ async def Decode(self, request: generate_pb2.DecodeRequest, context):
)

async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, context):
import time
t0 = time.time()
adapter_parameters = request.adapter_parameters
if is_base_model(adapter_parameters):
logger.info("No adapter to download for base model. Skipping.")
Expand Down Expand Up @@ -192,12 +194,15 @@ async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, co
f"(no reservation limit)")
adapter_memory_fraction = 0.0

print("!!! DownloadAdapter took", time.time() - t0, "seconds")
return generate_pb2.DownloadAdapterResponse(
downloaded=True,
memory_fraction=adapter_memory_fraction
)

async def LoadAdapter(self, request: generate_pb2.LoadAdapterRequest, context):
import time
t0 = time.time()
adapter_parameters = request.adapter_parameters
if is_base_model(adapter_parameters):
logger.info("No adapter to load for base model. Skipping.")
Expand All @@ -217,6 +222,7 @@ async def LoadAdapter(self, request: generate_pb2.LoadAdapterRequest, context):

self.model.load_adapter(adapter_parameters, adapter_source, adapter_index, api_token)

print("!!! LoadAdapter took", time.time() - t0, "seconds")
return generate_pb2.LoadAdapterResponse(loaded=True)
except Exception:
logger.exception("Error when loading adapter")
Expand Down
10 changes: 7 additions & 3 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def __init__(self, base_layer, layer_id, process_group):
self.base_layer = base_layer
self.layer_id = layer_id
self.process_group = process_group
self.use_sgmv = False

def forward_layer_type(
self,
Expand All @@ -520,6 +521,8 @@ def forward_layer_type(
data = adapter_data.data.get(layer_type)

if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
if not self.use_sgmv:
print("!!! USE SGMV")
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
Expand All @@ -536,7 +539,7 @@ def forward_layer_type(
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r // self.process_group.size(),
r,
)

if self.process_group.size() > 1:
Expand Down Expand Up @@ -571,13 +574,14 @@ def forward_lora(
adapter_mask: torch.Tensor,
) -> torch.Tensor:
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
lora_a = orient_for_rank(lora_a, data.adapter_index_configs[adapter_index].r)
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]

lora_a = orient_for_rank(lora_a, lora_b.size(0))
a_out = input @ lora_a

if self.process_group.size() > 1:
a_out = self.collect_lora_a(a_out)

lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
result = (a_out @ lora_b) * adapter_mask
return result

Expand Down
9 changes: 7 additions & 2 deletions server/lorax_server/utils/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs

def can_vectorize(self, pg: ProcessGroup) -> bool:
# print("Checking if we can vectorize", [(rank_data.rank, pg.size()) for rank_data in self.rank_data.values()])
return all(
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
for rank_data in self.rank_data.values()
Expand All @@ -67,6 +68,7 @@ class AdapterBatchData:

@staticmethod
def from_meta(meta: AdapterBatchMetadata, weights: Dict[str, "BatchedLoraWeights"]) -> "AdapterBatchData":
print("!!! FROM META")
data = {}
for k, v in weights.items():
if v.is_empty():
Expand Down Expand Up @@ -98,14 +100,15 @@ def __init__(
):
# [num_layers, hidden_size, r]
weights_a = [
orient_for_rank(w, adapter_config.r).contiguous()
orient_for_rank(w, w.size(1)).contiguous()
for w in weights_a
]
self.weights_a = torch.stack(weights_a)

# [num_layers, r, hidden_size]
self.weights_b = torch.stack(weights_b)

print("MERGED SHAPE", self.weights_a.shape, self.weights_a.shape)
self.adapter_config = adapter_config


Expand Down Expand Up @@ -184,11 +187,13 @@ def get_data(self, meta: AdapterBatchMetadata) -> AdapterWeightData:
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in self.lora_weights:
continue
rank_indices[self.lora_weights[adapter_idx].weights_a.size(2)].append(segment_idx)
print("rank indices for segment", self.lora_weights[adapter_idx].weights_b.shape)
rank_indices[self.lora_weights[adapter_idx].weights_b.size(1)].append(segment_idx)

rank_data = {}
for rank, indices in rank_indices.items():
lora_a_ptr_indices = lora_a_ptr[indices]
print("Rank", rank, "Indices", lora_a_ptr_indices.shape)
tmp_shrink, tmp_expand = get_tmp_tensors(
lora_a_ptr_indices.size(0),
rank,
Expand Down

0 comments on commit 50395a3

Please sign in to comment.