|
3 | 3 |
|
4 | 4 | import itertools |
5 | 5 | from abc import abstractmethod |
6 | | -from typing import Any, Literal, Optional, Union |
| 6 | +from typing import Any, Optional, Union |
7 | 7 |
|
8 | 8 | import torch |
9 | | -import torch.nn as nn |
10 | 9 | from torch.nn.parameter import Parameter, UninitializedParameter |
11 | 10 |
|
12 | 11 | from vllm.distributed import ( |
@@ -1440,237 +1439,3 @@ def extra_repr(self) -> str: |
1440 | 1439 | s += f", tp_size={self.tp_size}" |
1441 | 1440 | s += f", reduce_results={self.reduce_results}" |
1442 | 1441 | return s |
1443 | | - |
1444 | | - |
1445 | | -@CustomOp.register("qkv_cross_parallel_linear") |
1446 | | -class QKVCrossParallelLinear(LinearBase): |
1447 | | - """Linear layers for efficient cross-attention's QKV transformation. |
1448 | | -
|
1449 | | - Args: |
1450 | | - hidden_size: input hidden state size of the transformer. |
1451 | | - head_size: size of each attention head. |
1452 | | - total_num_heads: total number of attention query heads. |
1453 | | - total_num_kv_heads: total number of attention key/value heads. If |
1454 | | - None, assume total_num_kv_heads = total_num_heads. |
1455 | | - bias: If true, add bias. |
1456 | | - skip_bias_add: This was added to enable performance optimizations where |
1457 | | - bias can be fused with other element-wise operations. we |
1458 | | - skip adding bias but instead return it. |
1459 | | - params_dtype: Data type for the parameters. |
1460 | | - quant_config: Quantization configure. |
1461 | | - prefix: The name of the layer in the state dict, including all parents |
1462 | | - (e.g. model.layers.0.qkv_proj) |
1463 | | - """ |
1464 | | - |
1465 | | - def __init__( |
1466 | | - self, |
1467 | | - hidden_size: int, |
1468 | | - head_size: int, |
1469 | | - total_num_heads: int, |
1470 | | - total_num_kv_heads: Optional[int] = None, |
1471 | | - bias: bool = True, |
1472 | | - skip_bias_add: bool = False, |
1473 | | - params_dtype: Optional[torch.dtype] = None, |
1474 | | - quant_config: Optional[QuantizationConfig] = None, |
1475 | | - prefix: str = "", |
1476 | | - ): |
1477 | | - # input_size and output_size are not used, just for alignment |
1478 | | - input_size = hidden_size |
1479 | | - output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size |
1480 | | - super().__init__( |
1481 | | - input_size=input_size, |
1482 | | - output_size=output_size, |
1483 | | - skip_bias_add=skip_bias_add, |
1484 | | - params_dtype=params_dtype, |
1485 | | - quant_config=quant_config, |
1486 | | - prefix=prefix, |
1487 | | - ) |
1488 | | - |
1489 | | - self.quant_config = quant_config |
1490 | | - |
1491 | | - # Empty placeholders for loading as a single module. |
1492 | | - placeholder_size = 0 |
1493 | | - assert self.quant_method is not None |
1494 | | - self.quant_method.create_weights( |
1495 | | - self, |
1496 | | - placeholder_size, |
1497 | | - [placeholder_size], |
1498 | | - placeholder_size, |
1499 | | - placeholder_size, |
1500 | | - self.params_dtype, |
1501 | | - weight_loader=self.weight_loader, |
1502 | | - ) |
1503 | | - |
1504 | | - # Use a dictionary to avoid submodules parameters auto-registration: |
1505 | | - # drop-in replacement for a `QKVParallelLinear` module. |
1506 | | - self.proj = dict() |
1507 | | - self.proj["q_proj_decoder"] = ColumnParallelLinear( |
1508 | | - input_size=hidden_size, |
1509 | | - output_size=total_num_heads * head_size, |
1510 | | - bias=bias, |
1511 | | - quant_config=quant_config, |
1512 | | - skip_bias_add=skip_bias_add, |
1513 | | - params_dtype=params_dtype, |
1514 | | - prefix=f"{prefix}.q_proj_decoder", |
1515 | | - ) |
1516 | | - |
1517 | | - self.proj["kv_proj_encoder"] = QKVParallelLinear( |
1518 | | - hidden_size=hidden_size, |
1519 | | - head_size=head_size, |
1520 | | - total_num_heads=0, |
1521 | | - total_num_kv_heads=total_num_kv_heads, |
1522 | | - bias=bias, |
1523 | | - quant_config=quant_config, |
1524 | | - skip_bias_add=skip_bias_add, |
1525 | | - params_dtype=params_dtype, |
1526 | | - prefix=f"{prefix}.kv_proj_encoder", |
1527 | | - ) |
1528 | | - |
1529 | | - # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. |
1530 | | - self.q_size = self.q_proj_decoder.output_size_per_partition |
1531 | | - self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size |
1532 | | - |
1533 | | - if bias: |
1534 | | - self.bias = torch.nn.Parameter() |
1535 | | - set_weight_attrs( |
1536 | | - self.bias, |
1537 | | - { |
1538 | | - "output_dim": 0, |
1539 | | - "weight_loader": self.weight_loader_v1, |
1540 | | - }, |
1541 | | - ) |
1542 | | - else: |
1543 | | - self.bias = None |
1544 | | - |
1545 | | - def process_weights_after_loading(self): |
1546 | | - for layer in self.proj.values(): |
1547 | | - if self.quant_method is not None: |
1548 | | - self.quant_method.process_weights_after_loading(layer) |
1549 | | - |
1550 | | - @property |
1551 | | - def q_proj_decoder(self) -> ColumnParallelLinear: |
1552 | | - layer = self.proj["q_proj_decoder"] |
1553 | | - for name, param in self.named_parameters(): |
1554 | | - target_param = getattr(layer, name, None) |
1555 | | - if target_param is not None: |
1556 | | - self.sync_weight_attrs(param, target_param, mode="q_proj_decoder") |
1557 | | - return layer |
1558 | | - |
1559 | | - @property |
1560 | | - def kv_proj_encoder(self) -> QKVParallelLinear: |
1561 | | - layer = self.proj["kv_proj_encoder"] |
1562 | | - for name, param in self.named_parameters(): |
1563 | | - target_param = getattr(layer, name, None) |
1564 | | - if target_param is not None: |
1565 | | - self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder") |
1566 | | - return layer |
1567 | | - |
1568 | | - def sync_weight_attrs( |
1569 | | - self, |
1570 | | - src_param: nn.Parameter, |
1571 | | - tgt_param: nn.Parameter, |
1572 | | - mode: Literal["q_proj_decoder", "kv_proj_encoder"], |
1573 | | - ): |
1574 | | - missing_attrs_dict = { |
1575 | | - k: getattr(src_param, k) |
1576 | | - for k in (set(vars(src_param).keys()) - set(vars(tgt_param).keys())) |
1577 | | - } |
1578 | | - # TODO(Isotr0py): handle bitsandbytes 8bit |
1579 | | - use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", False) |
1580 | | - if missing_attrs_dict and use_bitsandbytes_4bit: |
1581 | | - q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard( |
1582 | | - missing_attrs_dict |
1583 | | - ) |
1584 | | - if mode == "q_proj_decoder": |
1585 | | - set_weight_attrs(tgt_param, q_proj_attrs) |
1586 | | - elif mode == "kv_proj_encoder": |
1587 | | - set_weight_attrs(tgt_param, kv_proj_attrs) |
1588 | | - else: |
1589 | | - set_weight_attrs(tgt_param, missing_attrs_dict) |
1590 | | - |
1591 | | - def _is_same_param( |
1592 | | - self, |
1593 | | - src_param: torch.nn.Parameter, |
1594 | | - map_param: torch.nn.Parameter, |
1595 | | - ) -> bool: |
1596 | | - """Check if two parameters are exactly pointing to same things.""" |
1597 | | - # ignore weight_loader because it's always different |
1598 | | - key_to_ignore = ["weight_loader", "_weight_loader"] |
1599 | | - has_same_type_name = type(src_param) is type(map_param) |
1600 | | - src_param_attrs = { |
1601 | | - k: v for k, v in src_param.__dict__.items() if k not in key_to_ignore |
1602 | | - } |
1603 | | - map_param_attrs = { |
1604 | | - k: v for k, v in map_param.__dict__.items() if k not in key_to_ignore |
1605 | | - } |
1606 | | - has_same_attrs = src_param_attrs == map_param_attrs |
1607 | | - return has_same_type_name and has_same_attrs |
1608 | | - |
1609 | | - def select_proj_params( |
1610 | | - self, |
1611 | | - layer: nn.Module, |
1612 | | - param: nn.Parameter, |
1613 | | - ) -> nn.Parameter: |
1614 | | - """ |
1615 | | - Given the placeholder param, |
1616 | | - return the corresponding param in the proj layers. |
1617 | | - """ |
1618 | | - target_param_list = [ |
1619 | | - v for _, v in layer.named_parameters() if self._is_same_param(param, v) |
1620 | | - ] |
1621 | | - assert len(target_param_list) == 1 |
1622 | | - target_param = target_param_list[0] |
1623 | | - return target_param |
1624 | | - |
1625 | | - def forward( # type: ignore[override] |
1626 | | - self, |
1627 | | - decoder_hidden_states: torch.Tensor, |
1628 | | - encoder_hidden_states: torch.Tensor, |
1629 | | - ) -> tuple[torch.Tensor, ...]: |
1630 | | - q, _ = self.q_proj_decoder(decoder_hidden_states) |
1631 | | - if encoder_hidden_states is None: |
1632 | | - # Encoder KV already cached. |
1633 | | - k = None |
1634 | | - v = None |
1635 | | - else: |
1636 | | - # Prefill phase, encoder KV cached here. |
1637 | | - kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states) |
1638 | | - # Split kv in half |
1639 | | - k, v = kv_enc.split(self.kv_size, dim=-1) |
1640 | | - return q, k, v |
1641 | | - |
1642 | | - def weight_loader_v1( |
1643 | | - self, |
1644 | | - param: torch.nn.Parameter, |
1645 | | - loaded_weight: torch.Tensor, |
1646 | | - loaded_shard_id: Optional[str] = None, |
1647 | | - ): |
1648 | | - # just like all other parameters, does not yet |
1649 | | - # support loading bias with weight_loader_v2 |
1650 | | - layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder |
1651 | | - target_param = self.select_proj_params(layer, param) |
1652 | | - shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else () |
1653 | | - layer.weight_loader(target_param, loaded_weight, *shard_id_args) |
1654 | | - |
1655 | | - def weight_loader( |
1656 | | - self, |
1657 | | - param: torch.nn.Parameter, |
1658 | | - loaded_weight: torch.Tensor, |
1659 | | - loaded_shard_id: Optional[str] = None, |
1660 | | - ): |
1661 | | - layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder |
1662 | | - target_param = self.select_proj_params(layer, param) |
1663 | | - shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else () |
1664 | | - if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED: |
1665 | | - layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args) |
1666 | | - else: |
1667 | | - layer.weight_loader(target_param, loaded_weight, *shard_id_args) |
1668 | | - |
1669 | | - def extra_repr(self) -> str: |
1670 | | - s = f"in_features={self.input_size}" |
1671 | | - s += f", q_size={self.q_size}" |
1672 | | - s += f", kv_size={self.kv_size}" |
1673 | | - s += f", bias={self.bias is not None}" |
1674 | | - s += f", tp_size={get_tensor_model_parallel_world_size()}" |
1675 | | - s += ", gather_output=False" |
1676 | | - return s |
0 commit comments