Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions kt-kernel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,63 @@ pip install .
```bash
python -c "from kt_kernel import AMXMoEWrapper; print('✓ kt-kernel installed successfully')"
```

## Weight Quantization

KT-Kernel provides a weight conversion tool to quantize model weights from FP8/FP16/BF16 to INT4/INT8 format optimized for AMX inference.

### Quantization Methods

- **INT4**: 4-bit quantization for maximum memory efficiency
- **INT8**: 8-bit quantization for better accuracy

### Supported Input Formats

- **FP8**: 8-bit floating point with automatic dequantization
- **FP16**: 16-bit floating point
- **BF16**: BFloat16 format

### Basic Usage

```bash
# Quantize BF16 model to INT4
python scripts/convert_weights.py \
--input-path /path/to/bf16/model \
--input-type bf16 \
--output /path/to/output \
--quant-method int4

# Quantize FP16 model to INT8
python scripts/convert_weights.py \
--input-path /path/to/fp16/model \
--input-type fp16 \
--output /path/to/output \
--quant-method int8

# Quantize FP8 model to INT4
python scripts/convert_weights.py \
--input-path /path/to/fp8/model \
--input-type fp8 \
--output /path/to/output \
--quant-method int4
```

### Output Format

The converted weights are saved in SafeTensors format with NUMA-aware layout:
```
output_dir/
├── model-00001-of-00050.safetensors
├── model-00002-of-00050.safetensors
├── ...
├── config.json
└── tokenizer files...
```

Each expert's weights are split across NUMA nodes for optimal memory access:
- `blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.weight`: Quantized weights
- `blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.scale`: Quantization scales

## Before Commit!
your msg should match: Conventional Commits (https://www.conventionalcommits.org/) <br>and format your code before commit:
```shell
Expand Down
70 changes: 64 additions & 6 deletions kt-kernel/python/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(
moe_intermediate_size: int,
num_gpu_experts: int,
cpuinfer_threads: int,
subpool_count: int,
threadpool_count: int,
amx_weight_path: str,
chunked_prefill_size: int,
cpu_save: bool = False,
Expand All @@ -207,7 +207,7 @@ def __init__(
moe_intermediate_size: MoE intermediate size
num_gpu_experts: Number of experts to run on GPU
cpuinfer_threads: Number of CPU inference threads
subpool_count: Number of NUMA subpools
threadpool_count: Number of NUMA subpools
amx_weight_path: Path to AMX weights
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory
Expand All @@ -227,13 +227,13 @@ def __init__(
if AMXMoEWrapper._cpu_infer_instance is None:
worker_config = cpuinfer_ext.WorkerPoolConfig()

subpool_numa_map = list(range(subpool_count))
subpool_numa_map = list(range(threadpool_count))
subpool_thread_count = [
cpuinfer_threads // subpool_count + (1 if i < cpuinfer_threads % subpool_count else 0)
for i in range(subpool_count)
cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0)
for i in range(threadpool_count)
]

worker_config.subpool_count = subpool_count
worker_config.subpool_count = threadpool_count
worker_config.subpool_numa_map = subpool_numa_map
worker_config.subpool_thread_count = subpool_thread_count
AMXMoEWrapper._cpu_infer_instance = cpuinfer_ext.CPUInfer(worker_config)
Expand Down Expand Up @@ -261,6 +261,64 @@ def __init__(
self.up_scales = None
self.down_scales = None

def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
):
"""
Load and quantize weights from BF16/FP16 tensors (online quantization).

Args:
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
# Store tensors as instance variables to keep them alive
self.gate_proj = gate_proj.contiguous()
self.up_proj = up_proj.contiguous()
self.down_proj = down_proj.contiguous()

# Configure MoE with online quantization (cpu_save mode)
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size

# Enable save mode for online quantization
moe_config.save = True
moe_config.load = False

# Set weight pointers
moe_config.gate_proj = self.gate_proj.data_ptr()
moe_config.up_proj = self.up_proj.data_ptr()
moe_config.down_proj = self.down_proj.data_ptr()

# Set output path for quantized weights
moe_config.path = self.amx_weight_path

# Create MoE module based on AMX method
amx_method = os.environ.get("AMX_METHOD", "AMXINT4")
if amx_method == "AMXINT4":
self.moe = AMXInt4_MOE(moe_config)
elif amx_method == "AMXINT8":
self.moe = AMXInt8_MOE(moe_config)
else:
raise NotImplementedError(f"Unsupported AMX method: {amx_method}")

# Submit quantization and save task
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
Comment on lines +264 to +320
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The quantization method is determined by reading the AMX_METHOD environment variable inside this function. This makes the function's behavior implicit and dependent on external state, which can be difficult to debug and test. It would be more robust to pass the quantization method as an explicit argument.

You would also need to update the call site in scripts/convert_weights.py to pass self.quant_method.

    def load_weights_from_tensors(
        self,
        gate_proj: torch.Tensor,
        up_proj: torch.Tensor,
        down_proj: torch.Tensor,
        physical_to_logical_map_cpu: torch.Tensor,
        quant_method: str = "int4",
    ):
        """
        Load and quantize weights from BF16/FP16 tensors (online quantization).

        Args:
            gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
            up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
            down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
            physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
            quant_method: The quantization method to use ('int4' or 'int8').
        """
        # Store tensors as instance variables to keep them alive
        self.gate_proj = gate_proj.contiguous()
        self.up_proj = up_proj.contiguous()
        self.down_proj = down_proj.contiguous()

        # Configure MoE with online quantization (cpu_save mode)
        moe_config = MOEConfig(
            self.num_experts,
            self.num_experts_per_tok,
            self.hidden_size,
            self.moe_intermediate_size,
            self.num_gpu_experts,
        )
        moe_config.layer_idx = self.layer_idx
        moe_config.pool = self.cpu_infer.backend_
        moe_config.max_len = self.chunked_prefill_size

        # Enable save mode for online quantization
        moe_config.save = True
        moe_config.load = False

        # Set weight pointers
        moe_config.gate_proj = self.gate_proj.data_ptr()
        moe_config.up_proj = self.up_proj.data_ptr()
        moe_config.down_proj = self.down_proj.data_ptr()

        # Set output path for quantized weights
        moe_config.path = self.amx_weight_path

        # Create MoE module based on quantization method
        if quant_method == "int4":
            self.moe = AMXInt4_MOE(moe_config)
        elif quant_method == "int8":
            self.moe = AMXInt8_MOE(moe_config)
        else:
            raise NotImplementedError(f"Unsupported quantization method: {quant_method}")

        # Submit quantization and save task
        self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
        self.cpu_infer.sync()


def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
"""
Load weights for this layer and initialize the MoE module.
Expand Down
Loading