Skip to content

Commit

Permalink
Fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Mar 13, 2024
1 parent 50395a3 commit 41db5b0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
25 changes: 23 additions & 2 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,6 @@ def __init__(self, base_layer, layer_id, process_group):
self.base_layer = base_layer
self.layer_id = layer_id
self.process_group = process_group
self.use_sgmv = False

def forward_layer_type(
self,
Expand All @@ -521,7 +520,7 @@ def forward_layer_type(
data = adapter_data.data.get(layer_type)

if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
if not self.use_sgmv:
if self.process_group.rank() == 0 and self.layer_id == 0:
print("!!! USE SGMV")
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
Expand All @@ -541,9 +540,14 @@ def forward_layer_type(
self.layer_id,
r,
)
if self.process_group.rank() == 0 and self.layer_id == 0:
print("V", v.shape, v.norm().item())

if self.process_group.size() > 1:
v = self.collect_lora_a(v)

if self.process_group.rank() == 0 and self.layer_id == 0:
print("V collect", v.shape, v.norm().item())

lora_b_sgmv_cutlass(
proj,
Expand All @@ -554,6 +558,9 @@ def forward_layer_type(
rank_segments.segment_ends,
self.layer_id,
)

if self.process_group.rank() == 0 and self.layer_id == 0:
print("proj", proj.shape, proj.norm().item())

if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
Expand All @@ -577,12 +584,26 @@ def forward_lora(
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]

lora_a = orient_for_rank(lora_a, lora_b.size(0))

if self.process_group.rank() == 0 and self.layer_id == 0:
print("lora_a", lora_a.shape, lora_a.norm().item())
print("lora_b", lora_b.shape, lora_b.norm().item())

a_out = input @ lora_a
if self.process_group.rank() == 0 and self.layer_id == 0:
print("V", a_out.shape, a_out.norm().item())

if self.process_group.size() > 1:
a_out = self.collect_lora_a(a_out)

if self.process_group.rank() == 0 and self.layer_id == 0:
print("V collect", a_out.shape, a_out.norm().item())

result = (a_out @ lora_b) * adapter_mask

if self.process_group.rank() == 0 and self.layer_id == 0:
print("proj", result.shape, result.norm().item())

return result

def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
Expand Down
6 changes: 4 additions & 2 deletions server/lorax_server/utils/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def __init__(
weights_b: List[torch.Tensor],
adapter_config: LoraConfig,
):
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1

# [num_layers, hidden_size, r]
weights_a = [
orient_for_rank(w, w.size(1)).contiguous()
Expand Down Expand Up @@ -187,8 +190,7 @@ def get_data(self, meta: AdapterBatchMetadata) -> AdapterWeightData:
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in self.lora_weights:
continue
print("rank indices for segment", self.lora_weights[adapter_idx].weights_b.shape)
rank_indices[self.lora_weights[adapter_idx].weights_b.size(1)].append(segment_idx)
rank_indices[self.lora_weights[adapter_idx].lora_a_r].append(segment_idx)

rank_data = {}
for rank, indices in rank_indices.items():
Expand Down

0 comments on commit 41db5b0

Please sign in to comment.