Skip to content

Conversation

@ovowei
Copy link
Contributor

@ovowei ovowei commented Oct 22, 2025

No description provided.

@ovowei ovowei merged commit 5180fe8 into main Oct 22, 2025
6 checks passed
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @ovowei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a robust online weight quantization feature to the kt-kernel library. It enables users to convert model weights from floating-point formats (FP8, FP16, BF16) to optimized integer formats (INT4, INT8) for efficient AMX inference. The changes involve a major refactoring of the weight conversion script to integrate with the AMXMoEWrapper for direct, NUMA-aware quantization and saving, along with comprehensive documentation updates to guide users through the new process.

Highlights

  • Online Weight Quantization: Implemented a new online quantization process to convert FP8/FP16/BF16 model weights directly to INT4/INT8 formats optimized for AMX inference.
  • Refactored Weight Conversion Script: The scripts/convert_weights.py script has been significantly updated to leverage the AMXMoEWrapper for online quantization, supporting various input and output quantization types.
  • NUMA-Aware Weight Handling: The quantization process now incorporates NUMA-aware memory management and automatically saves quantized weights in a NUMA-optimized layout.
  • Documentation Update: The README.md has been expanded with a detailed "Weight Quantization" section, including supported methods, input formats, basic usage examples, and output format description.
  • Parameter Renaming: The subpool_count parameter in AMXMoEWrapper has been renamed to threadpool_count for clarity, reflecting its role in thread distribution across NUMA subpools.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant new feature: a weight conversion tool for quantizing FP8/FP16/BF16 models to INT4/INT8 for optimized AMX inference. The implementation involves a new OnlineQuantConverter that leverages the AMXMoEWrapper to perform online quantization. The changes are extensive and well-structured.

My review focuses on improving robustness, resource management, and code clarity. Key feedback includes:

  • Refactoring the quantization method selection in AMXMoEWrapper to use an explicit function argument instead of an environment variable.
  • Improving the robustness of tensor loading by passing dtype explicitly instead of inferring it from filenames.
  • Addressing a potential resource leak by changing how safetensors files are handled to ensure they are properly closed.
  • Minor style improvements, such as moving an import to the top of the file and converting a dangling string to a proper comment.

Overall, this is a great addition. Addressing these points will make the new conversion script more robust and maintainable.

Comment on lines +237 to +238
self.tensor_file_map: Dict[str, str] = {} # key -> filename
self.file_handle_map: Dict[str, any] = {} # filename -> file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation in _load_input_files opens all .safetensors files and keeps their file handles in self.file_handle_map for the duration of the script. This can lead to a 'too many open files' error, especially for models with many shards. Since safetensors.safe_open returns a context manager, it's best to use it within a with statement to ensure files are closed properly.

A better approach would be to scan for keys without keeping files open, and then open each file on-demand when a tensor is needed. This would involve:

  1. Removing self.file_handle_map.
  2. In _load_input_files, use with safe_open(...) to scan keys and store the full file_path in self.tensor_file_map.
  3. In _load_tensor (which is not in this diff), use with safe_open(...) again to load the tensor on demand.

Comment on lines +264 to +320
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
):
"""
Load and quantize weights from BF16/FP16 tensors (online quantization).
Args:
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
"""
# Store tensors as instance variables to keep them alive
self.gate_proj = gate_proj.contiguous()
self.up_proj = up_proj.contiguous()
self.down_proj = down_proj.contiguous()

# Configure MoE with online quantization (cpu_save mode)
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size

# Enable save mode for online quantization
moe_config.save = True
moe_config.load = False

# Set weight pointers
moe_config.gate_proj = self.gate_proj.data_ptr()
moe_config.up_proj = self.up_proj.data_ptr()
moe_config.down_proj = self.down_proj.data_ptr()

# Set output path for quantized weights
moe_config.path = self.amx_weight_path

# Create MoE module based on AMX method
amx_method = os.environ.get("AMX_METHOD", "AMXINT4")
if amx_method == "AMXINT4":
self.moe = AMXInt4_MOE(moe_config)
elif amx_method == "AMXINT8":
self.moe = AMXInt8_MOE(moe_config)
else:
raise NotImplementedError(f"Unsupported AMX method: {amx_method}")

# Submit quantization and save task
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The quantization method is determined by reading the AMX_METHOD environment variable inside this function. This makes the function's behavior implicit and dependent on external state, which can be difficult to debug and test. It would be more robust to pass the quantization method as an explicit argument.

