Skip to content

Commit

Permalink
Pad segments to power of 2
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Jan 3, 2024
1 parent 262af45 commit c65a6fa
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions server/lorax_server/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def pad_and_fill(dest: torch.Tensor, src: torch.Tensor, pad_value: int):
dest[src.shape[0]:].fill_(pad_value)


def next_pow_2(x: int) -> int:
assert x > 0
return 1 << (x-1).bit_length()


@dataclass
class GraphState:
input_ids: torch.Tensor
Expand Down Expand Up @@ -145,9 +150,16 @@ def trace(
) -> "GraphWrapper":
max_input_state = get_max_graph_state(device, adapter_layers)

# WARNING: for some reason the SGMV kernel can hang if we don't use a power of 2
# as the segment size. This is a workaround until we can figure out why.
# Specifically, this issue has been observed with batch_size=96.
# I suspect it is related to synchronization and the chunk size (256) used in the kernel.
# But we need to investigate further.
segment_size = next_pow_2(batch_size)

adapter_weight_data = {}
for layer_name, weight_data in max_input_state.adapter_data.data.items():
tmp_expand_size = get_tmp_expand_size(batch_size)
tmp_expand_size = get_tmp_expand_size(segment_size)

tmp_shrink = weight_data.rank_data[MAX_RANK].tmp_shrink
if use_cutlass_shrink(max_rank):
Expand All @@ -163,10 +175,10 @@ def trace(
rank=max_rank,
tmp_shrink=tmp_shrink,
tmp_expand=weight_data.rank_data[MAX_RANK].tmp_expand[:tmp_expand_size],
lora_a_ptr=weight_data.rank_data[MAX_RANK].lora_a_ptr[:batch_size],
lora_b_ptr=weight_data.rank_data[MAX_RANK].lora_b_ptr[:batch_size],
segment_starts=weight_data.rank_data[MAX_RANK].segment_starts[:batch_size],
segment_ends=weight_data.rank_data[MAX_RANK].segment_ends[:batch_size],
lora_a_ptr=weight_data.rank_data[MAX_RANK].lora_a_ptr[:segment_size],
lora_b_ptr=weight_data.rank_data[MAX_RANK].lora_b_ptr[:segment_size],
segment_starts=weight_data.rank_data[MAX_RANK].segment_starts[:segment_size],
segment_ends=weight_data.rank_data[MAX_RANK].segment_ends[:segment_size],
),
} if max_rank > 0 else {},
)
Expand Down

0 comments on commit c65a6fa

Please sign in to comment.