-
Notifications
You must be signed in to change notification settings - Fork 1.1k
fix kt-kernel #1527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix kt-kernel #1527
Conversation
Summary of ChangesHello @ovowei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a robust online weight quantization feature to the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant new feature: a weight conversion tool for quantizing FP8/FP16/BF16 models to INT4/INT8 for optimized AMX inference. The implementation involves a new OnlineQuantConverter that leverages the AMXMoEWrapper to perform online quantization. The changes are extensive and well-structured.
My review focuses on improving robustness, resource management, and code clarity. Key feedback includes:
- Refactoring the quantization method selection in
AMXMoEWrapperto use an explicit function argument instead of an environment variable. - Improving the robustness of tensor loading by passing
dtypeexplicitly instead of inferring it from filenames. - Addressing a potential resource leak by changing how
safetensorsfiles are handled to ensure they are properly closed. - Minor style improvements, such as moving an import to the top of the file and converting a dangling string to a proper comment.
Overall, this is a great addition. Addressing these points will make the new conversion script more robust and maintainable.
| self.tensor_file_map: Dict[str, str] = {} # key -> filename | ||
| self.file_handle_map: Dict[str, any] = {} # filename -> file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation in _load_input_files opens all .safetensors files and keeps their file handles in self.file_handle_map for the duration of the script. This can lead to a 'too many open files' error, especially for models with many shards. Since safetensors.safe_open returns a context manager, it's best to use it within a with statement to ensure files are closed properly.
A better approach would be to scan for keys without keeping files open, and then open each file on-demand when a tensor is needed. This would involve:
- Removing
self.file_handle_map. - In
_load_input_files, usewith safe_open(...)to scan keys and store the fullfile_pathinself.tensor_file_map. - In
_load_tensor(which is not in this diff), usewith safe_open(...)again to load the tensor on demand.
| def load_weights_from_tensors( | ||
| self, | ||
| gate_proj: torch.Tensor, | ||
| up_proj: torch.Tensor, | ||
| down_proj: torch.Tensor, | ||
| physical_to_logical_map_cpu: torch.Tensor, | ||
| ): | ||
| """ | ||
| Load and quantize weights from BF16/FP16 tensors (online quantization). | ||
| Args: | ||
| gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size] | ||
| up_proj: Up projection weights [num_experts, intermediate_size, hidden_size] | ||
| down_proj: Down projection weights [num_experts, hidden_size, intermediate_size] | ||
| physical_to_logical_map_cpu: Mapping from physical to logical expert IDs | ||
| """ | ||
| # Store tensors as instance variables to keep them alive | ||
| self.gate_proj = gate_proj.contiguous() | ||
| self.up_proj = up_proj.contiguous() | ||
| self.down_proj = down_proj.contiguous() | ||
|
|
||
| # Configure MoE with online quantization (cpu_save mode) | ||
| moe_config = MOEConfig( | ||
| self.num_experts, | ||
| self.num_experts_per_tok, | ||
| self.hidden_size, | ||
| self.moe_intermediate_size, | ||
| self.num_gpu_experts, | ||
| ) | ||
| moe_config.layer_idx = self.layer_idx | ||
| moe_config.pool = self.cpu_infer.backend_ | ||
| moe_config.max_len = self.chunked_prefill_size | ||
|
|
||
| # Enable save mode for online quantization | ||
| moe_config.save = True | ||
| moe_config.load = False | ||
|
|
||
| # Set weight pointers | ||
| moe_config.gate_proj = self.gate_proj.data_ptr() | ||
| moe_config.up_proj = self.up_proj.data_ptr() | ||
| moe_config.down_proj = self.down_proj.data_ptr() | ||
|
|
||
| # Set output path for quantized weights | ||
| moe_config.path = self.amx_weight_path | ||
|
|
||
| # Create MoE module based on AMX method | ||
| amx_method = os.environ.get("AMX_METHOD", "AMXINT4") | ||
| if amx_method == "AMXINT4": | ||
| self.moe = AMXInt4_MOE(moe_config) | ||
| elif amx_method == "AMXINT8": | ||
| self.moe = AMXInt8_MOE(moe_config) | ||
| else: | ||
| raise NotImplementedError(f"Unsupported AMX method: {amx_method}") | ||
|
|
||
| # Submit quantization and save task | ||
| self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr())) | ||
| self.cpu_infer.sync() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The quantization method is determined by reading the AMX_METHOD environment variable inside this function. This makes the function's behavior implicit and dependent on external state, which can be difficult to debug and test. It would be more robust to pass the quantization method as an explicit argument.
You would also need to update the call site in scripts/convert_weights.py to pass self.quant_method.
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
up_proj: torch.Tensor,
down_proj: torch.Tensor,
physical_to_logical_map_cpu: torch.Tensor,
quant_method: str = "int4",
):
"""
Load and quantize weights from BF16/FP16 tensors (online quantization).
Args:
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
quant_method: The quantization method to use ('int4' or 'int8').
"""
# Store tensors as instance variables to keep them alive
self.gate_proj = gate_proj.contiguous()
self.up_proj = up_proj.contiguous()
self.down_proj = down_proj.contiguous()
# Configure MoE with online quantization (cpu_save mode)
moe_config = MOEConfig(
self.num_experts,
self.num_experts_per_tok,
self.hidden_size,
self.moe_intermediate_size,
self.num_gpu_experts,
)
moe_config.layer_idx = self.layer_idx
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
# Enable save mode for online quantization
moe_config.save = True
moe_config.load = False
# Set weight pointers
moe_config.gate_proj = self.gate_proj.data_ptr()
moe_config.up_proj = self.up_proj.data_ptr()
moe_config.down_proj = self.down_proj.data_ptr()
# Set output path for quantized weights
moe_config.path = self.amx_weight_path
# Create MoE module based on quantization method
if quant_method == "int4":
self.moe = AMXInt4_MOE(moe_config)
elif quant_method == "int8":
self.moe = AMXInt8_MOE(moe_config)
else:
raise NotImplementedError(f"Unsupported quantization method: {quant_method}")
# Submit quantization and save task
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()| def _load_binary_tensor(self, file_path: str) -> torch.Tensor: | ||
| """Load .kt format binary tensor file | ||
| Args: | ||
| file_path: Path to .kt binary file | ||
| Returns: | ||
| torch.Tensor: Loaded tensor | ||
| """ | ||
| if not os.path.exists(file_path): | ||
| raise FileNotFoundError(f"File not found: {file_path}") | ||
|
|
||
| with open(file_path, 'rb') as f: | ||
| binary_data = f.read() | ||
|
|
||
| # Determine dtype based on file name | ||
| if 'scale' in file_path: | ||
| # Scale tensors are typically float32 | ||
| np_array = np.frombuffer(binary_data, dtype=np.float32) | ||
| else: | ||
| # Quant tensors are typically int8 | ||
| np_array = np.frombuffer(binary_data, dtype=np.int8) | ||
|
|
||
| tensor = torch.from_numpy(np_array.copy()) | ||
| return tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Determining the tensor's data type based on the presence of the substring 'scale' in the file path is brittle. This can lead to errors if file naming conventions change. A more robust approach is to pass the expected dtype as an argument to this function from the call site, which already has the context to know whether it's loading a weight or a scale tensor.
You would then update the calls in _load_layer_tensors_from_disk to pass the appropriate dtype, for example:
# in _load_layer_tensors_from_disk
if quant_files:
tensors[weight_key] = self._load_binary_tensor(quant_files[0], dtype=np.int8)
if scale_files:
tensors[scale_key] = self._load_binary_tensor(scale_files[0], dtype=np.float32) def _load_binary_tensor(self, file_path: str, dtype) -> torch.Tensor:
"""Load .kt format binary tensor file
Args:
file_path: Path to .kt binary file
dtype: The numpy dtype of the tensor.
Returns:
torch.Tensor: Loaded tensor
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, 'rb') as f:
binary_data = f.read()
np_array = np.frombuffer(binary_data, dtype=dtype)
tensor = torch.from_numpy(np_array.copy())
return tensor| Args: | ||
| layer_idx: Layer index | ||
| """ | ||
| import shutil |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """ | ||
| Example usage(test passed): | ||
| python convert_weights.py --input-path /mnt/data3/models/DeepSeek-V3.1 --input-type fp8 --output /mnt/data3/models/DeepSeek-V3.1-INT4-test --quant-method int4 --cpuinfer-threads 60 --subpool-count 2 | ||
| python convert_weights.py --input-path /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct --input-type bf16 --output /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct-INT4-test --quant-method int4 --cpuinfer-threads 60 --subpool-count 2 | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This multi-line string containing example usage is a dangling literal that has no effect at runtime. For clarity and correctness, it should be converted into a proper comment block.
# Example usage(test passed):
# python convert_weights.py --input-path /mnt/data3/models/DeepSeek-V3.1 --input-type fp8 --output /mnt/data3/models/DeepSeek-V3.1-INT4-test --quant-method int4 --cpuinfer-threads 60 --subpool-count 2
# python convert_weights.py --input-path /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct --input-type bf16 --output /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct-INT4-test --quant-method int4 --cpuinfer-threads 60 --subpool-count 2
No description provided.