Skip to content

Commit

Permalink
Fix Qwen tensor parallelism (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Dec 11, 2023
1 parent 6ab05be commit b98884f
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 37 deletions.
15 changes: 10 additions & 5 deletions server/lorax_server/models/custom_modeling/flash_qwen_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,19 +142,23 @@ def forward(self, hidden_states, residual=None):


def load_attention(config, prefix, weights, layer_id):
base_layer = load_attention_multi(config, prefix, weights)
projection_size = config.kv_channels * config.num_attention_heads
projection_size = (config.hidden_size // config.num_attention_heads) * config.num_attention_heads
base_layer = load_attention_multi(config, prefix, weights, projection_size)
return TensorParallelMultiAdapterLinear.load(
base_layer, layer_id, [ATTN_C_ATTN], sizes=[
3 * projection_size,
], process_group=weights.process_group
)


def load_attention_multi(config, prefix, weights):
def load_attention_multi(config, prefix, weights, projection_size):
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.c_attn"],
prefixes=[
(f"{prefix}.c_attn", (0, projection_size)),
(f"{prefix}.c_attn", (projection_size, projection_size)),
(f"{prefix}.c_attn", (2 * projection_size, projection_size)),
],
dim=0,
weights=weights,
bias=True,
Expand All @@ -173,7 +177,8 @@ def __init__(
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.projection_size = config.kv_channels * config.num_attention_heads
self.projection_size = (self.head_size * config.num_attention_heads) // weights.process_group.size()
self.process_group = weights.process_group

self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
Expand Down
38 changes: 31 additions & 7 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from lorax_server.utils.dist import MEMORY_FRACTION
from lorax_server.utils.lora import LM_HEAD, AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights
from lorax_server.utils.segments import SegmentConcatBuilder, find_segments
from lorax_server.utils.weights import shard_on_dim

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -755,6 +756,27 @@ def load_adapter(self, adapter_id, adapter_source, adapter_index):

self.adapter_id = adapter_id

def shard_lora_weights(
self,
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
layer_type: str,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# [hidden_size, r]
split_dim = 0 if self.is_row_parallel(layer_type) else 1
weights_a = [
shard_on_dim(w, dim=split_dim, process_group=self.process_group)
for w in weights_a
]

# [r, hidden_size]
weights_b = [
shard_on_dim(w, dim=1, process_group=self.process_group)
for w in weights_b
]

return weights_a, weights_b

def load_batched_adapter_weights(
self,
module_map: Dict[str, Dict],
Expand Down Expand Up @@ -795,7 +817,7 @@ def load_batched_adapter_weights(
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale

q_lora_merged = MergedLoraWeights(
lora_a_list, lora_b_list, adapter_config, layer_type, self.process_group, self.is_row_parallel(layer_type),
*self.shard_lora_weights(lora_a_list, lora_b_list, layer_type), adapter_config,
)
q_lora_weights = self.batched_lora_weights[layer_type]
q_lora_weights.add_adapter(adapter_index, q_lora_merged)
Expand Down Expand Up @@ -828,7 +850,6 @@ def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch

def warmup(self, batch: FlashCausalLMBatch):

torch.cuda.empty_cache()
try:
cache_manager = set_cache_manager(
Expand All @@ -841,11 +862,14 @@ def warmup(self, batch: FlashCausalLMBatch):
self.device,
)
_, batch = self.generate_token(batch)
except Exception as e:
raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`"
) from e
except RuntimeError as e:
if "CUDA out of memory" in str(e) or isinstance(e, torch.cuda.OutOfMemoryError):
raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`"
) from e
else:
raise

torch.cuda.synchronize(self.device)

Expand Down
41 changes: 41 additions & 0 deletions server/lorax_server/models/flash_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID
from lorax_server.utils.lora import LM_HEAD
from lorax_server.utils.weights import shard_on_dim

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(

self.model_id = model_id
model = FlashQwenForCausalLM(config, weights)
self.config = config

torch.distributed.barrier(group=self.process_group)
super(FlashQwen, self).__init__(
Expand Down Expand Up @@ -137,3 +139,42 @@ def get_num_layers_for_type(self, layer_type: str) -> int:

def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

def split_lora_b_qkv(self, t: torch.Tensor, projection_size: int) -> torch.Tensor:
# Because we're splitting on the hidden size dimension, we need to
# account for the separate q, k, and v matrices.
chunks = torch.split(t, projection_size, dim=1)
assert len(chunks) == 3
chunks = [
shard_on_dim(w, dim=1, process_group=self.process_group)
for w in chunks
]
return torch.cat(chunks, dim=1)

def shard_lora_weights(
self,
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
layer_type: str,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# TODO(travis): genralize this for other layers and architectures
if layer_type == ATTN_C_ATTN:
# [hidden_size, r]
split_dim = 0 if self.is_row_parallel(layer_type) else 1
weights_a = [
shard_on_dim(w, dim=split_dim, process_group=self.process_group)
for w in weights_a
]

# [r, hidden_size]
# Because we're splitting on the hidden size dimension, we need to
# account for the separate q, k, and v matrices.
projection_size = (self.config.hidden_size // self.config.num_attention_heads) * self.config.num_attention_heads
weights_b = [
self.split_lora_b_qkv(w, projection_size)
for w in weights_b
]

return weights_a, weights_b
else:
return super().shard_lora_weights(weights_a, weights_b, layer_type)
6 changes: 3 additions & 3 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torch import nn
from torch.nn import functional as F
from typing import List
from typing import List, Tuple, Union

HAS_BITS_AND_BYTES = True
try:
Expand Down Expand Up @@ -368,7 +368,7 @@ def load(cls, config, prefix: str, weights, bias: bool, fan_in_fan_out: bool = F
def load_multi(
cls,
config,
prefixes: List[str],
prefixes: List[Union[str, Tuple]],
weights,
bias: bool,
dim: int,
Expand All @@ -379,7 +379,7 @@ def load_multi(
)

if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
b = weights.get_sharded_list("bias", prefixes, dim=0)
bias = torch.cat(b, dim=dim)
else:
bias = None
Expand Down
7 changes: 1 addition & 6 deletions server/lorax_server/utils/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,15 @@ def __init__(
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
adapter_config: LoraConfig,
layer_type: str,
process_group: ProcessGroup,
is_row_parallel: bool,
):
# [num_layers, hidden_size, r]
split_dim = 0 if is_row_parallel else 1
weights_a = [
orient_for_rank(shard_on_dim(w, dim=split_dim, process_group=process_group), adapter_config.r)
orient_for_rank(w, adapter_config.r)
for w in weights_a
]
self.weights_a = torch.stack(weights_a)

# [num_layers, r, hidden_size]
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
self.weights_b = torch.stack(weights_b)

self.adapter_config = adapter_config
Expand Down
61 changes: 45 additions & 16 deletions server/lorax_server/utils/weights.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from typing import List, Dict, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError
import torch
from loguru import logger
Expand Down Expand Up @@ -114,15 +114,31 @@ def get_tensor(self, tensor_name: str):
tensor = tensor.to(device=self.device)
return tensor

def get_partial_sharded(self, tensor_name: str, dim: int):
def get_partial_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[int, int]] = None):
"""Loads tensor with the given name and shards it along the given dimension.
The optional range argument can be used to load and split on only a subset of the tensor.
This is useful in cases where the tensor is stored as one contiguous block, but is logically
split into different components that need to be sharded separately. For example, when storing
QKV weights together as a single tensor on disk.
Args:
tensor_name (str): Name of the tensor to load.
dim (int): Dimension to shard along.
range (Optional[Tuple[int, int]]): Range of indices to load and shard as (offset, size).
"""
filename, tensor_name = self.get_filename(tensor_name)
world_size = self.process_group.size()
rank = self.process_group.rank()

f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
size = slice_.get_shape()[dim]
start, stop = get_start_stop_idxs_for_rank(size, rank, world_size)
if range is not None:
offset, size = range
else:
offset = 0
size = slice_.get_shape()[dim]
start, stop = get_start_stop_idxs_for_rank(offset, size, rank, world_size)

if dim == 0:
tensor = slice_[start:stop]
Expand All @@ -137,35 +153,48 @@ def get_partial_sharded(self, tensor_name: str, dim: int):
tensor = tensor.to(device=self.device)
return tensor

def get_sharded(self, tensor_name: str, dim: int):
def get_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[int, int]] = None):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
size = slice_.get_shape()[dim]
size = slice_.get_shape()[dim] if range is None else range[1]
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
return self.get_partial_sharded(tensor_name, dim, range=range)

def get_sharded_prefix(self, module_name: str, prefix: Union[str, Tuple], dim: int):
if isinstance(prefix, str):
return self.get_sharded(f"{prefix}.{module_name}", dim=dim)
else:
assert isinstance(prefix, tuple)
assert len(prefix) == 2
return self.get_sharded(f"{prefix[0]}.{module_name}", dim=dim, range=prefix[1])

def get_sharded_list(self, module_name: str, prefixes: List[Union[str, Tuple]], dim: int):
return [self.get_sharded_prefix(module_name, p, dim=dim) for p in prefixes]

def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str, dim: int):
if quantize in ["gptq", "awq"]:
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
self.get_sharded_list("qweight", prefixes, dim=1), dim=1
)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `lorax-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)

qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
self.get_sharded_list("qzeros", prefixes, dim=1), dim=1
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
self.get_sharded_list("scales", prefixes, dim=1), dim=1
)
if quantize == "gptq":
# no tensor parallelism, so remove the range if provided
prefixes = [p[0] if isinstance(p, tuple) else p for p in prefixes]
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
Expand All @@ -176,7 +205,7 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
w = self.get_sharded_list("weight", prefixes, dim=0)
weight = torch.cat(w, dim=dim)
return weight

Expand Down Expand Up @@ -314,10 +343,10 @@ def _set_gptq_params(self, model_id):
except Exception:
pass

def get_start_stop_idxs_for_rank(size, rank, world_size):
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
start = offset + rank * block_size
stop = offset + (rank + 1) * block_size
return start, stop


Expand All @@ -326,7 +355,7 @@ def shard_on_dim(t: torch.Tensor, dim: int, process_group: torch.distributed.Pro
rank = process_group.rank()

size = t.shape[dim]
start, stop = get_start_stop_idxs_for_rank(size, rank, world_size)
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)

if dim == 0:
tensor = t[start:stop]
Expand Down

0 comments on commit b98884f

Please sign in to comment.