You would also need to update the call site in scripts/convert_weights.py to pass self.quant_method.

    def load_weights_from_tensors(
        self,
        gate_proj: torch.Tensor,
        up_proj: torch.Tensor,
        down_proj: torch.Tensor,
        physical_to_logical_map_cpu: torch.Tensor,
        quant_method: str = "int4",
    ):
        """
        Load and quantize weights from BF16/FP16 tensors (online quantization).

        Args:
            gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
            up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
            down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
            physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
            quant_method: The quantization method to use ('int4' or 'int8').
        """
        # Store tensors as instance variables to keep them alive
        self.gate_proj = gate_proj.contiguous()
        self.up_proj = up_proj.contiguous()
        self.down_proj = down_proj.contiguous()

        # Configure MoE with online quantization (cpu_save mode)
        moe_config = MOEConfig(
            self.num_experts,
            self.num_experts_per_tok,
            self.hidden_size,
            self.moe_intermediate_size,
            self.num_gpu_experts,
        )
        moe_config.layer_idx = self.layer_idx
        moe_config.pool = self.cpu_infer.backend_
        moe_config.max_len = self.chunked_prefill_size

        # Enable save mode for online quantization
        moe_config.save = True
        moe_config.load = False

        # Set weight pointers
        moe_config.gate_proj = self.gate_proj.data_ptr()
        moe_config.up_proj = self.up_proj.data_ptr()
        moe_config.down_proj = self.down_proj.data_ptr()

        # Set output path for quantized weights
        moe_config.path = self.amx_weight_path

        # Create MoE module based on quantization method
        if quant_method == "int4":
            self.moe = AMXInt4_MOE(moe_config)
        elif quant_method == "int8":
            self.moe = AMXInt8_MOE(moe_config)
        else:
            raise NotImplementedError(f"Unsupported quantization method: {quant_method}")

        # Submit quantization and save task
        self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
        self.cpu_infer.sync()

Comment on lines +540 to +564
def _load_binary_tensor(self, file_path: str) -> torch.Tensor:
"""Load .kt format binary tensor file
Args:
file_path: Path to .kt binary file
Returns:
torch.Tensor: Loaded tensor
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")

with open(file_path, 'rb') as f:
binary_data = f.read()

# Determine dtype based on file name
if 'scale' in file_path:
# Scale tensors are typically float32
np_array = np.frombuffer(binary_data, dtype=np.float32)
else:
# Quant tensors are typically int8
np_array = np.frombuffer(binary_data, dtype=np.int8)

tensor = torch.from_numpy(np_array.copy())
return tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Determining the tensor's data type based on the presence of the substring 'scale' in the file path is brittle. This can lead to errors if file naming conventions change. A more robust approach is to pass the expected dtype as an argument to this function from the call site, which already has the context to know whether it's loading a weight or a scale tensor.

You would then update the calls in _load_layer_tensors_from_disk to pass the appropriate dtype, for example:

# in _load_layer_tensors_from_disk
if quant_files:
    tensors[weight_key] = self._load_binary_tensor(quant_files[0], dtype=np.int8)
if scale_files:
    tensors[scale_key] = self._load_binary_tensor(scale_files[0], dtype=np.float32)
    def _load_binary_tensor(self, file_path: str, dtype) -> torch.Tensor:
        """Load .kt format binary tensor file

        Args:
            file_path: Path to .kt binary file
            dtype: The numpy dtype of the tensor.

        Returns:
            torch.Tensor: Loaded tensor
        """
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"File not found: {file_path}")

        with open(file_path, 'rb') as f:
            binary_data = f.read()

        np_array = np.frombuffer(binary_data, dtype=dtype)

        tensor = torch.from_numpy(np_array.copy())
        return tensor

Args:
layer_idx: Layer index
"""
import shutil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The shutil module is imported inside this function. According to Python style guidelines (PEP 8), imports should be placed at the top of the file. This improves readability and makes the file's dependencies clear at a glance. Please move import shutil to the top of the file.

Comment on lines +772 to +776
"""
Example usage(test passed):
python convert_weights.py --input-path /mnt/data3/models/DeepSeek-V3.1 --input-type fp8 --output /mnt/data3/models/DeepSeek-V3.1-INT4-test --quant-method int4 --cpuinfer-threads 60 --subpool-count 2
python convert_weights.py --input-path /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct --input-type bf16 --output /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct-INT4-test --quant-method int4 --cpuinfer-threads 60 --subpool-count 2
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This multi-line string containing example usage is a dangling literal that has no effect at runtime. For clarity and correctness, it should be converted into a proper comment block.

# Example usage(test passed):
# python convert_weights.py --input-path /mnt/data3/models/DeepSeek-V3.1 --input-type fp8 --output /mnt/data3/models/DeepSeek-V3.1-INT4-test --quant-method int4 --cpuinfer-threads 60 --subpool-count 2
# python convert_weights.py --input-path /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct --input-type bf16 --output /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct-INT4-test --quant-method int4 --cpuinfer-threads 60 --subpool-count 2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants