From 16b587c104708c0b5b9e05142031d7240b16213a Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Tue, 26 Nov 2024 09:48:27 -0800 Subject: [PATCH 01/15] fix: hidden states handling in batch expansion for spec decoding (#839) --- aphrodite/spec_decode/batch_expansion.py | 83 ++++++++++++++------- aphrodite/spec_decode/spec_decode_worker.py | 4 +- aphrodite/spec_decode/top1_proposer.py | 2 +- aphrodite/spec_decode/util.py | 18 ++++- tests/spec_decode/e2e/conftest.py | 23 +++--- 5 files changed, 89 insertions(+), 41 deletions(-) diff --git a/aphrodite/spec_decode/batch_expansion.py b/aphrodite/spec_decode/batch_expansion.py index 0f803ab74..ca6cb6114 100644 --- a/aphrodite/spec_decode/batch_expansion.py +++ b/aphrodite/spec_decode/batch_expansion.py @@ -1,6 +1,6 @@ from array import array from itertools import chain, count -from typing import Iterator, List, Tuple +from typing import Iterator, List, Optional, Tuple import torch @@ -90,21 +90,22 @@ def score_proposals( assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] - all_tokens, all_probs, spec_logprobs = self._contract_batch( - contracted_bs=len(execute_model_req.seq_group_metadata_list), - target_sampler_output=target_sampler_output, - proposals=proposals, - num_scoring_tokens=num_scoring_tokens, - non_spec_indices=non_spec_indices, - spec_indices=spec_indices, - k=execute_model_req.num_lookahead_slots, - ) + (all_tokens, all_probs, spec_logprobs, + all_hidden_states) = self._contract_batch( + contracted_bs=len(execute_model_req.seq_group_metadata_list), + target_sampler_output=target_sampler_output, + proposals=proposals, + num_scoring_tokens=num_scoring_tokens, + non_spec_indices=non_spec_indices, + spec_indices=spec_indices, + k=execute_model_req.num_lookahead_slots, + ) return SpeculativeScores( probs=all_probs, token_ids=all_tokens, logprobs=spec_logprobs, - hidden_states=target_sampler_output.hidden_states, + hidden_states=all_hidden_states, ) def _expand_batch( @@ -147,10 +148,11 @@ def _expand_batch( num_scoring_tokens) def _contract_batch( - self, contracted_bs: int, target_sampler_output: SamplerOutput, - proposals: SpeculativeProposals, num_scoring_tokens: int, - non_spec_indices: List[int], spec_indices: List[int], - k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self, contracted_bs: int, target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], k: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. @@ -158,9 +160,10 @@ def _contract_batch( contracted_bs is the original batch size, and the batch size that the target_sampler_output will be contracted to. """ - (target_token_ids, target_probs, target_logprobs, + (target_token_ids, target_probs, target_logprobs, target_hidden_states, non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) = self._split_scoring_output( + non_spec_target_logprobs, + non_spec_target_hidden_states) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) # Map distinct sequences used to score each token @@ -178,23 +181,36 @@ def _contract_batch( self._vocab_size) target_logprobs = target_logprobs.reshape(target_probs.shape) + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + spec_expanded_bs, k + 1, target_hidden_states.shape[-1]) all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), fill_value=-1) all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) all_logprobs = target_logprobs.new_full(size=all_probs.shape, fill_value=-float("inf")) + if target_sampler_output.hidden_states is not None: + all_hidden_states = target_hidden_states.new_zeros( + size=(contracted_bs, k + 1, target_hidden_states.shape[-1])) + else: + all_hidden_states = None if non_spec_indices: all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_probs[non_spec_indices, :1, :] = non_spec_target_probs all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs + if all_hidden_states is not None: + all_hidden_states[ + non_spec_indices, :1, :] = non_spec_target_hidden_states if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs all_logprobs[spec_indices] = target_logprobs - return all_tokens, all_probs, all_logprobs + if all_hidden_states is not None: + all_hidden_states[spec_indices] = target_hidden_states + return all_tokens, all_probs, all_logprobs, all_hidden_states def _create_scoring_model_input( self, @@ -329,8 +345,9 @@ def _create_single_target_seq_group_metadata( def _split_scoring_output( self, sampler_output: SamplerOutput, num_scoring_tokens: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor, Optional[torch.Tensor]]: """Split the target model output into speculative and non-speculative output. """ @@ -354,24 +371,38 @@ def _split_scoring_output( non_spec_logprobs, ) = sampler_output.logprobs.split(split_sizes) + if sampler_output.hidden_states is not None: + ( + spec_hidden_states, + non_spec_hidden_states, + ) = sampler_output.hidden_states.split(split_sizes) + else: + spec_hidden_states, non_spec_hidden_states = None, None # Convert scores to tensors. sampler_output.sampled_token_probs = spec_probs sampler_output.sampled_token_ids = spec_sampled_tokens sampler_output.logprobs = spec_logprobs - (target_token_ids, target_probs, - target_logprobs) = sampler_output_to_torch([sampler_output], True) + sampler_output.hidden_states = spec_hidden_states + (target_token_ids, target_probs, target_logprobs, + target_hidden_states) = sampler_output_to_torch([sampler_output], + True) # Convert non-speculative output tokens to tensors. sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_ids = non_spec_sampled_tokens sampler_output.logprobs = non_spec_logprobs + sampler_output.hidden_states = non_spec_hidden_states (non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) = sampler_output_to_torch([sampler_output], - True) + non_spec_target_logprobs, + non_spec_target_hidden_states) = sampler_output_to_torch( + [sampler_output], True) return (target_token_ids, target_probs, target_logprobs, - non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs) + target_hidden_states, non_spec_target_token_ids, + non_spec_target_probs, non_spec_target_logprobs, + non_spec_target_hidden_states) + + def _create_target_seq_id_iterator( self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: diff --git a/aphrodite/spec_decode/spec_decode_worker.py b/aphrodite/spec_decode/spec_decode_worker.py index 535713695..986a6126e 100644 --- a/aphrodite/spec_decode/spec_decode_worker.py +++ b/aphrodite/spec_decode/spec_decode_worker.py @@ -645,9 +645,7 @@ def _verify_tokens( hidden_states = proposal_scores.hidden_states if hidden_states is not None: # Contract hidden states based on accepted tokens - hs_size = hidden_states.shape[1] - hidden_states = hidden_states.reshape(-1, max_proposal_len + 1, - hs_size) + hs_size = hidden_states.shape[-1] accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) index = accepted_index[:, None, None].expand(-1, 1, hs_size) diff --git a/aphrodite/spec_decode/top1_proposer.py b/aphrodite/spec_decode/top1_proposer.py index 22464251e..1b4acbaf2 100644 --- a/aphrodite/spec_decode/top1_proposer.py +++ b/aphrodite/spec_decode/top1_proposer.py @@ -242,7 +242,7 @@ def _merge_outputs( return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output - proposal_tokens, proposal_probs, _ = sampler_output_to_torch( + proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( sampler_output, sampler_transposed) # Now, reformat the output GPU tensors such that each sequence has diff --git a/aphrodite/spec_decode/util.py b/aphrodite/spec_decode/util.py index d881fa328..740b3a62b 100644 --- a/aphrodite/spec_decode/util.py +++ b/aphrodite/spec_decode/util.py @@ -124,7 +124,7 @@ def split_batch_by_proposal_len( def sampler_output_to_torch( sampler_output_list: List[SamplerOutput], sampler_transposed: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Utility function which converts a list of SamplerOutput to tensors. sampler_transposed here is used as the indicator for whether @@ -170,7 +170,21 @@ def sampler_output_to_torch( if sampler_transposed: sampled_token_ids = sampled_token_ids.transpose(0, 1) - return sampled_token_ids, sampled_token_probs, sampled_token_logprobs + if sampler_output_list[0].hidden_states is not None: + # shape: [batch_size, num_sampler_output, hidden_dim] + sampled_hidden_states = torch.stack( + [ + sampler_output.hidden_states + for sampler_output in sampler_output_list + ], + dim=0, + ) + if sampler_transposed: + sampled_hidden_states = sampled_hidden_states.transpose(0, 1) + else: + sampled_hidden_states = None + return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs, + sampled_hidden_states) def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 8c2cf6f50..99d04e089 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -286,15 +286,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, ensure_all_accepted=ensure_all_accepted) -def run_equality_correctness_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len, - force_output_len: bool, - temperature: float, - seeded: bool, - print_tokens: bool = False, - ensure_all_accepted: bool = False): +def run_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + temperature: float, + seeded: bool, + print_tokens: bool = False, + ensure_all_accepted: bool = False, + expected_acceptance_rate: Optional[float] = None): """Helper method that compares the outputs of both the baseline LLM and the test LLM. It asserts greedy equality, e.g. that the outputs are exactly the same when temperature is zero (or when temperature is > 0 and seeded). @@ -355,5 +357,8 @@ def run_equality_correctness_test(baseline_llm_generator, print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids + print(f'{acceptance_rate=}') if ensure_all_accepted: assert acceptance_rate == 1.0 + if expected_acceptance_rate is not None: + assert acceptance_rate >= expected_acceptance_rate - 1e-2 From 7c7ec12f3608caef39f4c0ef60617685b2007f94 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Tue, 26 Nov 2024 10:07:13 -0800 Subject: [PATCH 02/15] chore: refactor executor classes for easier inheritance (#840) --- aphrodite/executor/gpu_executor.py | 29 +++++++++++++++----------- aphrodite/executor/ray_gpu_executor.py | 19 ++++++++--------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/aphrodite/executor/gpu_executor.py b/aphrodite/executor/gpu_executor.py index eaa0dfbbf..5249545bf 100644 --- a/aphrodite/executor/gpu_executor.py +++ b/aphrodite/executor/gpu_executor.py @@ -61,6 +61,18 @@ def _get_worker_kwargs( or (rank % self.parallel_config.tensor_parallel_size == 0), ) + def _get_worker_module_and_class(self) -> Tuple[str, str]: + if self.scheduler_config.is_multi_step: + worker_module_name = "aphrodite.task_handler.multi_step_worker" + worker_class_name = "MultiStepWorker" + elif self.speculative_config: + worker_module_name = "aphrodite.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + else: + worker_module_name = "aphrodite.task_handler.worker" + worker_class_name = "Worker" + return (worker_module_name, worker_class_name) + def _get_create_worker_kwargs( self, local_rank: int = 0, @@ -68,18 +80,11 @@ def _get_create_worker_kwargs( distributed_init_method: Optional[str] = None) -> Dict: worker_kwargs = self._get_worker_kwargs(local_rank, rank, distributed_init_method) - if self.scheduler_config.is_multi_step: - worker_kwargs.update( - worker_module_name="aphrodite.task_handler.multi_step_worker", - worker_class_name="MultiStepWorker") - elif self.speculative_config: - worker_kwargs.update( - worker_module_name="aphrodite.spec_decode.spec_decode_worker", - worker_class_name="create_spec_worker") - else: - worker_kwargs.update( - worker_module_name="aphrodite.task_handler.worker", - worker_class_name="Worker") + (worker_module_name, + worker_class_name) = self._get_worker_module_and_class() + worker_kwargs.update(worker_module_name=worker_module_name, + worker_class_name=worker_class_name) + return worker_kwargs def _create_worker(self, diff --git a/aphrodite/executor/ray_gpu_executor.py b/aphrodite/executor/ray_gpu_executor.py index ab22ba3a4..de6979132 100644 --- a/aphrodite/executor/ray_gpu_executor.py +++ b/aphrodite/executor/ray_gpu_executor.py @@ -101,15 +101,8 @@ def _configure_ray_workers_use_nsight(self, return ray_remote_kwargs def _get_worker_wrapper_args(self) -> Dict[str, Any]: - if self.speculative_config is not None: - worker_module_name = "aphrodite.spec_decode.spec_decode_worker" - worker_class_name = "create_spec_worker" - elif self.scheduler_config.is_multi_step: - worker_module_name = "aphrodite.task_handler.multi_step_worker" - worker_class_name = "MultiStepWorker" - else: - worker_module_name = "aphrodite.task_handler.worker" - worker_class_name = "Worker" + (worker_module_name, + worker_class_name) = self._get_worker_module_and_class() return dict( worker_module_name=worker_module_name, @@ -117,6 +110,10 @@ def _get_worker_wrapper_args(self) -> Dict[str, Any]: trust_remote_code=self.model_config.trust_remote_code, ) + # child class could overwrite this to return actual env vars. + def _get_env_vars_to_be_updated(self): + return self._env_vars_for_all_workers + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): if (self.parallel_config.tensor_parallel_size == 1 @@ -240,8 +237,10 @@ def sort_by_driver_then_worker_ip(worker): "APHRODITE_TRACE_FUNCTION": str(APHRODITE_TRACE_FUNCTION), }, ) for (node_id, _) in worker_node_and_gpu_ids] + self._env_vars_for_all_workers = ( + all_args_to_update_environment_variables) self._run_workers("update_environment_variables", - all_args=all_args_to_update_environment_variables) + all_args=self._get_env_vars_to_be_updated()) if len(node_gpus) == 1: # in single node case, we don't need to get the IP address. From 563e8f7ac85bcf29081f7b707d3fea2571ba7620 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Tue, 26 Nov 2024 10:46:02 -0800 Subject: [PATCH 03/15] fix: latency and serving benchmarks (#841) --- tests/benchmarks/backend_request_func.py | 5 +++-- tests/benchmarks/engine/latency.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/benchmarks/backend_request_func.py b/tests/benchmarks/backend_request_func.py index 9e6bb17d9..68ea48842 100644 --- a/tests/benchmarks/backend_request_func.py +++ b/tests/benchmarks/backend_request_func.py @@ -276,8 +276,9 @@ async def async_request_openai_completions( output.ttft = ttft # Decoding phase - output.itl.append(timestamp - - most_recent_timestamp) + else: + output.itl.append(timestamp - + most_recent_timestamp) most_recent_timestamp = timestamp generated_text += data["choices"][0]["text"] diff --git a/tests/benchmarks/engine/latency.py b/tests/benchmarks/engine/latency.py index 1efb0eba0..c0c06c192 100644 --- a/tests/benchmarks/engine/latency.py +++ b/tests/benchmarks/engine/latency.py @@ -12,7 +12,7 @@ from aphrodite import LLM, SamplingParams from aphrodite.common.utils import FlexibleArgumentParser from aphrodite.engine.args_tools import EngineArgs -from aphrodite.inputs import PromptStrictInputs +from aphrodite.inputs import PromptInputs from aphrodite.quantization import QUANTIZATION_METHODS @@ -62,7 +62,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptStrictInputs] = [{ + dummy_inputs: List[PromptInputs] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] From 93bc86359144bfcb2f381664279aed166fa08d5a Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Tue, 26 Nov 2024 19:47:38 -0800 Subject: [PATCH 04/15] feat: Machete Kernels for Hopper GPUs (#842) * add cute and torch utils * add cutlass extensions for aphrodite * add machete kernels * cmakelists fixes * prepack_layout -> prepacked_layout * fix: cute and torch utils * lmao! * fix numeric conversion compilation * add custom ops for machete * add tests and benchmark scripts * integrate with gptq * gptq activation ordering --- .gitignore | 3 + CMakeLists.txt | 41 + aphrodite/_custom_ops.py | 29 + aphrodite/modeling/parameter.py | 57 + aphrodite/quantization/awq_marlin.py | 9 +- .../compressed_tensors/compressed_tensors.py | 3 +- .../schemes/compressed_tensors_wNa16.py | 130 +- .../quantization/compressed_tensors/utils.py | 26 +- aphrodite/quantization/gptq_marlin.py | 108 +- .../quantization/kernels/MPLinearKernel.py | 83 + aphrodite/quantization/kernels/__init__.py | 67 + aphrodite/quantization/kernels/machete.py | 118 ++ aphrodite/quantization/kernels/marlin.py | 132 ++ aphrodite/quantization/utils/__init__.py | 3 + aphrodite/quantization/utils/layer_utils.py | 33 + aphrodite/quantization/utils/machete_utils.py | 30 + aphrodite/quantization/utils/marlin_utils.py | 29 +- aphrodite/quantization/utils/quant_utils.py | 55 +- kernels/cuda_utils.h | 10 + .../aphrodite_collective_builder.cuh | 43 + .../aphrodite_custom_types.cuh | 51 + .../aphrodite_cutlass_library_extension.py | 49 + .../aphrodite_numeric_conversion.cuh | 797 +++++++++ kernels/cutlass_extensions/cute_utils.cuh | 68 + kernels/cutlass_extensions/torch_utils.hpp | 160 ++ kernels/ops.h | 2 + kernels/permute_cols.cu | 88 + kernels/quantization/machete/generate.py | 530 ++++++ .../machete/machete_collective_builder.cuh | 33 + .../machete/machete_interleaving_utils.cuh | 35 + .../quantization/machete/machete_mainloop.cuh | 1473 +++++++++++++++++ .../machete/machete_mm_kernel.cuh | 238 +++ .../machete/machete_mm_launcher.cuh | 95 ++ .../machete/machete_prepack_kernel.cuh | 62 + .../machete/machete_prepack_launcher.cuh | 71 + .../machete/machete_prepacked_layout.cuh | 220 +++ .../quantization/machete/machete_pytorch.cu | 79 + kernels/quantization/quant_ops.h | 21 + kernels/torch_bindings.cpp | 19 + tests/benchmarks/kernels/benchmark_machete.py | 370 +++++ .../benchmarks/kernels/graph_machete_bench.py | 64 + tests/benchmarks/kernels/weight_shapes.py | 43 + tests/kernels/test_machete_gemm.py | 274 +++ tests/kernels/test_permute_cols.py | 13 + 44 files changed, 5691 insertions(+), 173 deletions(-) create mode 100644 aphrodite/quantization/kernels/MPLinearKernel.py create mode 100644 aphrodite/quantization/kernels/__init__.py create mode 100644 aphrodite/quantization/kernels/machete.py create mode 100644 aphrodite/quantization/kernels/marlin.py create mode 100644 aphrodite/quantization/utils/layer_utils.py create mode 100644 aphrodite/quantization/utils/machete_utils.py create mode 100644 kernels/cutlass_extensions/aphrodite_collective_builder.cuh create mode 100644 kernels/cutlass_extensions/aphrodite_custom_types.cuh create mode 100644 kernels/cutlass_extensions/aphrodite_cutlass_library_extension.py create mode 100644 kernels/cutlass_extensions/aphrodite_numeric_conversion.cuh create mode 100644 kernels/cutlass_extensions/cute_utils.cuh create mode 100644 kernels/cutlass_extensions/torch_utils.hpp create mode 100644 kernels/permute_cols.cu create mode 100644 kernels/quantization/machete/generate.py create mode 100644 kernels/quantization/machete/machete_collective_builder.cuh create mode 100644 kernels/quantization/machete/machete_interleaving_utils.cuh create mode 100644 kernels/quantization/machete/machete_mainloop.cuh create mode 100644 kernels/quantization/machete/machete_mm_kernel.cuh create mode 100644 kernels/quantization/machete/machete_mm_launcher.cuh create mode 100644 kernels/quantization/machete/machete_prepack_kernel.cuh create mode 100644 kernels/quantization/machete/machete_prepack_launcher.cuh create mode 100644 kernels/quantization/machete/machete_prepacked_layout.cuh create mode 100644 kernels/quantization/machete/machete_pytorch.cu create mode 100644 tests/benchmarks/kernels/benchmark_machete.py create mode 100644 tests/benchmarks/kernels/graph_machete_bench.py create mode 100644 tests/benchmarks/kernels/weight_shapes.py create mode 100644 tests/kernels/test_machete_gemm.py create mode 100644 tests/kernels/test_permute_cols.py diff --git a/.gitignore b/.gitignore index bb77a46d5..ae8f8502f 100644 --- a/.gitignore +++ b/.gitignore @@ -207,3 +207,6 @@ images/ *.exp *.lib *.obj + +# generated files +**/generated/** \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index f5007585f..57ef0a549 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -219,6 +219,7 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA") "kernels/quantization/gptq_marlin/awq_marlin_repack.cu" "kernels/quantization/fp8/fp8_marlin.cu" "kernels/all_reduce/custom_all_reduce.cu" + "kernels/permute_cols.cu" "kernels/sampling/sampling.cu") # Add CUTLASS and GPTQ Marlin kernels if not MSVC @@ -250,6 +251,46 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA") endif() endif() + # + # For the Machete kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/kernels/cutlass_extensions/:${CUTLASS_DIR}/python/:$PYTHONPATH + ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/kernels/quantization/machete/generate.py + RESULT_VARIABLE machete_generation_result + OUTPUT_VARIABLE machete_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ) + + if (NOT machete_generation_result EQUAL 0) + message(FATAL_ERROR "Machete generation failed." + " Result: \"${machete_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log") + else() + message(STATUS "Machete generation completed successfully.") + endif() + + # Add machete generated sources + file(GLOB MACHETE_GEN_SOURCES "kernels/quantization/machete/generated/*.cu") + list(APPEND APHRODITE_EXT_SRC ${MACHETE_GEN_SOURCES}) + message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}") + + # See comment above for scaled_mm_c3x (same if condition) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + set_source_files_properties( + ${MACHETE_GEN_SOURCES} + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") + endif() + + # Add pytorch binding + list(APPEND APHRODITE_EXT_SRC + kernels/quantization/machete/machete_pytorch.cu) endif() define_gpu_extension_target( diff --git a/aphrodite/_custom_ops.py b/aphrodite/_custom_ops.py index 03fca03bc..c93a4a819 100644 --- a/aphrodite/_custom_ops.py +++ b/aphrodite/_custom_ops.py @@ -330,6 +330,35 @@ def gptq_marlin_gemm(a: torch.Tensor, is_zp_float) +# machete +def machete_supported_schedules(b_type: ScalarType) -> List[str]: + return torch.ops._C.machete_supported_schedules(b_type) + + +def machete_gemm( + a: torch.Tensor, + b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B + b_type: ScalarType, + b_scales: Optional[torch.Tensor] = None, + b_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + c: Optional[torch.Tensor] = None, + alpha: Optional[float] = None, + beta: Optional[float] = None, + schedule: Optional[str] = None, +) -> torch.Tensor: + return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros, + b_group_size, c, alpha, beta, schedule) + + +def machete_prepack_B(b_q_weight: torch.Tensor, + b_type: ScalarType) -> torch.Tensor: + return torch.ops._C.machete_prepack_B(b_q_weight, b_type) + +def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: + return torch.ops._C.permute_cols(a, perm) + + # fp8 def scaled_fp8_quant( input: torch.Tensor, diff --git a/aphrodite/modeling/parameter.py b/aphrodite/modeling/parameter.py index d637bacc8..271958b79 100644 --- a/aphrodite/modeling/parameter.py +++ b/aphrodite/modeling/parameter.py @@ -320,6 +320,63 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): marlin_tile_size=self.marlin_tile_size) +def permute_param_layout_(param: BaseAphroditeParameter, input_dim: int, + output_dim: int, **kwargs) -> BaseAphroditeParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, + useful for forcing the parameter into a known layout, for example, if I need + a packed (quantized) weight matrix to be in the layout + {input_dim = 0, output_dim = 1, packed_dim = 0} + then I can call: + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + to ensure x is in the correct layout (permuting it to the correct layout if + required, asserting if it cannot get it to the correct layout) + """ + + curr_input_dim = getattr(param, "input_dim", None) + curr_output_dim = getattr(param, "output_dim", None) + + if curr_input_dim is None or curr_output_dim is None: + assert param.data.dim() == 2,\ + "permute_param_layout_ only supports 2D parameters when either "\ + "input_dim or output_dim is not set" + + # if one of the dimensions is not set, set it to the opposite of the other + # we can only do this since we asserted the parameter is 2D above + if curr_input_dim is None: + assert curr_output_dim is not None,\ + "either input or output dim must be set" + curr_input_dim = (curr_output_dim + 1) % 2 + if curr_output_dim is None: + assert curr_input_dim is not None,\ + "either input or output dim must be set" + curr_output_dim = (curr_input_dim + 1) % 2 + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim preserving + # other dimensions + perm = [ + i for i in range(param.data.dim()) + if i not in [curr_input_dim, curr_output_dim] + ] + perm.insert(input_dim, curr_input_dim) + perm.insert(output_dim, curr_output_dim) + + if "packed_dim" in kwargs: + assert hasattr(param, "packed_dim") and\ + param.packed_dim == perm[kwargs["packed_dim"]],\ + "permute_param_layout_ currently doesn't support repacking" + + param.data = param.data.permute(*perm) + if hasattr(param, "_input_dim"): + param._input_dim = input_dim + if hasattr(param, "_output_dim"): + param._output_dim = output_dim + if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): + param._packed_dim = kwargs["packed_dim"] + + return param + def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size diff --git a/aphrodite/quantization/awq_marlin.py b/aphrodite/quantization/awq_marlin.py index ead6ef3c5..53d4c76ca 100644 --- a/aphrodite/quantization/awq_marlin.py +++ b/aphrodite/quantization/awq_marlin.py @@ -9,10 +9,11 @@ from aphrodite.modeling.parameter import (GroupQuantScaleParameter, PackedAphroditeParameter) from aphrodite.quantization.base_config import QuantizationConfig +from aphrodite.quantization.utils import replace_parameter from aphrodite.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) + verify_marlin_supported, verify_marlin_supports_shape) from aphrodite.scalar_type import scalar_types @@ -227,7 +228,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qweight", marlin_qweight) + replace_parameter(layer, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. marlin_scales = marlin_permute_scales( @@ -235,7 +236,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, group_size=self.quant_config.group_size) - replace_tensor(layer, "scales", marlin_scales) + replace_parameter(layer, "scales", marlin_scales) # Permute zero-points from AWQ format to marlin format. marlin_zp = awq_to_marlin_zero_points( @@ -243,7 +244,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.num_groups, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qzeros", marlin_zp) + replace_parameter(layer, "qzeros", marlin_zp) # Not-used layer.g_idx = marlin_make_empty_g_idx(device) diff --git a/aphrodite/quantization/compressed_tensors/compressed_tensors.py b/aphrodite/quantization/compressed_tensors/compressed_tensors.py index b62d7eea7..6239f062e 100644 --- a/aphrodite/quantization/compressed_tensors/compressed_tensors.py +++ b/aphrodite/quantization/compressed_tensors/compressed_tensors.py @@ -223,7 +223,8 @@ def _get_scheme_from_parts( return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, - group_size=weight_quant.group_size) + group_size=weight_quant.group_size, + actorder=weight_quant.actorder) # Detect If Activation Quantization. # TODO @dsikka: clean-up conditions diff --git a/aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 6881458df..d89612e03 100644 --- a/aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -1,38 +1,43 @@ -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Set import torch +from loguru import logger -from aphrodite import _custom_ops as ops from aphrodite.modeling.parameter import (BaseAphroditeParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, - PackedAphroditeParameter) + PackedAphroditeParameter, + RowAphroditeParameter) from aphrodite.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from aphrodite.quantization.compressed_tensors.utils import ActivationOrdering +from aphrodite.quantization.kernels import (MPLinearLayerConfig, + choose_mp_linear_kernel) from aphrodite.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) + marlin_repeat_scales_on_all_ranks) from aphrodite.scalar_type import scalar_types __all__ = ["CompressedTensorsWNA16"] WNA16_SUPPORTED_TYPES_MAP = { 4: scalar_types.uint4b8, - 8: scalar_types.uint8b128, + 8: scalar_types.uint8b128 } WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) class CompressedTensorsWNA16(CompressedTensorsScheme): + _kernel_backends_being_used: Set[str] = set() def __init__(self, strategy: str, num_bits: int, - group_size: Optional[int] = None): + group_size: Optional[int] = None, + actorder: Optional[ActivationOrdering] = None): self.pack_factor = 32 // num_bits self.strategy = strategy self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP if self.group_size == -1 and self.strategy != "channel": raise ValueError("Marlin kernels require group quantization or " @@ -46,36 +51,42 @@ def __init__(self, self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] - # Verify supported on platform. - verify_marlin_supported(quant_type=self.quant_type, - group_size=self.group_size) - @classmethod def get_min_capability(cls) -> int: # ampere and up return 80 - def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: List[int], + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): output_size_per_partition = sum(output_partition_sizes) + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=params_dtype, + group_size=self.group_size, + zero_points=False, + has_g_idx=self.has_g_idx + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsWNA16", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + # If group_size is -1, we are in channelwise case. - channelwise = (self.group_size == -1) group_size = self.group_size if self.group_size != -1 else input_size row_parallel = (input_size != input_size_per_partition) - # In the case of channelwise quantization, we need to replicate the - # scales across all gpus. - partition_scales = (row_parallel and not channelwise) - - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size) + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel) scales_and_zp_size = input_size // group_size @@ -123,62 +134,27 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.group_size = group_size + # group index (for activation reordering) + if self.has_g_idx: + weight_g_idx = RowAphroditeParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_g_idx", weight_g_idx) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name=None, + w_gidx_param_name="weight_g_idx") # Checkpoints are serialized in compressed-tensors format, which is - # different from marlin format. Handle repacking here. + # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.weight_packed.device - - # Allocate marlin workspace. - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) - - # Act-order not supported in compressed-tensors yet, so set to empty. - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.weight_zp = marlin_make_empty_g_idx(device) - # Update for kernel - layer.weight_packed = torch.nn.Parameter( - layer.weight_packed.t().contiguous(), requires_grad=False) - layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.squeeze().t().contiguous(), requires_grad=False) - - # Repack weights from compressed-tensors format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.weight_packed, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_type.size_bits) - replace_tensor(layer, "weight_packed", marlin_qweight) - - # Permute scales from compressed-tensors format to marlin format. - marlin_scales = marlin_permute_scales( - layer.weight_scale, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - group_size=layer.group_size) - replace_tensor(layer, "weight_scale", marlin_scales) + self.kernel.process_weights_after_loading(layer) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - - return apply_gptq_marlin_linear( - input=x, - weight=layer.weight_packed, - weight_scale=layer.weight_scale, - weight_zp=layer.weight_zp, - g_idx=layer.g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=self.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=True, - bias=bias) + return self.kernel.apply_weights(layer, x, bias) diff --git a/aphrodite/quantization/compressed_tensors/utils.py b/aphrodite/quantization/compressed_tensors/utils.py index 744341a15..883d91309 100644 --- a/aphrodite/quantization/compressed_tensors/utils.py +++ b/aphrodite/quantization/compressed_tensors/utils.py @@ -1,8 +1,8 @@ import re from enum import Enum -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from torch.nn import Module from aphrodite.quantization.utils.quant_utils import FUSED_LAYER_NAME_MAPPING @@ -39,6 +39,17 @@ class QuantizationStrategy(str, Enum): TOKEN = "token" +class ActivationOrdering(str, Enum): + """ + Enum storing strategies for activation ordering + Group: reorder groups and weight\n + Weight: only reorder weight, not groups. Slightly lower latency and + accuracy compared to group actorder\n + """ + GROUP = "group" + WEIGHT = "weight" + + class QuantizationArgs(BaseModel): """ User facing arguments used to define a quantization config @@ -57,6 +68,8 @@ class QuantizationArgs(BaseModel): observed with every sample. Defaults to False for static quantization. Note that enabling dynamic quantization will change the default observer to a memoryless one + :param actorder: whether to apply group quantization in decreasing order of + activation. Defaults to None for arbitrary ordering """ num_bits: int = 8 @@ -66,6 +79,7 @@ class QuantizationArgs(BaseModel): strategy: Optional[QuantizationStrategy] = None block_structure: Optional[str] = None dynamic: bool = False + actorder: Union[ActivationOrdering, bool, None] = None observer: str = Field( default="minmax", description=("The class to use to compute the quantization param - " @@ -78,6 +92,14 @@ class QuantizationArgs(BaseModel): "Observers constructor excluding quantization range or symmetry"), ) + @field_validator("actorder", mode="before") + def validate_actorder(cls, value) -> Optional[ActivationOrdering]: + if isinstance(value, bool): + return ActivationOrdering.GROUP if value else None + if isinstance(value, str): + return ActivationOrdering(value.lower()) + return value + def is_activation_quantization_format(format: str) -> bool: _ACTIVATION_QUANTIZATION_FORMATS = [ diff --git a/aphrodite/quantization/gptq_marlin.py b/aphrodite/quantization/gptq_marlin.py index a277d421a..8c13d14f7 100644 --- a/aphrodite/quantization/gptq_marlin.py +++ b/aphrodite/quantization/gptq_marlin.py @@ -1,10 +1,9 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import torch from loguru import logger -from torch.nn import Parameter -from aphrodite import _custom_ops as ops +from aphrodite.common.utils import is_hip from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead from aphrodite.modeling.parameter import (ChannelQuantScaleParameter, @@ -13,13 +12,12 @@ PackedColumnParameter, RowAphroditeParameter) from aphrodite.quantization.base_config import QuantizationConfig +from aphrodite.quantization.kernels import (MPLinearLayerConfig, + choose_mp_linear_kernel) from aphrodite.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_marlin_supported, verify_marlin_supports_shape) + check_marlin_supported, marlin_repeat_scales_on_all_ranks, + verify_marlin_supported) from aphrodite.scalar_type import scalar_types -from aphrodite.common.utils import is_hip class GPTQMarlinConfig(QuantizationConfig): @@ -151,6 +149,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): Args: quant_config: The GPTQ Marlin quantization config. """ + _kernel_backends_being_used: Set[str] = set() def __init__(self, quant_config: GPTQMarlinConfig) -> None: self.quant_config = quant_config @@ -166,23 +165,34 @@ def create_weights( **extra_weight_attrs, ) -> None: - del output_size output_size_per_partition = sum(output_partition_sizes) is_row_parallel = input_size != input_size_per_partition weight_loader = extra_weight_attrs.get("weight_loader") + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + has_g_idx=self.quant_config.desc_act + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info( + f"Using {kernel_type.__name__} for GPTQMarlinLinearMethod") + self._kernel_backends_being_used.add(kernel_type.__name__) + # Normalize group_size if self.quant_config.group_size != -1: group_size = self.quant_config.group_size else: group_size = input_size - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size) - # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, self.quant_config.group_size, @@ -263,55 +273,15 @@ def create_weights( layer.register_parameter("g_idx", g_idx) layer.register_parameter("scales", scales) layer.register_parameter("qzeros", qzeros) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act, - is_row_parallel) - - # Checkpoints are serialized in AutoGPTQ format, which is different from the - # marlin format. This function is called after the weights are loaded. - # Here, we handle the repacking, including the activation reordering case. - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.qweight.device - - # required by torch.compile - layer.qweight = Parameter(layer.qweight.data, requires_grad=False) - layer.scales = Parameter(layer.scales.data, requires_grad=False) - # Allocate marlin workspace - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx") - # Handle sorting for activation reordering if needed. - if self.quant_config.desc_act: - g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - replace_tensor(layer, "g_idx", g_idx) - else: - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.zp = marlin_make_empty_g_idx(device) - - # Repack weights from autogptq format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.qweight, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qweight", marlin_qweight) - - # Permute scales from autogptq format to marlin format. - marlin_scales = marlin_permute_scales( - layer.scales, - size_k=(layer.input_size if self.quant_config.desc_act else - layer.input_size_per_partition), - size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) - replace_tensor(layer, "scales", marlin_scales) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) def apply( self, @@ -319,16 +289,4 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return apply_gptq_marlin_linear( - input=x, - weight=layer.qweight, - weight_scale=layer.scales, - weight_zp=layer.zp, - g_idx=layer.g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=self.quant_config.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=layer.is_k_full, - bias=bias) + return self.kernel.apply_weights(layer, x, bias) diff --git a/aphrodite/quantization/kernels/MPLinearKernel.py b/aphrodite/quantization/kernels/MPLinearKernel.py new file mode 100644 index 000000000..160083d66 --- /dev/null +++ b/aphrodite/quantization/kernels/MPLinearKernel.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +import torch + +from aphrodite.quantization.utils import replace_parameter +from aphrodite.scalar_type import ScalarType + + +@dataclass +class MPLinearLayerConfig: + full_weight_shape: Tuple[int, int] # [in, out] + partition_weight_shape: Tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + has_g_idx: bool + + +class MPLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.w_zp_name = w_zp_param_name + self.w_gidx_name = w_gidx_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _transform_param(self, layer: torch.nn.Module, name: Optional[str], + fn: Callable) -> None: + if name is not None and getattr(layer, name, None) is not None: + + old_param = getattr(layer, name) + new_param = fn(old_param) + # replace the parameter with torch.nn.Parameter for TorchDynamo + # compatibility + replace_parameter( + layer, name, + torch.nn.Parameter(new_param.data, requires_grad=False)) + + def _get_weight_params( + self, layer: torch.nn.Module + ) -> Tuple[torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor] # w_gidx + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) diff --git a/aphrodite/quantization/kernels/__init__.py b/aphrodite/quantization/kernels/__init__.py new file mode 100644 index 000000000..5bfeda49f --- /dev/null +++ b/aphrodite/quantization/kernels/__init__.py @@ -0,0 +1,67 @@ +import os +from typing import List, Optional, Type + +from aphrodite.platforms import current_platform +from aphrodite.quantization.kernels.machete import MacheteLinearKernel +from aphrodite.quantization.kernels.marlin import MarlinLinearKernel +from aphrodite.quantization.kernels.MPLinearKernel import (MPLinearKernel, + MPLinearLayerConfig) + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + MarlinLinearKernel, +] + + +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: + """ + Choose an MPLinearKernel that can implement the given config for the given + compute capability. Attempts to choose the best kernel in terms of + performance. + Args: + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the compute + capability. Defaults to None. + Raises: + ValueError: If no kernel can implement the given config. + Returns: + Type[MPLinearKernel]: Chosen kernel. + """ + if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ + .split(","): + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "WNA16 linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/aphrodite/quantization/kernels/machete.py b/aphrodite/quantization/kernels/machete.py new file mode 100644 index 000000000..a5b90e257 --- /dev/null +++ b/aphrodite/quantization/kernels/machete.py @@ -0,0 +1,118 @@ +from functools import partial +from typing import Optional, Tuple + +import torch + +from aphrodite import _custom_ops as ops +from aphrodite.modeling.parameter import (BaseAphroditeParameter, + permute_param_layout_) +from aphrodite.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) +from aphrodite.quantization.utils.quant_utils import ( + pack_weights_into_int32, unpack_weights_into_int32) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MacheteLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ + c.partition_weight_shape[0] != c.full_weight_shape[0]: + return False, "Act reordering currently not supported by Machete, "\ + "when the input features are partitioned across "\ + "devices" + + if c.zero_points: + return False, "Zero points currently not supported by "\ + " Compressed Tensors + Machete. (Kernel supports it"\ + " but CompressedTensorsWNA16 does not so support has"\ + " not been added to MacheteWNA16Kernel yet" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{MACHETE_SUPPORTED_GROUP_SIZES}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + if c.has_g_idx: + assert self.w_gidx_name is not None + perm = torch.argsort(getattr(layer, self.w_gidx_name))\ + .to(torch.int) + + self.act_perm = lambda x: x[:, perm] + # use `ops.permute_cols` if possible + if c.act_type in [torch.float16, torch.bfloat16] \ + and c.partition_weight_shape[0] % 8 == 0: + self.act_perm = partial(ops.permute_cols, perm=perm) + + def transform_w_q(x): + assert isinstance(x, BaseAphroditeParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + if c.has_g_idx: + x_unpacked = unpack_weights_into_int32(x.data, + c.weight_type, + packed_dim=0) + x_perm = x_unpacked[perm, :] + x.data = pack_weights_into_int32(x_perm, + c.weight_type, + packed_dim=0) + x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), + self.config.weight_type) + return x + + def transform_w_s(x): + assert isinstance(x, BaseAphroditeParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + if c.has_g_idx: + x_2d = self.act_perm(x_2d) + + output = ops.machete_gemm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_zeros=None, + b_scales=w_s, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/aphrodite/quantization/kernels/marlin.py b/aphrodite/quantization/kernels/marlin.py new file mode 100644 index 000000000..8c4166aa6 --- /dev/null +++ b/aphrodite/quantization/kernels/marlin.py @@ -0,0 +1,132 @@ +from typing import Optional, Tuple + +import torch + +from aphrodite import _custom_ops as ops +from aphrodite.modeling.parameter import (BaseAphroditeParameter, + permute_param_layout_) +from aphrodite.quantization.utils.marlin_utils import ( + MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, + check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, + marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, + query_marlin_supported_quant_types) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MarlinLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.zero_points: + return False, "Zero points currently not supported by "\ + " MarlinLinearKernel. Will be added when AWQMarlin "\ + "is migrated over to using MPLinearKernel backend" + + quant_types = query_marlin_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, f"Quant type ({c.weight_type}) not supported by"\ + f" Marlin, supported types are: {quant_types}" + + if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Marlin, supported group sizes are: "\ + f"{MARLIN_SUPPORTED_GROUP_SIZES}" + + return check_marlin_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1], + c.full_weight_shape[1], + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) + self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace(c.partition_weight_shape[1], + device) + + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "w_zp" + + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + pass + # TODO (lucas): add the following when AWQMarlin is migrated over to + # using MPLinearKernel backend + # self._transform_param(layer, self.w_zp_name, lambda x: \ + # marlin_zero_points( + # x, + # size_k=c.partition_weight_shape[0], + # size_n=c.partition_weight_shape[1], + # num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + + def transform_w_q(x): + assert isinstance(x, BaseAphroditeParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.gptq_marlin_repack(x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits) + return x + + def transform_w_s(x): + assert isinstance(x, BaseAphroditeParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = marlin_permute_scales(x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size) + return x + + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) + + # `process_weights_after_loading` will ensure w_zp and w_gidx are not + # None for marlin + return apply_gptq_marlin_linear( + input=x, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=self.is_k_full, + bias=bias) diff --git a/aphrodite/quantization/utils/__init__.py b/aphrodite/quantization/utils/__init__.py index e69de29bb..6d18fa3b2 100644 --- a/aphrodite/quantization/utils/__init__.py +++ b/aphrodite/quantization/utils/__init__.py @@ -0,0 +1,3 @@ +from .layer_utils import replace_parameter, update_tensor_inplace + +__all__ = ['update_tensor_inplace', 'replace_parameter'] diff --git a/aphrodite/quantization/utils/layer_utils.py b/aphrodite/quantization/utils/layer_utils.py new file mode 100644 index 000000000..adb22b177 --- /dev/null +++ b/aphrodite/quantization/utils/layer_utils.py @@ -0,0 +1,33 @@ +from typing import Union + +import torch + + +def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): + assert dst.dtype == src.dtype, "Tensors must have the same dtype" + + # update tensor shape and stride + dst.as_strided_(src.shape, src.stride()) + + # If not the same underlying storage move tensor data + if dst.data_ptr() != src.data_ptr(): + dst.copy_(src) + del src + + +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by Aphrodite (and won't be freed) +def replace_parameter(mod: torch.nn.Module, name: str, + new: Union[torch.Tensor, torch.nn.Parameter]) -> None: + + old = getattr(mod, name) + if old.dtype == new.dtype and \ + old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): + # If we can just update in-place to avoid re-registering + # can be faster if the underlying storage is the same + update_tensor_inplace(old, new) + else: + # Fallback re-register parameter + if not isinstance(new, torch.nn.Parameter): + new = torch.nn.Parameter(new) + mod.register_parameter(name, torch.nn.Parameter(new)) diff --git a/aphrodite/quantization/utils/machete_utils.py b/aphrodite/quantization/utils/machete_utils.py new file mode 100644 index 000000000..3f74c5032 --- /dev/null +++ b/aphrodite/quantization/utils/machete_utils.py @@ -0,0 +1,30 @@ +from typing import List, Optional, Tuple + +import torch + +from aphrodite.scalar_type import ScalarType, scalar_types + +MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128] +MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] + + +def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: + if zero_points: + return [scalar_types.uint4, scalar_types.uint8] + else: + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]: + return [torch.float16, torch.bfloat16] + + +def check_machete_supports_shape(in_features: int, out_featrues: int) \ + -> Tuple[bool, Optional[str]]: + if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: + return False, "Input features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" + if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: + return False, "Output features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" + return True, None diff --git a/aphrodite/quantization/utils/marlin_utils.py b/aphrodite/quantization/utils/marlin_utils.py index dec5d653a..616b77b95 100644 --- a/aphrodite/quantization/utils/marlin_utils.py +++ b/aphrodite/quantization/utils/marlin_utils.py @@ -117,6 +117,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int, "with --quantization gptq.") +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + + def marlin_make_workspace(output_size_per_partition: int, device: torch.device) -> torch.Tensor: max_workspace_size = (output_size_per_partition // @@ -145,6 +158,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: requires_grad=False) +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + def marlin_sort_g_idx( g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) @@ -220,17 +238,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp -# Newly generated tensors need to replace existing tensors that are -# already registered as parameters by Aphrodite (and won't be freed) -def replace_tensor(layer: torch.nn.Module, name: str, - new_t: torch.Tensor) -> None: - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, diff --git a/aphrodite/quantization/utils/quant_utils.py b/aphrodite/quantization/utils/quant_utils.py index 5618d7970..9087714b0 100644 --- a/aphrodite/quantization/utils/quant_utils.py +++ b/aphrodite/quantization/utils/quant_utils.py @@ -19,6 +19,49 @@ } +def pack_weights_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_weights_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj @@ -80,7 +123,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): def quantize_weights(w: torch.Tensor, quant_type: ScalarType, group_size: int, - zero_points: bool = False): + zero_points: bool = False, + ref_zero_points_after_scales: bool = False): assert quant_type.is_integer(), \ "Floating point quantization may work but has not been tested" @@ -124,8 +168,13 @@ def quantize_weights(w: torch.Tensor, w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) w_q = torch.clamp(w_q, min_q_val, max_q_val) - # Compute ref (dequantized) - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and zero_points: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s if quant_type.has_bias(): w_q += quant_type.bias diff --git a/kernels/cuda_utils.h b/kernels/cuda_utils.h index e9a7a12cf..0a9bd4ca8 100644 --- a/kernels/cuda_utils.h +++ b/kernels/cuda_utils.h @@ -1,5 +1,15 @@ #pragma once +#if defined(__CUDACC__) || defined(_NVHPC_CUDA) + #define HOST_DEVICE_INLINE __forceinline__ __host__ __device__ + #define DEVICE_INLINE __forceinline__ __device__ + #define HOST_INLINE __forceinline__ __host__ +#else + #define HOST_DEVICE_INLINE inline + #define DEVICE_INLINE inline + #define HOST_INLINE inline +#endif + int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); \ No newline at end of file diff --git a/kernels/cutlass_extensions/aphrodite_collective_builder.cuh b/kernels/cutlass_extensions/aphrodite_collective_builder.cuh new file mode 100644 index 000000000..f9508a2e8 --- /dev/null +++ b/kernels/cutlass_extensions/aphrodite_collective_builder.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::gemm::collective { +using namespace cute; + +// +// APHRODITECollectiveBuilder is a wrapper around CollectiveBuilder that allows +// for for custom kernel tags, allowing you to build custom collectives. Without +// touching the cutlass library headers, using `CutlassKernelTag` will mean it +// will resort to using the standard cutlass collective builder. +// + +// Use the default Cutlass collective builder, i.e. use an unmodified cutless +// collective +struct CutlassKernelTag {}; + +template +struct APHRODITECollectiveBuilder { + static_assert(sizeof(ElementA) == 0, + "Could not build a collective for given parameters."); +}; + +template +struct APHRODITECollectiveBuilder< + CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, + ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType> { + using CollectiveOp = typename CollectiveBuilder< + ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB, + GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp; +}; + +}; // namespace cutlass::gemm::collective \ No newline at end of file diff --git a/kernels/cutlass_extensions/aphrodite_custom_types.cuh b/kernels/cutlass_extensions/aphrodite_custom_types.cuh new file mode 100644 index 000000000..a75e86bd7 --- /dev/null +++ b/kernels/cutlass_extensions/aphrodite_custom_types.cuh @@ -0,0 +1,51 @@ +#pragma once + +#include "cutlass/integer_subbyte.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct aphrodite_biased_integer_subbyte : public integer_subbyte { + using Base = integer_subbyte; + + using Storage = typename Base::Storage; + using xint_t = typename Base::xint_t; + + using Base::bits_mask_; + using Base::sign_mask_; + using Base::storage; + + // + // Methods + // + + /// No operation + aphrodite_biased_integer_subbyte() = default; + + /// Conversion from integer type + CUTLASS_HOST_DEVICE explicit aphrodite_biased_integer_subbyte(int value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit aphrodite_biased_integer_subbyte(unsigned value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit aphrodite_biased_integer_subbyte(double value) + : Base(value) {} +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// "GPTQ" types, i.e. symmetric quantization +using aphrodite_uint4b8_t = aphrodite_biased_integer_subbyte<4, 8>; // u4b8 +using aphrodite_uint8b128_t = + aphrodite_biased_integer_subbyte<8, 128>; // u8b128 + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct sizeof_bits> { + static constexpr int value = Bits; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass \ No newline at end of file diff --git a/kernels/cutlass_extensions/aphrodite_cutlass_library_extension.py b/kernels/cutlass_extensions/aphrodite_cutlass_library_extension.py new file mode 100644 index 000000000..37fcdc86a --- /dev/null +++ b/kernels/cutlass_extensions/aphrodite_cutlass_library_extension.py @@ -0,0 +1,49 @@ +import enum +from typing import Dict, Union + +from cutlass_library import * + +# +# Extend cutlass library with custom types, and missing values +# + + +class APHRODITEDataType(enum.Enum): + u4b8 = enum_auto() + u8b128 = enum_auto() + + +class MixedInputKernelScheduleType(enum.Enum): + TmaWarpSpecializedMixedInput = enum_auto() + TmaWarpSpecializedPingpongMixedInput = enum_auto() + TmaWarpSpecializedCooperativeMixedInput = enum_auto() + + +APHRODITEDataTypeNames: Dict[Union[APHRODITEDataType, DataType], str] = { + **DataTypeNames, # type: ignore + **{ + APHRODITEDataType.u4b8: "u4b8", + APHRODITEDataType.u8b128: "u8b128", + } +} + +APHRODITEDataTypeTag: Dict[Union[APHRODITEDataType, DataType], str] = { + **DataTypeTag, # type: ignore + **{ + APHRODITEDataType.u4b8: "cutlass::aphrodite_uint4b8_t", + APHRODITEDataType.u8b128: "cutlass::aphrodite_uint8b128_t", + } +} + +APHRODITEKernelScheduleTag: Dict[Union[ + MixedInputKernelScheduleType, KernelScheduleType], str] = { + **KernelScheduleTag, # type: ignore + **{ + MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput", + } + } diff --git a/kernels/cutlass_extensions/aphrodite_numeric_conversion.cuh b/kernels/cutlass_extensions/aphrodite_numeric_conversion.cuh new file mode 100644 index 000000000..b837fe238 --- /dev/null +++ b/kernels/cutlass_extensions/aphrodite_numeric_conversion.cuh @@ -0,0 +1,797 @@ +#pragma once + +#include "cutlass/numeric_conversion.h" +#include "cutlass_extensions/aphrodite_custom_types.cuh" +#include "cutlass_extensions/cute_utils.cuh" + +// this file extends: +// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h +// with aphrodite specific type conversions, namely: aphrodite_uint4b8_t, +// aphrodite_uint8b128_t as well as adds interleaved numeric array converters +// for specific types. (interleaved numeric array converters can be more +// efficient for subbyte types) + +namespace cutlass { + +// InterleavedNumericArrayConverter is like NumericArrayConverter but also +// deinterleaves converted elements based on IlvBlkLayout, interleaving can +// make subbyte converts more efficient by allowing for efficient extraction +// of subbyte elements from a 32bit register. +template +struct InterleavedNumericArrayConverter { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + CUTE_INVALID_CONTROL_PATH( + "InterleavedNumericArrayConverter not implemented\n"); + return {}; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +template +struct InterleavedNumericArrayConverter< + IlvBlkLayout, T, S, N, Round, + std::enable_if_t()>> { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return Converter::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// TODO (LucasWilkinson): Implement +// for Array <= Array + +// .... + +template +struct ArrayConverterPacked32Bit { + using result_type = Array; + using source_type = Array; + + using result_packed_8_t = Array; + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_8_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + static_assert(N % 2 == 0, "N must be a multiple of 2"); + static_assert(cutlass::sizeof_bits_v >= 4); // TODO: add 16 packed sources + static_assert(32 % cutlass::sizeof_bits_v == 0); + static constexpr auto src_elems_per_32bit_reg = + 32 / cutlass::sizeof_bits_v; + + // Maybe not Valid. ScalarConverter will not actually work unless + // NumericConverter is implemented. However it won't be used + // anyways since we assert N % 2 == 0, just here for compliance with + // VectorizedConverter. + using ScalarConverter = NumericConverter; + + template + CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) { + if constexpr (sizeof(PackedSrc) == 1) { + return static_cast(reinterpret_cast(source)); + } else if constexpr (sizeof(PackedSrc) == 2) { + return static_cast(reinterpret_cast(source)); + } else { + static_assert(sizeof(PackedSrc) == 4); + return reinterpret_cast(source); + } + } + + // The core converter uses bit tricks to construct a known FP16 number, then + // does a subtraction in FP16 for the final result. + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert(PackedSrcType::kElements == PackedResultType::kElements); + static_assert(PackedResultType::kElements == 2 || + PackedResultType::kElements == 4 || + PackedResultType::kElements == 8, + "Invalid PackedResultType must be 2, 4 or 8."); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + return RegConvert32bit::template convert(to_reg(source)); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + ArrayConverterPacked32Bit; + + if constexpr (src_elems_per_32bit_reg >= 8) { + detail::VectorizedConverter::convert< + ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t, + src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source); + } else if constexpr (src_elems_per_32bit_reg >= 4) { + detail::VectorizedConverter::convert(result, source); + } else { + detail::VectorizedConverter::convert(result, source); + } + + return result; + } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + // Below constructs the following temporary: + // fp16s_01 = {0x00, i4_01, 0x00, i4_01} + // fp16s_23 = {0x00, i4_23, 0x00, i4_23} + // fp16s_45 = {0x00, i4_45, 0x00, i4_45} + // fp16s_67 = {0x00, i4_67, 0x00, i4_67} + // We use inline asm instead of __byte_perm intrinsic since we don't want + // the documented (& 0x7) on the index. NVCC might be able to optimize it + // out since the index is a constexpr, but we choose to be safe about it + // here. + uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for F16 -> I4 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src), "n"(0), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a fp16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the FP16 to the correct value for the + // FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)}, + // where x1 in the high nibble and x0 is the low nibble then using hfma + // to subtract 1032 from that + // The AND does the following: + // 1) Clear the set bits for the int4 we will ignore. + // We use lop3 so that we can use 1 instruction for AND and XOR. + static constexpr uint32_t xor_mask = 0x64006400; + static constexpr uint32_t and_mask = 0xFFF0FF0F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 hfmas that do the following: + // {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032} + // = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032} + static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032} + static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1} + + const half2& hfma_bias = reinterpret_cast(hfma_bias_rep); + const half2& hfma_scale = reinterpret_cast(hfma_scale_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, aphrodite_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032} + // For high nibble: + // {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16} + // - {72, 72} + static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024} + // For high nibble: + // {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64} + static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + + uint32_t const prmt_indices[2] = {0x5150, 0x5352}; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(r[ii]) + : "r"(src), "n"(start_byte_for_fp16), + "r"(prmt_indices[ii])); + } + + // -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes + static constexpr uint32_t bias_rep = 0x64806480; + const half2& bias = reinterpret_cast(bias_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hsub2(fp16x2_val, bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + PackedResultType r; + + // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of + // u8x4 source and stores the result in r (without introducing extra + // cvt.u32.u8 instruction) + uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653}; + uint32_t* result_as_int = reinterpret_cast(&r); + for (int ii = 0; ii < PackedResultType::kElements; ++ii) { + result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]); + // Subtract the magic number 0x4B000000 from tmp in floating-point + // arithmetic to obtain final result + r[ii] -= (8388608.f + 128.f); // fold in -128 bias + } + + return r; + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) { + // Hold output BF16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + uint32_t src_reg_shifted = src_reg >> 4; + + // Below constructs the following temporary: + uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for uint4b8_t -> BF16 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a BF16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the BF16 to the correct value for the + // BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)} + // and subtracting 136 to get {x1, x0} + static constexpr uint32_t xor_mask = 0x43004300; + static constexpr uint32_t and_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 bfmas that do the following: + // high BF16: + // hi_bf16 - 136, lo_bf16 - 136 + + // This is the BF16 {136, 136} represented as an integer. + static constexpr uint32_t bias_rep = 0x43084308; + const __nv_bfloat162& bias = + reinterpret_cast(bias_rep); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, bias); + } + + return reinterpret_cast(r); + } + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, + aphrodite_uint4b8_t, N, Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136} + static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128} + static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + // Not Valid, not supported, only here to satisfy the interface and to avoid + // a compile error. ScalarConverter will not actually work until + // NumericConverter is + // implemented + using ScalarConverter = + NumericConverter; + + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private " + "convert dispatch."); + + NumericArrayConverter + convert_uint8_to_f32; + Array tmp = + convert_uint8_to_f32(source); + NumericArrayConverter + convert_f32_to_bf16_; + return convert_f32_to_bf16_(tmp); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/kernels/cutlass_extensions/cute_utils.cuh b/kernels/cutlass_extensions/cute_utils.cuh new file mode 100644 index 000000000..c660b12d5 --- /dev/null +++ b/kernels/cutlass_extensions/cute_utils.cuh @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +namespace cute { + +//////////////////////////////////////////////////////////////////// +// layout utils +//////////////////////////////////////////////////////////////////// + +// Permute layout based on indices, example: +// permute_layout<1, 0>(layout) will swap the two dimensions +// permute_layout<0, 2, 1>(layout) will swap the last two dimensions +template +CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { + static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch"); + return cute::make_layout(cute::get(l)...); +} + +// is the layout f(x) = x +template +CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { + if constexpr (std::is_same_v) + return true; + else { + constexpr auto coalesced_layout = coalesce(Layout{}); + if constexpr (rank(coalesced_layout) == 1 && + stride<0>(coalesced_layout) == 1) { + return true; + } + return false; + } +} + +//////////////////////////////////////////////////////////////////// +// Pointer utils +//////////////////////////////////////////////////////////////////// + +template +static constexpr auto get_logical_ptr(PointerType* ptr) { + if constexpr (cute::sizeof_bits_v < 8) { + return cute::subbyte_iterator(ptr); + } else { + return ptr; + } +} + +//////////////////////////////////////////////////////////////////// +// Misc utils +//////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() { + constexpr auto bits = sizeof_bits_v * Elements{}; + if constexpr (bits % 128 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } else if constexpr (bits % 64 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<64>{}; + } else if constexpr (bits % 32 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<32>{}; + } else if constexpr (bits % 16 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<16>{}; + } else { + return AutoVectorizingCopyWithAssumedAlignment<8>{}; + } +} + +}; // namespace cute \ No newline at end of file diff --git a/kernels/cutlass_extensions/torch_utils.hpp b/kernels/cutlass_extensions/torch_utils.hpp new file mode 100644 index 000000000..2c7857252 --- /dev/null +++ b/kernels/cutlass_extensions/torch_utils.hpp @@ -0,0 +1,160 @@ +#pragma once + +#include + +#include "cute/layout.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" + +using ColumnMajor = typename cutlass::layout::ColumnMajor; +using RowMajor = typename cutlass::layout::RowMajor; + +namespace cute { + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g, + seq) { + return g(f(cute::get(static_cast(t)), I)...); +} + +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq) { + return make_shape(f(I)...); +} + +}; // namespace detail + +template +CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) { + if constexpr (cute::is_tuple::value) { + return detail::tapply_with_idx( + t, f, [](auto const&... a) { return cute::make_tuple(a...); }, + tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// calls: make_shape(f(0), f(1), ..., f(N-1)) +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) { + return detail::make_shape_from_idx(f, make_seq{}); +} + +}; // namespace cute + +// Make a layout from a tensor with `rank(Stride{})`, where the shape is the +// shape of the passed in tensor and the strides are of type `Stride` and +// contain the strides of the passed in tensor, checking that any static strides +// in `Stride{}` match the strides of the passed in tensor. +// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra +// strides are set to be 0 or 1. +template +static inline auto make_cute_layout(torch::Tensor const& tensor, + std::string_view name = "tensor") { + TORCH_CHECK(tensor.dim() <= rank(Stride{})); + auto stride = cute::transform_with_idx( + Stride{}, [&](auto const& stride_ele, auto const& idx) { + using StrideEle = std::decay_t; + + if (idx < tensor.dim()) { + if constexpr (cute::is_static_v) { + TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", + name, ".stride(", idx, ") to be ", StrideEle::value); + return StrideEle{}; + } else { + if (tensor.size(idx) == 1) { + // use 0 stride for dim with size 1, this is easier for + // cute/cutlass to optimize (helps the TMA code flatten dims) + return StrideEle{0}; + } else { + return tensor.stride(idx); + } + } + } else { + // Extra strides are assumed to be 0 or 1 + if constexpr (cute::is_static_v) { + static_assert(StrideEle::value == 0 || StrideEle::value == 1); + } + return StrideEle{}; + } + }); + + auto shape = cute::make_shape_from_idx([&](auto const& idx) { + if (idx < tensor.dim()) + return tensor.size(idx); + else + return int64_t(1); + }); + + return make_layout(shape, stride); +} + +template +static inline auto maybe_make_cute_layout( + c10::optional const& tensor, + std::string_view name = "tensor") { + using Layout = decltype(make_cute_layout(*tensor)); + + if (tensor) { + return std::optional{make_cute_layout(*tensor, name)}; + } else { + return std::optional{}; + } +} + +// +// Torch Type to Cutlass Type (equivalent_cutlass_type) +// + +template +struct equivalent_cutlass_type { + using type = T; +}; + +template +using equivalent_cutlass_type_t = typename equivalent_cutlass_type::type; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::half_t; +}; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::bfloat16_t; +}; + +// +// equivalent_scalar_t (basically inverse of equivalent_cutlass_type) +// + +// Return a `c10::CppTypeToScalarType` compatible type, i.e. get the C++ from +// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half` +template +struct equivalent_scalar_type { + using type = T; +}; + +template +using equivalent_scalar_type_t = typename equivalent_scalar_type::type; + +template <> +struct equivalent_scalar_type { + using type = c10::Half; +}; + +template <> +struct equivalent_scalar_type { + using type = c10::BFloat16; +}; + +// get equivalent c10::ScalarType tag from compile time type +template +static inline constexpr c10::ScalarType equivalent_scalar_type_v = + c10::CppTypeToScalarType>::value; \ No newline at end of file diff --git a/kernels/ops.h b/kernels/ops.h index 86c7915a1..f9e758d2f 100644 --- a/kernels/ops.h +++ b/kernels/ops.h @@ -103,6 +103,8 @@ at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const c10::optional& final_states_out_, bool silu_activation); +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); + // Sampling kernels torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, diff --git a/kernels/permute_cols.cu b/kernels/permute_cols.cu new file mode 100644 index 000000000..f51fa7329 --- /dev/null +++ b/kernels/permute_cols.cu @@ -0,0 +1,88 @@ +#include + +#include +#include + +#include + +static constexpr int default_threads = 256; +static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +// Currently only supports 16bit types (since we permute half types) +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = std::max(finish_row - start_row, 0); + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +// More efficient version of A[..., perm] +// taken from gptq_marlin.cu +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + auto dev = A.get_device(); + auto stream = at::cuda::getCurrentCUDAStream(dev); + + TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16, + "Currently only 16bit types are supported"); + TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); + TORCH_CHECK(A.size(-1) % 8 == 0, + "A columns must be a multiple of 8 (128bits)"); + auto A_2d = A.view({-1, A.size(-1)}); + + torch::Tensor D = torch::empty_like(A); + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + int block_rows = div_ceil(A_2d.size(0), sms); + permute_cols_kernel<<>>( + reinterpret_cast(A_2d.const_data_ptr()), + perm.const_data_ptr(), reinterpret_cast(D.mutable_data_ptr()), + A_2d.size(0), A_2d.size(1), block_rows); + return D; +} \ No newline at end of file diff --git a/kernels/quantization/machete/generate.py b/kernels/quantization/machete/generate.py new file mode 100644 index 000000000..ae5c83c72 --- /dev/null +++ b/kernels/quantization/machete/generate.py @@ -0,0 +1,530 @@ +import itertools +import math +import os +import shutil +from collections.abc import Iterable +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import jinja2 +# yapf conflicts with isort for this block +# yapf: disable +from aphrodite_cutlass_library_extension import (APHRODITEDataType, + APHRODITEDataTypeNames, + APHRODITEDataTypeTag, + APHRODITEKernelScheduleTag, + DataType, EpilogueScheduleTag, + EpilogueScheduleType, + MixedInputKernelScheduleType, + TileSchedulerTag, + TileSchedulerType) + +# yapf: enable + +# +# Generator templating +# + +DISPATCH_TEMPLATE = """ +#include "../machete_mm_launcher.cuh" + +namespace machete { +using GemmDispatcher_ = GemmDispatcher< + {{DataTypeTag[type_config.element_a]}}, // ElementA + {{DataTypeTag[type_config.element_b]}}, // ElementB + {{DataTypeTag[type_config.element_d]}}, // ElementD + {{DataTypeTag[type_config.accumulator]}}, // Accumulator + {{DataTypeTag[type_config.element_b_scale]}}, // Scales + {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints + +{% for s in schedules %}extern torch::Tensor +impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args); +{% endfor %} +template <> +torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) { + [[maybe_unused]] auto M = args.A.size(0); + [[maybe_unused]] auto N = args.B.size(1); + [[maybe_unused]] auto K = args.A.size(1); + + if (!args.schedule) { + {%- for cond, s in heuristic %} + {%if cond is not none%}if ({{cond}}) + {%- else %}else + {%- endif %} + return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %} + } + + {% for s in schedules %} + if (*args.schedule == "{{ gen_sch_name(s) }}") { + return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args); + } + {% endfor %} + TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for " + "schedule = ", *args.schedule); +} + +template <> +std::vector GemmDispatcher_::supported_schedules() { + return { + {% for s in schedules -%} + "{{ gen_sch_name(s) }}"{{ ", + " if not loop.last }}{%- endfor %} + }; +} + +}; // namespace machete +""" + +IMPL_TEMPLATE = """ +#include "../machete_mm_launcher.cuh" + +namespace machete { +template +using Kernel = MacheteKernelTemplate< + {{DataTypeTag[type_config.element_a]}}, // ElementA + {{DataTypeTag[type_config.element_b]}}, // ElementB + {{DataTypeTag[type_config.element_d]}}, // ElementD + {{DataTypeTag[type_config.accumulator]}}, // Accumulator + {{DataTypeTag[type_config.element_b_scale]}}, // Scales + {{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, + Config, with_C, with_scales, with_zeropoints>; + +{% for sch in schedules %} +{% set schedule_name = gen_sch_name(sch) -%} +struct sch_{{schedule_name}} { + using TileShapeNM = Shape<{{ + to_cute_constant(sch.tile_shape_mn)|join(', ')}}>; + using ClusterShape = Shape<{{ + to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>; + // TODO: Reimplement + // using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}}; + using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}}; + using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}}; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +}; + +torch::Tensor +impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) { + bool with_C = args.C.has_value(), with_scales = args.scales.has_value(), + with_zeropoints = args.zeros.has_value(); + + {% for s in specializations %} + if (with_C == {{s.with_C|lower}} + && with_zeropoints == {{s.with_zeropoints|lower}} + && with_scales == {{s.with_scales|lower}}) { + return run_impl>(args); + }{% endfor %} + + TORCH_CHECK_NOT_IMPLEMENTED( + false, "for the sake of compile times and binary size machete_mm(..) is " + " not implemented for with_C=", with_C, ", with_scales=", with_scales, + ", with_zeropoints=", with_zeropoints, + " (for {{type_name}}_sch_{{schedule_name}})"); +} +{% endfor %} + +}; // namespace machete +""" + +PREPACK_TEMPLATE = """ +#include "../machete_prepack_launcher.cuh" + +namespace machete { +using PrepackBDispatcher_ = PrepackBDispatcher< + {{DataTypeTag[type_config.element_a]}}, // ElementA + {{DataTypeTag[type_config.element_b]}}, // ElementB + {{DataTypeTag[type_config.element_d]}}, // ElementD + {{DataTypeTag[type_config.accumulator]}}, // Accumulator + {{DataTypeTag[type_config.element_b_scale]}}, // Scales + {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints + +using PrepackedLayoutB = PrepackedLayoutBTemplate< + {{DataTypeTag[type_config.element_a]}}, // ElementA + {{DataTypeTag[type_config.element_b]}}, // ElementB + {{DataTypeTag[type_config.element_d]}}, // ElementD + {{DataTypeTag[type_config.accumulator]}}, // Accumulator + cutlass::layout::ColumnMajor, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>; + +template <> +torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) { + return prepack_impl(B); +} +}; // namespace machete +""" + +TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput +TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative + + +@dataclass(frozen=True) +class ScheduleConfig: + tile_shape_mn: Tuple[int, int] + cluster_shape_mnk: Tuple[int, int, int] + kernel_schedule: MixedInputKernelScheduleType + epilogue_schedule: EpilogueScheduleType + tile_scheduler: TileSchedulerType + + +@dataclass +class TypeConfig: + element_a: DataType + element_b: Union[DataType, APHRODITEDataType] + element_b_scale: DataType + element_b_zeropoint: DataType + element_d: DataType + accumulator: DataType + + +@dataclass +class Specialization: + with_C: bool + with_zeropoints: bool + with_scales: bool + + +@dataclass +class ImplConfig: + type_config: TypeConfig + schedule_configs: List[ScheduleConfig] + specializations: List[Specialization] + heuristic: List[Tuple[Optional[str], ScheduleConfig]] + + +def generate_schedule_name(schedule_config: ScheduleConfig) -> str: + tile_shape = ( + f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" + ) + cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" + + f"x{schedule_config.cluster_shape_mnk[1]}" + + f"x{schedule_config.cluster_shape_mnk[2]}") + kernel_schedule = APHRODITEKernelScheduleTag[ + schedule_config.kernel_schedule]\ + .split("::")[-1] + epilogue_schedule = EpilogueScheduleTag[ + schedule_config.epilogue_schedule].split("::")[-1] + tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\ + .split("::")[-1] + + return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + + f"_{epilogue_schedule}_{tile_scheduler}") + + +# mostly unique shorter schedule_name +def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str: + kernel_terse_names_replace = { + "KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_", + "TmaWarpSpecializedCooperative_": "TmaCoop_", + "StreamKScheduler": "streamK", + } + + schedule_name = generate_schedule_name(schedule_config) + for orig, terse in kernel_terse_names_replace.items(): + schedule_name = schedule_name.replace(orig, terse) + return schedule_name + + +# unique type_name +def generate_type_signature(kernel_type_config: TypeConfig): + element_a = APHRODITEDataTypeNames[kernel_type_config.element_a] + element_b = APHRODITEDataTypeNames[kernel_type_config.element_b] + element_d = APHRODITEDataTypeNames[kernel_type_config.element_d] + accumulator = APHRODITEDataTypeNames[kernel_type_config.accumulator] + element_scale = APHRODITEDataTypeNames[kernel_type_config.element_b_scale] + element_zeropoint = APHRODITEDataTypeNames[ + kernel_type_config.element_b_zeropoint] + + return (f"{element_a}{element_b}{element_d}" + f"{accumulator}{element_scale}{element_zeropoint}") + + +# non-unique shorter type_name +def generate_terse_type_signature(kernel_type_config: TypeConfig): + element_a = APHRODITEDataTypeNames[kernel_type_config.element_a] + element_b = APHRODITEDataTypeNames[kernel_type_config.element_b] + + return f"{element_a}{element_b}" + + +def is_power_of_two(n): + return (n != 0) and (n & (n - 1) == 0) + + +def to_cute_constant(value: List[int]): + + def _to_cute_constant(value: int): + if is_power_of_two(value): + return f"_{value}" + else: + return f"Int<{value}>" + + if isinstance(value, Iterable): + return [_to_cute_constant(value) for value in value] + else: + return _to_cute_constant(value) + + +template_globals = { + "DataTypeTag": APHRODITEDataTypeTag, + "KernelScheduleTag": APHRODITEKernelScheduleTag, + "EpilogueScheduleTag": EpilogueScheduleTag, + "TileSchedulerTag": TileSchedulerTag, + "to_cute_constant": to_cute_constant, + "gen_sch_name": generate_terse_schedule_name, +} + + +def create_template(template_str): + template = jinja2.Template(template_str) + template.globals.update(template_globals) + return template + + +mm_dispatch_template = create_template(DISPATCH_TEMPLATE) +mm_impl_template = create_template(IMPL_TEMPLATE) +prepack_dispatch_template = create_template(PREPACK_TEMPLATE) + + +def create_sources(impl_config: ImplConfig, num_impl_files=2): + sources = [] + + type_name = generate_type_signature(impl_config.type_config) + terse_type_name = generate_terse_type_signature(impl_config.type_config) + + sources.append(( + f"machete_mm_{terse_type_name}", + mm_dispatch_template.render(type_name=type_name, + type_config=impl_config.type_config, + schedules=impl_config.schedule_configs, + heuristic=impl_config.heuristic), + )) + + sources.append(( + f"machete_prepack_{terse_type_name}", + prepack_dispatch_template.render( + type_name=type_name, + type_config=impl_config.type_config, + ), + )) + + num_schedules = len(impl_config.schedule_configs) + schedules_per_file = math.ceil(num_schedules / num_impl_files) + for part, i in enumerate(range(0, num_schedules, schedules_per_file)): + file_schedules = impl_config.schedule_configs[i:i + schedules_per_file] + + sources.append(( + f"machete_mm_{terse_type_name}_impl_part{part}", + mm_impl_template.render( + type_name=type_name, + type_config=impl_config.type_config, + schedules=file_schedules, + specializations=impl_config.specializations, + ), + )) + return sources + + +def generate(): + # See csrc/quantization/machete/Readme.md, the Codegeneration for more info + # about how this works + SCRIPT_DIR = os.path.dirname(__file__) + + schedule_common_params = dict( + kernel_schedule=TmaMI, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.StreamK, + ) + + # For now we use the same heuristic for all types + # Heuristic is currently tuned for H100s + default_heuristic = [ + #### M = 257+ + ( + "M > 256 && K <= 16384 && N <= 4096", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 256", + ScheduleConfig( + tile_shape_mn=(128, 256), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 129-256 + ( + "M > 128 && K <= 4096 && N <= 4096", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 128 && K <= 8192 && N <= 8192", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 128", + ScheduleConfig( + tile_shape_mn=(128, 256), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 65-128 + ( + "M > 64 && K <= 4069 && N <= 4069", + ScheduleConfig( + tile_shape_mn=(128, 32), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 64 && K <= 4069 && N <= 8192", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 64 && K >= 8192 && N >= 12288", + ScheduleConfig( + tile_shape_mn=(256, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 64", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 33-64 + ( + "M > 32 && K <= 6144 && N <= 6144", + ScheduleConfig( + tile_shape_mn=(128, 16), + cluster_shape_mnk=(1, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 32 && K >= 16384 && N >= 12288", + ScheduleConfig( + tile_shape_mn=(256, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 32", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 17-32 + ( + "M > 16 && K <= 12288 && N <= 8192", + ScheduleConfig( + tile_shape_mn=(128, 32), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 16", + ScheduleConfig( + tile_shape_mn=(256, 32), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 1-16 + ( + "N >= 26624", + ScheduleConfig( + tile_shape_mn=(256, 16), + cluster_shape_mnk=(1, 1, 1), + **schedule_common_params # type: ignore + )), + ( + None, + ScheduleConfig( + tile_shape_mn=(128, 16), + cluster_shape_mnk=(1, 1, 1), + **schedule_common_params # type: ignore + )), + ] + + schedules = list(set([x[1] for x in default_heuristic])) + + impl_configs = [] + + GPTQ_kernel_type_configs = list( + (TypeConfig( + element_a=element_a, + element_b=element_b, + element_b_scale=element_a, + element_b_zeropoint=element_a, + element_d=element_a, + accumulator=DataType.f32, + ) for element_b in (APHRODITEDataType.u4b8, APHRODITEDataType.u8b128) + for element_a in (DataType.f16, DataType.bf16))) + + GPTQ_kernel_specializations = [ + Specialization(with_C=False, with_zeropoints=False, with_scales=True) + ] + + impl_configs += [ + ImplConfig(x[0], x[1], x[2], x[3]) + for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules), + itertools.repeat(GPTQ_kernel_specializations), + itertools.repeat(default_heuristic)) + ] + + AWQ_kernel_type_configs = list( + (TypeConfig( + element_a=element_a, + element_b=element_b, + element_b_scale=element_a, + element_b_zeropoint=element_a, + element_d=element_a, + accumulator=DataType.f32, + ) for element_b in (DataType.u4, DataType.u8) + for element_a in (DataType.f16, DataType.bf16))) + + AWQ_kernel_specializations = [ + Specialization(with_C=False, with_zeropoints=True, with_scales=True) + ] + + impl_configs += [ + ImplConfig(x[0], x[1], x[2], x[3]) + for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules), + itertools.repeat(AWQ_kernel_specializations), + itertools.repeat(default_heuristic)) + ] + + output_dir = os.path.join(SCRIPT_DIR, "generated") + + # Delete the "generated" directory if it exists + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + # Create the "generated" directory + os.makedirs(output_dir) + + # Render each group of configurations into separate files + for impl_config in impl_configs: + for filename, code in create_sources(impl_config): + filepath = os.path.join(output_dir, f"{filename}.cu") + with open(filepath, "w") as output_file: + output_file.write(code) + print(f"Rendered template to {filepath}") + + +if __name__ == "__main__": + generate() diff --git a/kernels/quantization/machete/machete_collective_builder.cuh b/kernels/quantization/machete/machete_collective_builder.cuh new file mode 100644 index 000000000..acb94ce6e --- /dev/null +++ b/kernels/quantization/machete/machete_collective_builder.cuh @@ -0,0 +1,33 @@ +#pragma once + +#include "cutlass_extensions/aphrodite_collective_builder.cuh" +#include "machete_mainloop.cuh" + +namespace cutlass::gemm::collective { +using namespace cute; + +struct MacheteKernelTag {}; + +template +struct APHRODITECollectiveBuilder< + MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_, + GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB, + ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, + KernelScheduleType, + cute::enable_if_t<( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v)>> { + using CollectiveOp = machete::MacheteCollectiveMma< + ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, + AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, + StageCountType, KernelScheduleType>; +}; + +}; // namespace cutlass::gemm::collective \ No newline at end of file diff --git a/kernels/quantization/machete/machete_interleaving_utils.cuh b/kernels/quantization/machete/machete_interleaving_utils.cuh new file mode 100644 index 000000000..15713dfb9 --- /dev/null +++ b/kernels/quantization/machete/machete_interleaving_utils.cuh @@ -0,0 +1,35 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace machete { + +using namespace cute; + +// get an interleaved block layout where each element consecutive element has a +// stride of bit_stride and the block width is blk_bit_width, +// examples: +// size_bits = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1 +// size_bits = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1) +// size_bits = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1) +// size_bits = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1) +template +CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() { + static_assert(blk_bit_width % bit_stride == 0); + static_assert(bit_stride % cute::sizeof_bits_v == 0); + + constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v; + + if constexpr (cute::sizeof_bits_v == bit_stride) { + // identity layout + return Layout>>{}; + } else { + constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v; + constexpr auto num_strides = elems_per_blk / elems_per_stride; + return Layout, Int>, + Stride, Int<1>>>{}; + } +} + +}; // namespace machete \ No newline at end of file diff --git a/kernels/quantization/machete/machete_mainloop.cuh b/kernels/quantization/machete/machete_mainloop.cuh new file mode 100644 index 000000000..1e0ed9539 --- /dev/null +++ b/kernels/quantization/machete/machete_mainloop.cuh @@ -0,0 +1,1473 @@ +// +// Based off of: +// cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +// Specifically: +// https://github.com/NVIDIA/cutlass/tree/06b21349bcf6ddf6a1686a47a137ad1446579db9/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +// Referred to as upstream from in the comments +// +// The main optimization machete implements compared to upstream is to prepack +// the weight matrix to more closely match the shape of the wgmma instructions +// allowing for wider (ideally 128bit) shared memory loads. For subbyte types +// this is done by packing values from multiple wgmma loads (for a single +// thread) into a single 128bit load. This is very similar to layout used in +// Marlin, although specific to the wgmma instructions. +// +// Since the wgmma instructions only support sourcing from registers for the A +// operand, and we want to upconvert/decompress the weight values/elements +// before feeding them into the tensor cores in registers, we need the weight +// matrix to be A. To achieve this we compute the transpose of Y = XW^t as +// Y^t = W^tX^t. This is mostly done outside of this file in +// csrc/quantization/machete/machete_mm_kernel.cuh, but this why A is the +// quantized/narrow type and has the prepacked layout despite the API being: +// B_prepacked = machete_prepack_B(B) +// Y = machete_mm(A, B_prepacked) +// +#pragma once + +// clang-format off +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" +#include "cutlass/trace.h" + +#include "cutlass/detail/collective.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" + +namespace machete { + +using namespace cute; +using namespace cutlass; +using namespace cutlass::gemm; +using namespace cutlass::gemm::collective; +using namespace cutlass::gemm::collective::detail; + +template +struct MacheteCollectiveMma { + using Schedule = KernelScheduleType; + static_assert( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "KernelSchedule must be one of the warp specialized policies"); + + public: + static constexpr bool ALayoutIsPrepacked = true; + + // Prepacked block shape (N is M in the transposed problem) + using PPBlockShape_MK = typename GmemLayoutA::PPBlockShape_NK; + // Prepacked blocks per dim for a single MMA tile + using PPBlocksPerTile_MK = decltype(make_shape( + size<0>(TileShape_MNK{}) / size<0>(PPBlockShape_MK{}), + size<2>(TileShape_MNK{}) / size<1>(PPBlockShape_MK{}))); + + using IlvdBlkLayout = typename GmemLayoutA::IlvdBlkLayout; + + static_assert(size<0>(TileShape_MNK{}) % size<0>(PPBlockShape_MK{}) == 0, + "M in PPBlockShape_MK must evenly divide M TileShape_MNK"); + static_assert(size<2>(TileShape_MNK{}) % size<1>(PPBlockShape_MK{}) == 0, + "K in PPBlockShape_MK must evenly divide K TileShape_MNK"); + + using ArchTag = arch::Sm90; + using TileShape = TileShape_MNK; + using ClusterShape = ClusterShape_MNK; + using ElementA = deduce_mixed_width_dtype_t<0, ElementATuple_>; + using StrideA = TagToStrideA_t; + using ElementB = ElementB_; + using StrideB = TagToStrideB_t; + using ElementAccumulator = ElementAccumulator_; + using ElementMma = ElementB; + using ElementATuple = + cute::conditional_t::value, + cute::tuple, ElementATuple_>; + + static constexpr cute::GMMA::Major GmmaMajorA = + gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = + gmma_rs_tag_to_major_B(); + + // For coop schedules we have two warp groups cooperatively issuing wgmma + // instructions so we use 2 atoms along the M dim (one for each warpgroup) + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutMNK{})); + + private: + // + // the setup section (until "section setup end") contains a combination of + // modified code from (used as a starting point): + // `cutlass/gemm/collective/builders/sm90_gmma_builder.inl` + // `cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp` + // (upstream) + // + // however in-order to simplify the code we combine a lot of the logic from + // `CollectiveMma` and `CollectiveBuilder` into this class, this also makes + // sense given that we have flexibility on layouts here. We also simplify the + // code by only supporting scales and zeros for A (in the transposed problem, + // B from an API perspective), also since we force A to be the narrow type + // (i.e. the type to be upconverted) we can remove all the `SwapAB` logic in + // the upstream also simplifying the code. This section includes new logic + // (compared ustream) for handling the prepacked-A layouts (in the transposed + // problem, B from an API perspective) + // + using ElementScale = deduce_mixed_width_dtype_t<1, ElementATuple_>; + using ElementZero = deduce_mixed_width_dtype_t<2, ElementATuple_>; + + static constexpr bool IsANarrow = cutlass::sizeof_bits::value < + cutlass::sizeof_bits::value; + static_assert(IsANarrow, + "A must be the narrow one since its the one that flows through " + "registers."); + + public: + static constexpr int PipelineStages = + compute_stage_count_or_override_single_affine_transformed_input< + sm90_smem_capacity_bytes, ElementA, ElementB, ElementScale, + ElementZero, TileShape_MNK>(StageCountType{}); + + struct DispatchPolicy { + constexpr static int Stages = PipelineStages; + using ClusterShape = ClusterShape_MNK; + using Schedule = KernelScheduleType; + }; + + using GmemTiledCopyA = + decltype(sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = + decltype(sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + // ((T, V), (BlocksM, BlocksK), pipe) -> offset + using SmemLayoutA = decltype(GmemLayoutA::TVbNbKL_to_offset( + make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), + Int{}))); + + using SmemLayoutAtomARowMajor = + decltype(rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemLayoutAtomScale = Layout< + Shape(SmemLayoutAtomARowMajor{})), cute::Int<1>>>; + + using SmemLayoutAtomB = + decltype(rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemCopyAtomA = Copy_Atom; + using SmemCopyAtomB = void; + + // + // Validity checks + // + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, + "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + private: + enum class ConversionMode { + DirectConvert, + ConvertAndScale, + ConvertAndScaleWithZero + }; + + public: + // + // Type Aliases + // + using KernelSchedule = KernelScheduleType; + + // For cases where we can't have a void type, we can use this to allow the + // code to compile when the scale / zero is void. + using NonVoidElementScale = + cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = + cute::conditional_t, float, ElementZero>; + + // These are always MN major + using StrideScale = cute::Stride, int64_t, int64_t>; + // For cases where we can't have a void scale, we can use this to allow the + // code to compile when the scale is void. + using NonVoidStrideScale = + cute::conditional_t, + cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert((cutlass::gemm::detail::is_k_major()), + "The transformed matrix (A) must be K-major."); + + static_assert((sizeof(ElementB) == 2) || + (cutlass::gemm::detail::is_k_major() && + cutlass::gemm::detail::is_k_major()), + "The unscaled element (matrix B) must be 2 bytes OR both " + "inputs must be K-major"); + + static_assert(cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major " + "if B is scaled]."); + + static_assert(std::is_same_v, + "TiledMma::ValTypeC must be the same as ElementAccumulator."); + + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; + + using SmemCopyAtomScale = Copy_Atom; + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any + // rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = + cute::conditional_t>>; + using InternalElementB = + cute::conditional_t>>; + + using TransformA = cute::identity; + using TransformB = cute::identity; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = + cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), + shape<1>(SmemLayoutAtomScale{}))); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomScale{}) == 2, + "SmemLayoutAtomScale must be rank 2"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must evenly divide tile k shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutACopy = decltype(tile_to_shape( + SmemLayoutAtomARowMajor{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape( + SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), + Int{}))); + + // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major + // only (e.g. tf32, fp32, fp8, int8). + static constexpr bool IsLayoutAmnBmn = + cute::is_same_v, + layout::ColumnMajor> && + cute::is_same_v, + layout::RowMajor>; + + static_assert(DispatchPolicy::Stages >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc " + "for this mainloop."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + using GmmaSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + + // These two restrictions are related, so we place the assertions together. + // To relax them, we need to handle loading more than 1 row of scales for + // every main loop iteration. We must also handle updating the pipeline + // transaction bytes on the fly. NOTE: Deleting this assertion without + // required changes will cause the code to hang. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, + "size<1>(SmemLayoutAtomScale) must be 1."); + + private: + static constexpr ConversionMode get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = + KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + + // Same as upstream, should be kept the same when possible + static constexpr auto elements_per_smem_scale() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } else if constexpr (ModeHasScales) { + return cute::cosize_v; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in scale smem allocation."); + } + } + + // Same as upstream, should be kept the same when possible + static constexpr auto elements_per_smem_zero() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || + KernelConversionMode == ConversionMode::ConvertAndScale) { + return 0; + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + return cute::cosize_v; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in scale smem allocation."); + } + } + + // Same as upstream, should be kept the same when possible, not formatte for + // easier comparison + // clang-format off + // These methods use some the public members of the class. For that reason, we define them after the public section. + static constexpr uint32_t + compute_tma_transaction_bytes_mk() { + constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return baseline_bytes; + } + else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return baseline_bytes + scale_tx_bytes; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return baseline_bytes + scale_tx_bytes + zero_tx_bytes; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + + static constexpr uint32_t + compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + // clang-format on + + // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) + using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset( + make_shape(int32_t(0), int32_t(0), int32_t(0))))); + + using ATensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + shape(GmemLayoutA::TVbNbKL_to_offset( + make_shape(int32_t(0), int32_t(0), int32_t(0)))), + PrepackedStrideA{})); + + using BTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(StrideB{}, int32_t(0)), StrideB{})); + using ScaleTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{})); + + using ZeroTensor = decltype(make_tensor( + get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{})); + + static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) { + return make_tma_copy( + GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), + shape(SmemLayoutA{}(_, _, cute::Int<0>{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_scale( + ScaleTensor tensor_scale = ScaleTensor{}) { + return make_tma_copy(GmemTiledCopyScale{}, tensor_scale, + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_zero( + ZeroTensor tensor_zero = ZeroTensor{}) { + return make_tma_copy(GmemTiledCopyScale{}, tensor_zero, + SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + } + + static constexpr auto make_tma_copy_B(BTensor tensor_b = BTensor{}) { + return make_tma_copy( + GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + } + + public: + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // with `RealInternalElementA` -> `ElementA` since we support `SwapAB` logic + // clang-format off + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage + { + static constexpr int scale_elements = elements_per_smem_scale(); + static constexpr int zero_elements = elements_per_smem_zero(); + struct TensorStorage : cute::aligned_struct { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + ElementScale const* ptr_S = nullptr; + NonVoidStrideScale dS{}; + int group_size = 0; + ElementZero const* ptr_Z = nullptr; + uint32_t mma_promotion_interval = 4; + }; + // clang-format on + + // + // section setup end + // + + // Similar (but not idendtical) to upstream, should be kept the same when + // possible + // compared to upstream we use `make_tma_copy_A`, `make_tma_copy_B` etc. to + // define the TMA types + // Device side kernel params + struct Params { + public: + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A()); + using TMA_Scale = decltype(make_tma_copy_scale()); + using TMA_Zero = decltype(make_tma_copy_zero()); + using TMA_B = decltype(make_tma_copy_B()); + + // required by outer loop: i.e. + // cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + int64_t scale_k; + int group_size; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + }; + + // + // Methods + // + + // Similar (but not idendtical) to upstream, should be kept the same when + // possible + // compared to upstream we use `make_tma_copy_A` and `TVbNbKL_to_offset` here + // to handle the prepacked layout + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, + void* workspace) { + (void)workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is + // only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + auto make_logical_tensor = [&](auto ptr, auto shape, auto stride) { + return make_tensor(get_logical_ptr(ptr), make_layout(shape, stride)); + }; + + typename Params::TMA_A tma_load_a; + typename Params::TMA_B tma_load_b; + typename Params::TMA_Scale tma_load_scale; + typename Params::TMA_Zero tma_load_zero; + + auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + tma_load_a = make_tma_copy_A( + make_logical_tensor(ptr_A, shape(layout), stride(layout))); + + tma_load_b = make_tma_copy_B( + make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB)); + + if constexpr (ModeHasScales) { + tma_load_scale = make_tma_copy_scale(make_logical_tensor( + args.ptr_S, make_shape(M, args.group_size, L), args.dS)); + } + + if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + tma_load_zero = make_tma_copy_zero(make_logical_tensor( + args.ptr_Z, make_shape(M, args.group_size, L), args.dS)); + } + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0}; + } else if constexpr (ModeHasScales) { + auto scale_k = (K + args.group_size - 1) / args.group_size; + + return {tma_load_a, tma_load_b, tma_load_scale, + tma_load_zero, scale_k, args.group_size}; + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // with `SwapAB ? N : M -> M` since we dont support SwapAB + // clang-format off + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + implementable = implementable && (args.ptr_S == nullptr); + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) { + const int scale_mn = M; + const int scale_k = (K + args.group_size - 1) / args.group_size; + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); + implementable = implementable && args.group_size != 0; + implementable = implementable && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.ptr_Z != nullptr); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor()); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + + } + // clang-format off + + // Modified from upstream, should be kept close to that when possible + // the main difference is special handling for the prepacked A layout + // + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the + // contract Returned tuple must contain at least two elements, with the first + // two elements being: gA_mkl - The tma tensor, A after a local tile so it + // has shape (TILE_V,TILE_B,m,k,l) gB_nkl - The tma tensor, B after a local + // tile so it has shape (TILE_N,TILE_K,n,k,l) The rest of the tensors can be + // specified as needed by this collective. + // NOTE: TILE_B is the prepacked block index within a tile. TILE_V is the + // values within a prepacked block. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, + Params const& mainloop_params) const { + using X = Underscore; + auto M = get<0>(problem_shape_MNKL), N = get<1>(problem_shape_MNKL), + K = get<2>(problem_shape_MNKL), L = get<3>(problem_shape_MNKL); + + // (TILE_V,TILE_B,m,k,l) + auto make_gA_mkl = [&]() { + // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) + auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout)); + return local_tile(mA_mkl, + make_shape(size<0>(layout), PPBlocksPerTile_MK{}), + make_coord(0, make_coord(_, _))); + }; + + // (TILE_N,TILE_K,n,k,l) + auto make_gB_nkl = [&]() { + Tensor mB_nkl = + mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); + return local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), + Step{}); + }; + + // (TILE_M,TILE_Scale_K,m,scale_k,l) + auto make_gS_mkl = [&]() { + auto scale_k = mainloop_params.scale_k; + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor( + make_shape(M, scale_k, L)); + return local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); + }; + + // (TILE_M,TILE_Scale_K,m,scale_k,l) + auto make_gZ_mkl = [&]() { + auto scale_k = mainloop_params.scale_k; + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor( + make_shape(M, scale_k, L)); + return local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _)); + }; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl()); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScale) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl()); + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl(), + make_gZ_mkl()); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in load_init."); + } + } + + // Similar to upstream, should be kept close to that when possible + // the main difference is in the layout comments + // clang-format off + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + /// This overload gets triggered when we have scales. + template < + class... Ts, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A, B and Scales + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (TILE_V,TILE_B,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (TILE_N,TILE_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } + else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes + // on the fly. + // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K + // is a multiple of the threadblock tile K + const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K. + copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + // clang-format off + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + // clang-format on + + // Modified from upstream, should be kept close to that when possible + // the main differences are handling the prepacked A layout, and separating + // the loading of A from upcoverting A + // + // Perform a collective-scoped matrix multiply-accumulate + // Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, FrgTensorC& accum, + int k_tile_count, int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, + "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutB{}) == 3, + "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for " + "RF sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for " + "smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + // ((T, (FrgV,(RestM, RestK)), (BlocksM, BlocksK), pipe) -> offset + auto constexpr smem_A = SmemLayoutA{}; + + // convert: + // ((T, (MMA,(MMA_M, MMA_K)), (BlocksM, BlocksK), pipe) -> offset + // to: + // (T, MMA, ((MMA_M, BlocksM), (MMA_K, BlocksK)), pipe) -> offset + // which can be thought of as: + // (T, MMA, (MMA_M, MMA_K), pipe) -> offset + auto constexpr smem_A_mma_ = + make_layout(get<0, 0>(smem_A), get<0, 1, 0>(smem_A), + zip(get<0, 1, 1>(smem_A), get<1>(smem_A)), get<2>(smem_A)); + // flatten to: + // (T, MMA, MMA_M, MMA_K, pipe) -> offset + auto constexpr smem_A_mma = smem_A_mma_(_, _, make_coord(_, _), _); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), + smem_A_mma); // (T, MMA, MMA_M, MMA_K, pipe) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = sA(thread_idx, _, _, _, _); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate fragments and descriptors + Tensor tCrA_load = make_tensor( + tCsA(_, _, _, Int<0>{}).shape()); // (MMA,MMA_N,MMA_K) + Tensor tCrA_mma = make_fragment_like(tCrA_load); + + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + static constexpr int A_CPY_VEC = + decltype(max_common_vector(tCsA, tCrA_load)){}; + + static constexpr int COVERSION_WIDTH = + std::min(A_CPY_VEC, int(size<0>(tCrA_mma))); + + auto load_A_to_registers = [&](int read_stage) { + copy(create_auto_vectorizing_copy(), + tCsA(_, _, _, read_stage), tCrA_load(_, _, _)); + }; + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = + partition_extra_mma_info(thread_mma, shared_tensors); + auto copy_partitions_extra_info = retile_extra_mma_info( + tiled_mma, partitioned_extra_info, warp_group_thread_idx); + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + auto convert_A = [&, a_vec = Int{}](int k_block, + int read_stage) { + load_extra_info_to_registers(partitioned_extra_info, + copy_partitions_extra_info, k_block, + read_stage); + transform_A_kblock(tCrA_load, a_vec, tCrA_mma, partitioned_extra_info, + k_block); + }; + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + + constexpr int K_BLOCK_MAX = size<2>(tCrA_load); + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + load_A_to_registers(read_stage); + convert_A(0, read_stage); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + if (k_block < K_BLOCK_MAX - 1) { + convert_A(k_block + 1, smem_pipe_read.index()); + } + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + } + + --k_tile_count; + if (k_tile_count > 0) { + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to + // overwrite the A registers for the first mma. + warpgroup_wait(); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + load_A_to_registers(smem_pipe_read.index()); + convert_A(0, smem_pipe_read.index()); + } + } + + if (k_tile_count == 0) { + return; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, + // so we can release prior barrier + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ + // on it + ++smem_pipe_release; + } + + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + load_A_to_registers(smem_pipe_read.index()); + convert_A(0, smem_pipe_read.index()); + } else { + convert_A(k_block + 1, read_stage); + } + } + warpgroup_fence_operand(accum); + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_, _, k_block), + tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // release prior barrier + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ + // on it + ++smem_pipe_release; + } + + if (k_block < K_BLOCK_MAX - 1) { + convert_A(k_block + 1, read_stage); + } + } + } + + warpgroup_fence_operand(accum); + } + + // Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_release, + int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on + // it + ++smem_pipe_release; + } + } + + private: + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Utilities for any additional inputs inside of the TMA load + template + CUTLASS_DEVICE + auto partition_extra_tma_inputs( + Params const& mainloop_params, + cute::tuple const& load_inputs, + TensorStorage& shared_tensors, + uint2 const& cluster_local_block_id, + int const m_coord, + int const l_coord) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gS_mkl = get<2>(load_inputs); + auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tSgS, tSsS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gZ_mkl = get<3>(load_inputs); + auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + // clang-format off + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Utilities for partitioning extra inputs for loading from smem in the mainloop. + template + CUTLASS_DEVICE + auto partition_extra_mma_info( + ThreadMma const& mma_thread_slice, + TensorStorage& shared_tensors) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS, tCrS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = mma_thread_slice.partition_A(sZ); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); + return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + // clang-format on + + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Returns the tiled copy and copy views for the extra inputs. + template + CUTLASS_DEVICE + auto retile_extra_mma_info( + TiledMma const& tiled_mma, + cute::tuple& partitioned_extra_info, + int const warp_group_thread_idx) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + // clang-format on + + // Similar to `copy_A_and_extra_info` upstream, should be kept the same when + // possible + // the main differences this only loads the extra info into registers and + // not A (since we now preload more of A in the main pipeline) + // Load scales and zeros into registers if required + template + CUTLASS_DEVICE void load_extra_info_to_registers( + cute::tuple const& partitioned_mma_extra_info, + cute::tuple const& tiled_copy_and_views, int k_block, + int read_stage) { + if (k_block == 0) { + // We are starting a new k-tile so copy the scale + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); + auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); + auto tCsS = cute::get<0>(partitioned_mma_extra_info); + copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), + tCrS_copy_view(_, _, k_block)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + auto tCsZ = cute::get<2>(partitioned_mma_extra_info); + auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); + copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), + tCrZ_copy_view(_, _, k_block)); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); + } + } + } + + // Similar to upstream, should be kept the same when possible. + // the main differences are that `convert_tensor` supports interleaved + // layouts and bfloat16 has been optimized. `transform_internal_A` has also + // been inlined for code simplicity. + // Utilities to transform A. + template + CUTLASS_DEVICE void transform_A_kblock( + TCrA_load const& tCrA_load, cute::Int vec_A, + TCrA_mma& tCrA_mma, cute::tuple const& partitioned_extra_info, + int const k_block) { + auto in = tCrA_load(_, _, k_block); + auto out = tCrA_mma(_, _, k_block); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + convert_tensor(in, out, vec_A); + } else if constexpr (ModeHasScales) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto converted_inputs = + make_fragment_like(tCrA_mma)(_, _, k_block); + auto scales = tCrS(_, _, 0); + + // First, we upcast the inputs to the scale type + convert_tensor(in, converted_inputs, vec_A); + // Apply scales and broadcast across inputs, store in converted_inputs + + // We need to cast to nv_bfloat16 for the multiply since + // `cutlass::bfloat16_t` has an overloaded operator* that upconverts to + // float, which nvcc will not optimize to using vectorized fma + // instructions (i.e. hfma.bf16_v2) + if constexpr (std::is_same_v) { + cute::transform( + recast(converted_inputs), recast(scales), + recast(converted_inputs), cute::multiplies{}); + } else { + cute::transform(converted_inputs, scales, converted_inputs, + cute::multiplies{}); + } + + // Apply zeros if required + if constexpr (KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { + auto tCrZ = cute::get<3>(partitioned_extra_info); + auto converted_zeros = make_fragment_like(tCrZ)(_, _, 0); + + convert_tensor(tCrZ(_, _, 0), converted_zeros); + if constexpr (std::is_same_v) { + cute::transform(recast(converted_inputs), + recast(converted_zeros), + recast(converted_inputs), cute::plus{}); + } else { + cute::transform(converted_inputs, converted_zeros, converted_inputs, + cute::plus{}); + } + } + + // Finally, we convert the scaled inputs to the mma type. + convert_tensor(converted_inputs, out); + } else { + static_assert(cutlass::detail::dependent_false, + "No A data is loaded."); + } + } + + // Modified from upstream, should be kept the same when possible + // the main differences is that this version supports interleaved converts + // Utilities for transforming the A operand prior to issuing tensorcore math. + template > + CUTLASS_DEVICE void convert_tensor( + Tensor const& in, + Tensor& out, + cute::Int width = {}) { + // This is an element-wise conversion where we expect both tensors to have + // the same layout. As a result, we can cast as a cutlass array to use the + // fast numeric converters without worrying about indexing into the layout. + constexpr int N = cosize_v; + + // The inputs must be backed by registers & be statically sized. + static_assert(is_rmem::value, + "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, + "Output tensor for A conversion must come from registers"); + static_assert(is_static_v, + "Tensor layout for the conversion must be static"); + static_assert(cosize_v == size(TensorLayout{}), + "Cosize and size of the layout must be equal."); + static_assert( + N % ConversionVectorWidth == 0, + "Conversion vector width must divide cosize of the tensor layout."); + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + + constexpr cutlass::FloatRoundStyle RoundStyle = + cutlass::FloatRoundStyle::round_to_nearest; + + using Converter = cutlass::InterleavedNumericArrayConverter< + IlvdBlkLayout, DstType, SrcType, ConversionVectorWidth, RoundStyle>; + + constexpr int NumIterations = N / ConversionVectorWidth; + + for (int ii = 0; ii < NumIterations; ++ii) { + SrcArray const* src_array_ptr = + reinterpret_cast(raw_pointer_cast(in.data())) + ii; + DstArray* dst_array_ptr = + reinterpret_cast(raw_pointer_cast(out.data())) + ii; + *dst_array_ptr = Converter::convert(*src_array_ptr); + } + } +}; + +} // namespace machete \ No newline at end of file diff --git a/kernels/quantization/machete/machete_mm_kernel.cuh b/kernels/quantization/machete/machete_mm_kernel.cuh new file mode 100644 index 000000000..7238bdcaa --- /dev/null +++ b/kernels/quantization/machete/machete_mm_kernel.cuh @@ -0,0 +1,238 @@ +#pragma once + +#include +#include +#include + +// clang-format off +// The cutlass include order matters (annoyingly) +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/aphrodite_numeric_conversion.cuh" +#include "machete_collective_builder.cuh" +#include "machete_prepacked_layout.cuh" +#include "machete_interleaving_utils.cuh" + +namespace machete { + +using namespace cute; + +// NOTE This kernel computes D = alpha * A * B + beta * C by computing +// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma +// instructions only support sourcing from registers for the left-hand +// operand, we want to upconvert/decompress the quantized operand in +// register. Since the primary use case we want to support is Y = XW^t where +// W is quantized, in this situation or right-hand operand is quantized so +// we compute the transpose to move it to the left-hand side. +template +struct MacheteKernelTemplate { + using MmaType = ElementA_; + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementD = ElementD_; + using ElementC = cute::conditional_t; + using ElementZ = ZeroT; + using ElementS = ScaleT; + + using ElementAccumulator = + AccumulatorT; // Element type for internal accumulation + using ElementCompute = AccumulatorT; // For Epilogue + + using BTypeTuple = cute::conditional_t< + with_scales, + cute::conditional_t, + cute::tuple>, + ElementB>; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + using LayoutScale = cutlass::layout::RowMajor; + // not actually used since B has the prepacked layout, but required by cutlass + using _LayoutB = cutlass::layout::ColumnMajor; + + // Interface strides expected by create_arguments (will get transposed) + using StrideA = cutlass::detail::TagToStrideA_t; + using StrideC = cutlass::detail::TagToStrideA_t; + using StrideD = cutlass::detail::TagToStrideA_t; + using StrideS = cutlass::detail::TagToStrideA_t; + using StrideZ = StrideS; + + using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; + using LayoutC_Transpose = + typename cutlass::layout::LayoutTranspose::type; + using LayoutD_Transpose = + typename cutlass::layout::LayoutTranspose::type; + + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using PrepackedLayoutB = + PrepackedLayoutBTemplate; + + static int constexpr TileShapeK = + 128 * 8 / cutlass::sizeof_bits::value; + static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v; + static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v; + static int constexpr AlignmentC = + (with_C) ? 128 / cutlass::sizeof_bits_v : 0; + static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v; + + using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{}, + cute::Int{})); + using ClusterShape = typename ScheduleConfig::ClusterShape; + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; + using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; + using TileScheduler = typename ScheduleConfig::TileScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose, + AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::APHRODITECollectiveBuilder< + cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass, + BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // stride_B is unused (since B is prepacked), but still required by cutlass + using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>; + + using Arguments = typename Gemm::Arguments; + using MainloopArguments = typename GemmKernel::MainloopArguments; + using EpilogueArguments = typename GemmKernel::EpilogueArguments; + + template + static Arguments create_arguments( + cudaStream_t stream, + ElementA const* A_ptr, // A is an MxK matrix + Layout const& layout_A, + ElementB const* B_ptr, // B is an KxN prepacked matrix + ElementD* D_ptr, // D is an MxN matrix + Layout const& layout_D, + ElementC const* C_ptr, // C is an MxN matrix + std::optional> const& layout_C, + ElementS const* S_ptr, // S is an scale_KxN matrix + std::optional> const& layout_S, + ElementZ const* Z_ptr, // Z is an scale_KxN matrix + std::optional> const& layout_Z, + ElementCompute alpha, ElementCompute beta, + std::optional maybe_group_size) { + static_assert(!with_zeropoints || with_scales); + + int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); + + int const group_size = + maybe_group_size == -1 ? K : maybe_group_size.value_or(K); + int const scale_k = (K + group_size - 1) / group_size; + + TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); + TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N); + + if constexpr (with_C) { + TORCH_CHECK(C_ptr && layout_C); + } else { + TORCH_CHECK(!C_ptr, "C not supported"); + } + + if constexpr (with_scales) { + TORCH_CHECK(S_ptr && layout_S); + TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N)); + } else { + TORCH_CHECK(!S_ptr, "Scales not supported"); + } + + if constexpr (with_zeropoints) { + TORCH_CHECK(Z_ptr && layout_Z); + TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N)); + TORCH_CHECK(layout_S && *layout_Z == *layout_S, + "Scales and zeros must have the same layout"); + } else { + TORCH_CHECK(!Z_ptr, "Zeropoints not supported"); + } + + // Transpose A and D + // A doesn't need to be transposed since cutlass expects a NxK matrix + // for B (which is At) + auto stride_At = layout_A.stride(); + auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride(); + auto stride_Ct = stride_Dt; + if (layout_C) { + stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride(); + } + + MainloopArguments mainloop_arguments{}; + EpilogueArguments epilogue_arguments{ + {alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt}; + + if constexpr (with_scales && with_zeropoints) { + auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, + S_ptr, stride_S, group_size, Z_ptr}; + } else if constexpr (with_scales) { + auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); + mainloop_arguments = MainloopArguments{ + B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size}; + } else { + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At}; + } + + return Arguments{cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K, 1}, + mainloop_arguments, + epilogue_arguments}; + }; + + static size_t get_workspace_size(Arguments const& args) { + return Gemm::get_workspace_size(args); + } + + static bool can_implement(Arguments const& args) { + return Gemm::can_implement(args) == cutlass::Status::kSuccess; + } + + static void run(Arguments const& args, void* workspace, cudaStream_t stream) { + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(args, workspace, stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, + "Machete kernel failed to initialize workspace"); + + status = gemm_op.run(stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed"); + } +}; + +}; // namespace machete \ No newline at end of file diff --git a/kernels/quantization/machete/machete_mm_launcher.cuh b/kernels/quantization/machete/machete_mm_launcher.cuh new file mode 100644 index 000000000..60a4ed605 --- /dev/null +++ b/kernels/quantization/machete/machete_mm_launcher.cuh @@ -0,0 +1,95 @@ +#pragma once + +#include +#include + +#include "machete_mm_kernel.cuh" +#include "cutlass_extensions/torch_utils.hpp" + +namespace machete { + +struct PyTorchArguments { + torch::Tensor const& A; + torch::Tensor const& B; + c10::optional const& scales; + c10::optional const& zeros; + c10::optional group_size; + c10::optional const& C; + c10::optional alpha; + c10::optional beta; + c10::optional schedule; +}; + +template +torch::Tensor run_impl(PyTorchArguments args) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A)); + + auto device = args.A.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + + using EleA = typename MacheteKernel::ElementA; + using EleB = typename MacheteKernel::ElementB; + using EleC = typename MacheteKernel::ElementC; + using EleD = typename MacheteKernel::ElementD; + using EleScale = typename MacheteKernel::ElementS; + using EleZero = typename MacheteKernel::ElementZ; + + using StrideA = typename MacheteKernel::StrideA; + using StrideC = typename MacheteKernel::StrideC; + using StrideD = typename MacheteKernel::StrideD; + using StrideS = typename MacheteKernel::StrideS; + using StrideZ = typename MacheteKernel::StrideZ; + + int M = args.A.size(0); + int N = args.B.size(1); + int K = args.A.size(1); + + // Allocate output + torch::Tensor D = + torch::empty({M, N}, torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); + + auto const &A = args.A, &B = args.B; + auto const &C = args.C, &scales = args.scales, &zeros = args.zeros; + + auto layout_A = make_cute_layout(A, "A"); + auto layout_D = make_cute_layout(D, "D"); + auto layout_C = maybe_make_cute_layout(C, "C"); + auto layout_S = maybe_make_cute_layout(scales, "scales"); + auto layout_Z = maybe_make_cute_layout(zeros, "zeros"); + + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.mutable_data_ptr()); + auto C_ptr = static_cast(C ? C->const_data_ptr() : nullptr); + auto S_ptr = + static_cast(scales ? scales->const_data_ptr() : nullptr); + auto Z_ptr = + static_cast(zeros ? zeros->const_data_ptr() : nullptr); + + auto arguments = MacheteKernel::create_arguments( + stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr, + layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0), + args.group_size); + TORCH_CHECK(MacheteKernel::can_implement(arguments), + "Machete kernel cannot be run with these arguments"); + + size_t workspace_size = MacheteKernel::get_workspace_size(arguments); + torch::Tensor workspace = torch::empty( + workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device)); + + MacheteKernel::run(arguments, workspace.mutable_data_ptr(), stream); + + return D; +}; + +template +struct GemmDispatcher { + static torch::Tensor dispatch(PyTorchArguments args); + static std::vector supported_schedules(); +}; + +}; // namespace machete \ No newline at end of file diff --git a/kernels/quantization/machete/machete_prepack_kernel.cuh b/kernels/quantization/machete/machete_prepack_kernel.cuh new file mode 100644 index 000000000..8e0210458 --- /dev/null +++ b/kernels/quantization/machete/machete_prepack_kernel.cuh @@ -0,0 +1,62 @@ +#pragma once + +#include "machete_mm_kernel.cuh" +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/torch_utils.hpp" + +namespace machete { + +template +static __global__ void prepack_B_kernel(BInTensor B_in, + BTiledOutTensor B_tiled_out) { + auto tB_in = local_tile(B_in, TileShapeNKL{}, + make_coord(blockIdx.x, blockIdx.y, blockIdx.z)); + auto tB_out = B_tiled_out(make_coord(_, _), + make_coord(blockIdx.x, blockIdx.y), blockIdx.z); + + auto tiled_copy = make_tiled_copy(Copy_Atom{}, + Layout, Stride<_32, _1>>{}, + Layout>{}); + + auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + + Tensor thr_tile_S = thr_copy.partition_S(tB_in); + Tensor thr_tile_D = thr_copy.partition_D(tB_out); + + // Construct a register-backed Tensor with the same shape as each thread's + // partition + auto fragment = make_tensor(shape(thr_tile_D)); + + // Copy from GMEM to RMEM and from RMEM to GMEM + copy(tiled_copy, thr_tile_S, fragment); + copy(Copy_Atom{}, fragment, thr_tile_D); +} + +template +static void prepack_B(cudaStream_t stream, + typename PrepackedLayoutB::ElementB const* B_in_ptr, + InLayout B_layout, + typename PrepackedLayoutB::ElementB* B_out_ptr) { + using TileShapeNKL = + decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{})); + auto ilvd_NKbNbKL_to_offset = + PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout)); + + TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0); + TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0); + TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0); + + auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{}); + auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{}); + auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{}); + + auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout); + auto B_tiled_out = + make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset); + + prepack_B_kernel + <<>>(B_in, B_tiled_out); +} + +}; // namespace machete \ No newline at end of file diff --git a/kernels/quantization/machete/machete_prepack_launcher.cuh b/kernels/quantization/machete/machete_prepack_launcher.cuh new file mode 100644 index 000000000..df7831299 --- /dev/null +++ b/kernels/quantization/machete/machete_prepack_launcher.cuh @@ -0,0 +1,71 @@ +#pragma once + +#include "machete_prepack_kernel.cuh" +#include "cutlass_extensions/torch_utils.hpp" + +namespace machete { + +template +torch::Tensor prepack_impl(torch::Tensor const B) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(B)); + using ElementB = typename PrepackedLayoutB::ElementB; + using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK; + + auto device = B.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + auto B_ptr = static_cast(B.const_data_ptr()); + // elements per storage item for B + auto eles_per_storage = + (B.dtype().itemsize() * 8) / cute::sizeof_bits_v; + + // torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to + // match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L) + auto Bt_packed = B.t(); + + TORCH_CHECK( + (B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0, + "B.shape[0] (in terms of unpacked elements) must be a multiple of ", + size<1>(PPBlockShape_NK{})); + TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0, + "B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{})); + + using StrideB = cutlass::detail::TagToStrideB_t; + auto const l_Bt_packed = make_cute_layout(Bt_packed, "B"); + + // convert (N,packed_K,L) layout to (N,K,L) layout + // in effect we want to do: blocked_product(layout_Bt_packed, + // make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}), + // Step<_1, _0, _2>{})); + // but blocked_product does not support dynamic strides so we implement the + // equivalent manually, + // new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L) + // new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage) + // when s1 == 1 + TORCH_CHECK(stride<1>(l_Bt_packed) == 1); + // clang-format off + auto const layout_Bt = make_layout( + transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) { + return idx == 1 ? ele * eles_per_storage : ele; + }), + transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) { + return idx != 1 ? ele * eles_per_storage : ele; + })); + // clang-format on + + // Allocate output + torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous); + + prepack_B(stream, B_ptr, layout_Bt, + static_cast(D.mutable_data_ptr())); + + return D; +}; + +template +struct PrepackBDispatcher { + static torch::Tensor dispatch(torch::Tensor B); +}; + +}; // namespace machete \ No newline at end of file diff --git a/kernels/quantization/machete/machete_prepacked_layout.cuh b/kernels/quantization/machete/machete_prepacked_layout.cuh new file mode 100644 index 000000000..78e2cc5ee --- /dev/null +++ b/kernels/quantization/machete/machete_prepacked_layout.cuh @@ -0,0 +1,220 @@ +#pragma once + +#include +#include +#include + +// clang-format off +// The cutlass include order matters (annoyingly) + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#include "cutlass_extensions/cute_utils.cuh" +#include "machete_collective_builder.cuh" +#include "machete_interleaving_utils.cuh" + +namespace machete { + +using namespace cute; + +struct IlvBlkLayoutAuto {}; + +// This defines a prepacked layout for the B matrix, where the matrix is broken +// up into PPBlockShape_NK blocks. The data within each block is then compactly +// stored in memory such that when performing a TiledMMA operation with the same +// shape as prepacked block, all the data for a given thread is contiguous in +// memory. This allows us to use wider shared memory loads when loading B from +// shared memory. The values within a thread are also potentially interlaeved +// inorder to allow for more efficient upconverting. +// +// The contract here is that the `TiledMma` determined below matches the one +// ultimately used in the kernel. (this is also why the other element types are +// required along with the kernel schedule) +template +// clang-format on +struct PrepackedLayoutBTemplate { + using MmaType = ElementA_; + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementD = ElementD_; + using ElementAccumulator = + AccumulatorT; // Element type for internal accumulation + using ElementMma = MmaType; + + // Only use interleaved layouts for subbyte weights, prmt instructions makes + // non-interleaved layouts for 8bit+ weights efficient enough we don't need + // iterleaved layouts + using IlvdBlkLayout = std::conditional_t< + std::is_same_v, + std::conditional_t <= 4, + decltype(get_interleaved_blk_layout< + ElementB, sizeof_bits_v, 32>()), + void>, + IlvBlkLayout_>; + + // TODO (LucasWilkinson): compare the performance for other sizes + // Prepacked block shape, smallest layout atom for loading into registers + // (can contain multiple wgmma instructions worth of data in one block) + // We ideally want this to be configured such that a thread can perform 128bit + // loads, i.e. we amount of data associated with each thread within a + // prepacked block is a multiple of 128bits, when using a cooperative sechdule + // we have 256 threads working a single block at a time, this means each + // thread works on `sizeof_bits_v * (128*64) / 256` bits of data, + // for a 4bit type this would be 128bits + using PPBlockShape_NK = Shape<_128, _64>; + + // Create the shape of the tile anticipated to be used by the GEMM kernel, + // when the kernel executes we will compute `Ct = Bt * At` since the + // quantized weights (B), must be the lhs operand so the flow through + // registers. + // The _128 here doesn't actually impact the shape of the stored tile directly + // but may impact the op selected by rs_op_selector + using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{}, + size<1>(PPBlockShape_NK{}))); + + static constexpr cute::GMMA::Major GmmaMajorB = + gmma_rs_tag_to_major_B(); + + // For coop schedules we have two warp groups cooperatively issuing wgmma + // instructions so we use 2 atoms along the M dim (one for each warpgroup) + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutMNK{})); + + // Prepacked block, (athrid, val) -> (N,K) + // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K) + CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() { + return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{})); + } + + // Prepacked block, (N,K) -> (athrid, val) + // i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) + CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() { + return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{}); + } + + // Prepacked block, (athrid, val) -> (storage_offset) + // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx) + CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() { + // Return iterleaved layout + return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{}); + } + + // Prepacked block, (athrid, val) -> (storage_offset) + // i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx) + CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() { + auto layout_no_interleave = + make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{}); + + if constexpr (std::is_same_v) { + return layout_no_interleave; + } else { + // interleave by transforming FrgV into interleaved blocks where each + // block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is + // (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4) + // if FrgV is {A, B, C, D, E, F, G, H} + // then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H} + auto frgV = get<1, 0>(layout_no_interleave); + auto ilvdBlk = IlvdBlkLayout{}; + static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4"); + auto ilvd_FrgV = make_layout( + make_shape(shape(ilvdBlk), Int{}), + make_stride(stride(ilvdBlk), size(ilvdBlk))); + + // Return iterleaved layout + return make_layout( + get<0>(layout_no_interleave), + make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave))); + } + } + + // Prepacked block, (M,K) -> (storage_offset) + CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() { + // do (M,K) -> (athrid, val) -> (storage_idx) + return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV()); + } + + // ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset( + Shape_NKL shape_mkl) { + constexpr auto block_layout = ppblock_TV_to_offset(); + + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx) + auto result = make_layout( + block_layout, + make_layout(blocks_shape, + compact_col_major(blocks_shape, size(block_layout)))); + + // ((athrid, val), (BlocksN, BlocksK, L)) + // => ((athrid, val), (BlocksN, BlocksK), L) + return group<1, 3>(result(_, repeat(result)>(_))); + } + + // ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset( + Shape_NKL shape_mkl) { + constexpr auto block_layout = ppblock_ilvd_NK_to_offset(); + + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx) + auto result = make_layout( + block_layout, + make_layout(blocks_shape, + compact_col_major(blocks_shape, size(block_layout)))); + + // ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN, + // BlocksK), L) + return group<1, 3>(result(_, repeat(result)>(_))); + } + + // ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L) + template + CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) { + auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})), + make_layout(size<1>(PPBlockShape_NK{}))); + + // ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L) + auto tiled_A = zipped_divide(make_layout(shape_mkl), tile); + return tiled_A.compose(ppblock_TV_to_NK(), _); + } + + // (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L) + template + CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) { + auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl); + return blocked_product(ppblock_NK_to_TV(), + make_layout(shape<1>(TVbNbK_to_NKL_layout))); + } +}; + +}; // namespace machete \ No newline at end of file diff --git a/kernels/quantization/machete/machete_pytorch.cu b/kernels/quantization/machete/machete_pytorch.cu new file mode 100644 index 000000000..46445521a --- /dev/null +++ b/kernels/quantization/machete/machete_pytorch.cu @@ -0,0 +1,79 @@ +#include "machete_mm_launcher.cuh" +#include "machete_prepack_launcher.cuh" +#include "core/scalar_type.hpp" + +namespace machete { + +using namespace aphrodite; + +// +// Utils (type dispatching) +// + +template +static auto scalar_type_dispatch(ScalarType const& type, Fn fn) { + if (type == aphrodite::kU4) { + return fn(cutlass::uint4b_t{}); + } else if (type == aphrodite::kU8) { + return fn(cutlass::uint8_t{}); + } else if (type == aphrodite::kU4B8) { + return fn(cutlass::aphrodite_uint4b8_t{}); + } else if (type == aphrodite::kU8B128) { + return fn(cutlass::aphrodite_uint8b128_t{}); + } else { + TORCH_CHECK(false, "Unsupported type ", type.str()); + } +} + +#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \ + AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__) + +#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__)) + +// +// Interface +// + +std::vector supported_schedules(ScalarTypeTorchPtr const& btype) { + return scalar_type_dispatch(*btype, [&](auto BType) { + return GemmDispatcher::supported_schedules(); + }); +} + +torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, + ScalarTypeTorchPtr const& btype, + c10::optional const& scales, + c10::optional const& zeros, + c10::optional group_size, + c10::optional const& C, + c10::optional alpha, c10::optional beta, + c10::optional schedule) { + auto args = PyTorchArguments{.A = A, + .B = B, + .scales = scales, + .zeros = zeros, + .group_size = group_size, + .C = C, + .alpha = alpha, + .beta = beta, + .schedule = schedule}; + + return scalar_type_dispatch(*btype, [&](auto BType) { + return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES( + A.scalar_type(), "machete_gemm", [&] { + using ComputeType = equivalent_cutlass_type_t; + return GemmDispatcher::dispatch(args); + }); + }); +} + +torch::Tensor prepack_B(torch::Tensor const& B, + ScalarTypeTorchPtr const& btype) { + return scalar_type_dispatch(*btype, [&](auto BType) { + return PrepackBDispatcher::dispatch(B); + }); +} + +}; // namespace machete \ No newline at end of file diff --git a/kernels/quantization/quant_ops.h b/kernels/quantization/quant_ops.h index 39d619dbb..047b98e91 100644 --- a/kernels/quantization/quant_ops.h +++ b/kernels/quantization/quant_ops.h @@ -110,6 +110,7 @@ void decompress_e8p_origorder(torch::Tensor YIs, torch::Tensor CB, torch::Tensor& Y); #ifndef _WIN32 +// Cutlass Kernels bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, @@ -125,6 +126,26 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, c10::optional const& azp, c10::optional const& bias); +// Machete Kernels +namespace machete { + +std::vector supported_schedules( + aphrodite::ScalarTypeTorchPtr const& btype); + +torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, + aphrodite::ScalarTypeTorchPtr const& btype, + c10::optional const& scales, + c10::optional const& zeros, + c10::optional group_size, + c10::optional const& C, + c10::optional alpha, c10::optional beta, + c10::optional schedule); + +torch::Tensor prepack_B(torch::Tensor const& B, + aphrodite::ScalarTypeTorchPtr const& btype); + +}; // namespace machete + torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, torch::Tensor const& b_q_weight, torch::Tensor const& s_tok, diff --git a/kernels/torch_bindings.cpp b/kernels/torch_bindings.cpp index ecde79ed9..423e11d64 100644 --- a/kernels/torch_bindings.cpp +++ b/kernels/torch_bindings.cpp @@ -189,6 +189,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b_scales, Tensor azp_adj," " Tensor? azp, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp); + + // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. + ops.def("machete_supported_schedules", &machete::supported_schedules); + ops.def( + "machete_gemm(Tensor A, Tensor B," + " __torch__.torch.classes._core_C.ScalarType btype," + " Tensor? scales, Tensor? zeros, int? group_size," + " Tensor? C, float? alpha, float? beta, str? schedule)" + "-> Tensor"); + ops.impl("machete_gemm", torch::kCUDA, &machete::gemm); + ops.def( + "machete_prepack_B(Tensor B," + " __torch__.torch.classes._core_C.ScalarType btype)" + "-> Tensor"); + ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B); + + ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); + ops.impl("permute_cols", torch::kCUDA, &permute_cols); + #endif // QuIP# GEMV diff --git a/tests/benchmarks/kernels/benchmark_machete.py b/tests/benchmarks/kernels/benchmark_machete.py new file mode 100644 index 000000000..0cbcc3c91 --- /dev/null +++ b/tests/benchmarks/kernels/benchmark_machete.py @@ -0,0 +1,370 @@ +import argparse +import copy +import itertools +import math +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from weight_shapes import WEIGHT_SHAPES + +from aphrodite import _custom_ops as ops +from aphrodite.common.utils import FlexibleArgumentParser +from aphrodite.quantization.utils.marlin_utils import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales) +from aphrodite.quantization.utils.marlin_utils_test import MarlinWorkspace +from aphrodite.quantization.utils.quant_utils import (gptq_pack, pack_rows, + quantize_weights) +from aphrodite.scalar_type import ScalarType, scalar_types + +DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] +DEFAULT_TP_SIZES = [1] + + +def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor: + w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # make col major + return ops.machete_prepack_B(w_q, wtype) + + +def make_bench_tensors( + atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int, + k: int +) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor, + torch.tensor]]]: + assert wtype.is_integer(), "TODO: support floating point weights" + + # we want to make sure that weights don't fit into L2 cache between runs so + # we construct enough weights to exceed L2 cache, which is 50mb on a H100 + # so we target total weight size > 2*50mb + num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits)) + + a = torch.randn((m, k), device="cuda", dtype=atype) * 5 + weights = [ + torch.randn((k, n), device="cuda", dtype=atype) + for _ in range(num_weights) + ] + quanitized_weights = [ + quantize_weights(w, wtype, group_size) for w in weights + ] + + return a, quanitized_weights + + +# impl + + +# bench +def bench_fn(label: str, sub_label: str, description: str, + fn: Callable) -> TMeasurement: + + min_run_time = 1 + return TBenchmark.Timer( + stmt="fn()", + globals={ + "fn": fn + }, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def loop_over_weights( + a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor, + torch.tensor, torch.tensor]], + fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor], + None]): + for w_ref, w_q, w_s, _ in weights: + fn(a, w_ref, w_q, w_s) + + +def bench(atype: torch.dtype, + wtype: ScalarType, + group_size: int, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + benchmark_marlinv1: bool = True, + sweep_schedules: bool = True) -> Iterable[TMeasurement]: + a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k) + sub_label += f", L={len(weights)}" + + weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp) + for w_ref, w_q, w_s, w_zp in weights] + + timers = [] + # pytorch impl + timers.append( + bench_fn( + label, sub_label, "torch.matmul", lambda: loop_over_weights( + a, + weights, + lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref), + ))) + + if benchmark_marlinv1: + w_ref = weights[0][0] + + w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device) + sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device) + g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device) + + def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor: + w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape) + return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape, + wtype.size_bits) + + def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: + return marlin_permute_scales(w_s, *w_ref.shape, group_size) + + weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q), + marlinv1_permute_scales(w_s), w_zp) + for w_ref, w_q, w_s, w_zp in weights] + + workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + # marlinv1 + timers.append( + bench_fn( + label, sub_label, "marlin_orig", lambda: loop_over_weights( + a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops. + gptq_marlin_gemm(a, + w_q, + w_s, + w_zp_empty, + g_idx, + sort_indices, + workspace.scratch, + wtype, + size_m=a.shape[0], + size_n=w_ref.shape[1], + size_k=w_ref.shape[0], + is_k_full=True)))) + + # machete + timers.append( + bench_fn( + label, sub_label, "machete_heuristic", lambda: loop_over_weights( + a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm( + a, w_q, wtype, b_scales=w_s, b_group_size=group_size)))) + + if sweep_schedules: + print("Finding best schedule for machete") + best = None + best_schedule = None + schedules = ops.machete_supported_schedules(wtype) + for schedule in reversed(schedules): + + def run(a, _, w_q, w_s, schedule=schedule): + ops.machete_gemm(a, + w_q, + wtype, + w_s, + b_group_size=group_size, + schedule=schedule) + + res = bench_fn(label, sub_label, "machete_best", + lambda: loop_over_weights(a, weights_machete, run)) + + print(f" {res.median:5.5} ", schedule) + if not best or res.median < best.median: + best = res + best_schedule = schedule + print("Best schedule:", best_schedule) + timers.append(best) + + return timers + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, sweep_schedules: bool, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + + results = [] + for m, k, n in MKNs: + timers = bench(dtype, + scalar_types.uint4b8, + 128, + m, + k, + n, + f"{dtype}-gemm", + f"MKN=({m}x{k}x{n})", + sweep_schedules=sweep_schedules) + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output( + data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None, +): + + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, args.sweep_schedules, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, args.sweep_schedules, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, args.sweep_schedules, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == "__main__": + + def to_torch_dtype(dt): + if dt == "bfloat16": + return torch.bfloat16 + if dt == "float16": + return torch.float16 + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Machete GEMM. + To run square GEMMs: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['bfloat16', 'float16']", + ) + parser.add_argument( + "--sweep-schedules", + action="store_true", + help="Run a sweep over all supported schedules", + ) + subparsers = parser.add_subparsers(dest="cmd", required=True) + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/tests/benchmarks/kernels/graph_machete_bench.py b/tests/benchmarks/kernels/graph_machete_bench.py new file mode 100644 index 000000000..471e0bcab --- /dev/null +++ b/tests/benchmarks/kernels/graph_machete_bench.py @@ -0,0 +1,64 @@ +import math +import pickle +import re +from collections import defaultdict +from typing import List + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from torch.utils.benchmark import Measurement as TMeasurement + +from aphrodite.common.utils import FlexibleArgumentParser + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Benchmark the latency of processing a single batch of ' + 'requests till completion.') + parser.add_argument('filename', type=str) + + args = parser.parse_args() + + with open(args.filename, 'rb') as f: + data: List[TMeasurement] = pickle.load(f) + + results = defaultdict(lambda: list()) + for v in data: + result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label) + if result is not None: + KN = result.group(1) + else: + raise Exception("MKN not found") + result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label) + if result is not None: + M = result.group(1) + else: + raise Exception("MKN not found") + + kernel = v.task_spec.description + results[KN].append({ + "kernel": kernel, + "batch_size": M, + "median": v.median + }) + + rows = int(math.ceil(len(results) / 2)) + fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) + axs = axs.flatten() + axs_idx = 0 + for shape, data in results.items(): + plt.sca(axs[axs_idx]) + df = pd.DataFrame(data) + sns.lineplot(data=df, + x="batch_size", + y="median", + hue="kernel", + style="kernel", + markers=True, + dashes=False, + palette="Dark2") + plt.title(f"Shape: {shape}") + plt.ylabel("time (median, s)") + axs_idx += 1 + plt.tight_layout() + plt.savefig("graph_machete_bench.pdf") diff --git a/tests/benchmarks/kernels/weight_shapes.py b/tests/benchmarks/kernels/weight_shapes.py new file mode 100644 index 000000000..25ec9d602 --- /dev/null +++ b/tests/benchmarks/kernels/weight_shapes.py @@ -0,0 +1,43 @@ +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "mistralai/Mistral-7B-v0.1": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-7b-hf": [ + ([4096, 12288], 1), + ([4096, 4096], 0), + ([4096, 22016], 1), + ([11008, 4096], 0), + ], + "meta-llama/Llama-3-8b": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-13b-hf": [ + ([5120, 15360], 1), + ([5120, 5120], 0), + ([5120, 27648], 1), + ([13824, 5120], 0), + ], + "meta-llama/Llama-2-70b-hf": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], +} diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py new file mode 100644 index 000000000..f99534764 --- /dev/null +++ b/tests/kernels/test_machete_gemm.py @@ -0,0 +1,274 @@ +"""Tests for the machete kernel. +Run `pytest tests/kernels/test_machete_gemm.py`. +""" + +import math +from typing import Optional, Tuple + +import pytest +import torch + +from aphrodite import _custom_ops as ops +from aphrodite.platforms import current_platform +from aphrodite.quantization.utils.quant_utils import (pack_rows, + quantize_weights) +from aphrodite.scalar_type import ScalarType, scalar_types + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +MNK_SHAPES = [ + (1, 128, 128), + (1, 512, 1024), + (1, 4096, 4096), + (13, 8192, 4096), + (26, 4096, 8192), + (1, 4096, 4096), + (257, 128, 4096), + (257, 4224, 4160), + (257, 4096, 4096), + (64, 4096, 4096), + (1024, 4096, 8192), + (1024, 8192, 4096), +] + +ACT_TYPES = [torch.float16, torch.bfloat16] +WTYPE_ZEROPOINTS = [ + # GPTQ style + (scalar_types.uint4b8, False), + (scalar_types.uint8b128, False), + # AWQ style + (scalar_types.uint4, True), + (scalar_types.uint8, True), +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + + +def rand_data(shape, dtype=torch.float16): + return 10 * (torch.rand(shape, dtype=dtype, device="cuda") - 0.3) + + +def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): + return zps if zps is None else -1 * s * (zps.to(s.dtype)) + + +def machete_quantize_and_pack(w: torch.Tensor, + wtype: ScalarType, + group_size: int, + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, + wtype, + group_size, + zero_points=zero_points, + # to match how the kernel applies zps + ref_zero_points_after_scales=True) + + w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # convert to col major + w_q_machete = ops.machete_prepack_B(w_q, wtype) + + return w_ref, w_q_machete, w_s, w_zp + + +def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor, + wtype: ScalarType, group_size: int, + zero_points: bool): + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + b, wtype, group_size, zero_points) + + output_ref = torch.matmul(a, w_ref) + + output = ops.machete_gemm( + a=a, + b_q=w_q_packed, + b_type=wtype, + b_scales=w_s, + b_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + ) + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) +@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) +@pytest.mark.parametrize("group_size", [128, None]) +def test_machete_all_schedules(shape, atype: torch.dtype, + wtype_zeropoints: Tuple[ScalarType, bool], + group_size: Optional[int]): + m, n, k = shape + wtype, zero_points = wtype_zeropoints + + if group_size is not None and k % group_size != 0: + return + + print(f"MNK = {m} {n} {k}") + + # Normalize group_size + if group_size is None: + group_size = k + assert group_size <= k + + a = rand_data((m, k), atype) + w = rand_data((k, n), atype) + + w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack( + w, wtype, group_size, zero_points) + + output_ref = torch.matmul(a, w_ref) + + for schedule in ops.machete_supported_schedules(wtype): + print(f"Testing schedule {schedule}") + output = ops.machete_gemm( + a, + b_q=w_q_machete, + b_type=wtype, + b_scales=w_s, + b_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + schedule=schedule, + ) + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\ + f"Schedule failed {schedule}" + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) +@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) +@pytest.mark.parametrize("group_size", [128, None]) +def test_machete_heuristic(shape, atype: torch.dtype, + wtype_zeropoints: Tuple[ScalarType, bool], + group_size: Optional[int]): + m, n, k = shape + wtype, zero_points = wtype_zeropoints + + if group_size is not None and k % group_size != 0: + return + + # Normalize group_size + if group_size is None: + group_size = k + assert group_size <= k + + a = rand_data((m, k), atype) + b = rand_data((k, n), atype) + + machete_gemm_test_helper(a, b, wtype, group_size, zero_points) + + +# Test working on other devices +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_machete_devices(device: str): + m, n, k = 512, 4096, 4096 + wtype = scalar_types.uint4b8 + group_size = 128 + zero_points = False + + print(f"MNK = {m} {n} {k}, device = {device}") + + a = rand_data((m, k), torch.float16).to(device) + b = rand_data((k, n), torch.float16).to(device) + + machete_gemm_test_helper(a, b, wtype, group_size, zero_points) + + +# Test working with a subset of A and B +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_subset(): + big_m, big_n, big_k = 1024, 1024, 1024 + m, n, k = 512, 512, 512 + wtype = scalar_types.uint4b8 + group_size = 128 + zero_points = False + + whole_a = rand_data((big_m, big_k), torch.float16) + whole_b = rand_data((big_k, big_n), torch.float16) + + a = whole_a[0:m, 0:k] + b = whole_b[0:k, 0:n] + + machete_gemm_test_helper(a, b, wtype, group_size, zero_points) + + +# Test to make sure cuda graphs work +class MacheteLayer(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.machete_gemm(**self.kwargs) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = rand_data((m, k), torch.float16) + b = rand_data((k, n), torch.float16) + wtype = scalar_types.uint4b8 + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + b, wtype, group_size, zero_points) + + # Construct a trivial model with a single layer that calls a machete kernel + model = MacheteLayer( + a=a, + b_q=w_q_packed, + b_type=wtype, + b_scales=w_s, + b_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + ) + + output_ref = torch.matmul(a, w_ref) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + output.zero_() + g.replay() + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) diff --git a/tests/kernels/test_permute_cols.py b/tests/kernels/test_permute_cols.py new file mode 100644 index 000000000..68c5f1b30 --- /dev/null +++ b/tests/kernels/test_permute_cols.py @@ -0,0 +1,13 @@ +import pytest +import torch + +from aphrodite._custom_ops import permute_cols + + +@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)]) +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) +def test_permute_cols(shape, dtype): + x = torch.randn(shape, dtype=dtype).cuda() + perm = torch.randperm(x.shape[1]).to(torch.int).cuda() + y = permute_cols(x, perm) + torch.testing.assert_close(y, x[:, perm]) From dfa34d1b241b9044068719e9e66f9dc1c21175e7 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Tue, 26 Nov 2024 20:46:42 -0800 Subject: [PATCH 05/15] feat: add sampler_priorty (#837) * feat: add sampler_priorty * fix: sampler arg verification * more clean-up and remove min_tokens from the order * more cleaning up and logs * alias sampler_priority to sampler_order --- aphrodite/common/sampling_params.py | 44 ++++ aphrodite/endpoints/openai/protocol.py | 13 +- aphrodite/modeling/layers/sampler.py | 324 +++++++++++++++++++------ 3 files changed, 305 insertions(+), 76 deletions(-) diff --git a/aphrodite/common/sampling_params.py b/aphrodite/common/sampling_params.py index 08ef8af8f..e59f16a4f 100644 --- a/aphrodite/common/sampling_params.py +++ b/aphrodite/common/sampling_params.py @@ -23,6 +23,25 @@ class SamplingType(IntEnum): RANDOM_SEED = 2 BEAM = 3 +class SamplerID(IntEnum): + # Mirror these in aphrodite/modeling/layers/sampler.py + # Values out of order to keep backwards compatibility + # with Koboldcpp values + DRY = 7 + PENALTIES = 6 + NO_REPEAT_NGRAM = 8 + TEMPERATURE = 5 + TOP_NSIGMA = 9 + TOP_P_TOP_K = 0 + TOP_A = 1 + MIN_P = 2 + TFS = 3 + ETA_CUTOFF = 10 + EPSILON_CUTOFF = 11 + TYPICAL_P = 4 + QUADRATIC = 12 + XTC = 13 + LogitsProcessorFunc = Union[Callable[[List[int], torch.Tensor], torch.Tensor], Callable[[List[int], List[int], torch.Tensor], @@ -175,6 +194,8 @@ class SamplingParams( Defaults to None. skew: Bias the token selection towards higher or lower probability tokens. Defaults to 0 (disabled). + sampler_priority: A list of integers to control the order in which + samplers are applied. """ n: int = 1 @@ -227,6 +248,7 @@ class SamplingParams( dry_allowed_length: int = 2 dry_sequence_breaker_ids: List[int] = [] skew: float = 0.0 + sampler_priority: Optional[List[int]] = [] # The below fields are not supposed to be used as an input. # They are set in post_init. output_text_buffer_length: int = 0 @@ -279,6 +301,7 @@ class SamplingParams( "dry_allowed_length": 2, "dry_sequence_breaker_ids": [], "skew": 0.0, + "sampler_priority": [], } def __post_init__(self) -> None: @@ -428,6 +451,27 @@ def _verify_args(self) -> None: raise ValueError( "skew must be non-negative, got " f"{self.skew}.") + + if self.sampler_priority is not None: + if not self.sampler_priority: + self.sampler_priority = None + return + + if not isinstance(self.sampler_priority, list): + raise ValueError("sampler_priority must be a list of integers") + try: + provided_samplers = { + SamplerID(x) for x in self.sampler_priority} + except ValueError as e: + raise ValueError( + f"Invalid sampler ID in priority list: {e}") from e + + required_samplers = set(SamplerID) + if not required_samplers.issubset(provided_samplers): + missing = required_samplers - provided_samplers + missing_names = [s.name for s in missing] + raise ValueError(f"Missing required samplers in priority list: " + f"{missing_names}") def _verify_beam_search(self) -> None: if self.best_of == 1: diff --git a/aphrodite/endpoints/openai/protocol.py b/aphrodite/endpoints/openai/protocol.py index b3b7f957f..530144bc6 100644 --- a/aphrodite/endpoints/openai/protocol.py +++ b/aphrodite/endpoints/openai/protocol.py @@ -5,7 +5,8 @@ from typing import Any, Dict, List, Literal, Optional, Union import torch -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import (AliasChoices, BaseModel, ConfigDict, Field, + model_validator) from transformers import PreTrainedTokenizer from typing_extensions import Annotated @@ -160,6 +161,10 @@ class ChatCompletionRequest(OpenAIBaseModel): nsigma: Optional[float] = 0.0 skew: Optional[float] = 0.0 custom_token_bans: Optional[List[int]] = None + sampler_priority: Optional[List[int]] = Field( + default=[], + validation_alias=AliasChoices("sampler_priority", + "sampler_order")) # doc: end-chat-completion-sampling-params # doc: begin-chat-completion-extra-params @@ -317,6 +322,7 @@ def to_sampling_params( nsigma=self.nsigma, skew=self.skew, custom_token_bans=self.custom_token_bans, + sampler_priority=self.sampler_priority, ) @model_validator(mode='before') @@ -436,6 +442,10 @@ class CompletionRequest(OpenAIBaseModel): nsigma: Optional[float] = 0.0 skew: Optional[float] = 0.0 custom_token_bans: Optional[List[int]] = None + sampler_priority: Optional[List[int]] = Field( + default=[], + validation_alias=AliasChoices("sampler_priority", + "sampler_order")) # doc: end-completion-sampling-params # doc: begin-completion-extra-params @@ -552,6 +562,7 @@ def to_sampling_params( nsigma=self.nsigma, skew=self.skew, custom_token_bans=self.custom_token_bans, + sampler_priority=self.sampler_priority, ) @model_validator(mode="before") diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index 777e1c870..b1249a6ce 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -2,11 +2,13 @@ import itertools import os import warnings +from enum import IntEnum from math import inf from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn +from loguru import logger import aphrodite._custom_ops as ops from aphrodite.common.sampling_params import SamplingType @@ -36,6 +38,26 @@ os.getenv("APHRODITE_USE_SAMPLING_KERNELS", "0"))) +class SamplerID(IntEnum): + # Mirror these in aphrodite/common/sampling_params.py + # Values out of order to keep backwards compatibility + # with Koboldcpp values + DRY = 7 + PENALTIES = 6 + NO_REPEAT_NGRAM = 8 + TEMPERATURE = 5 + TOP_NSIGMA = 9 + TOP_P_TOP_K = 0 + TOP_A = 1 + MIN_P = 2 + TFS = 3 + ETA_CUTOFF = 10 + EPSILON_CUTOFF = 11 + TYPICAL_P = 4 + QUADRATIC = 12 + XTC = 13 + + class Sampler(nn.Module): """Samples the next tokens from the model's outputs. @@ -151,90 +173,242 @@ def forward( do_temp_last = self._do_temp_last logits = _apply_min_tokens_penalty(logits, sampling_metadata) - - if do_dry: - logits = _apply_dry( - logits, - sampling_tensors.prompt_tokens, - sampling_tensors.dry_multipliers, - sampling_tensors.dry_bases, - sampling_tensors.dry_allowed_lengths, - sampling_tensors.dry_sequence_breaker_ids - ) - - # Apply presence and frequency penalties. - if do_penalties: - logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) - - if do_no_repeat_ngrams: - logits = _apply_no_repeat_ngram( - logits, - sampling_tensors.prompt_tokens, - sampling_tensors.no_repeat_ngram_sizes) - - # Apply temperature scaling if not doing temp_last. - if do_temperatures and not do_temp_last: - _apply_temperatures(logits, sampling_tensors.temperatures, - sampling_tensors.dynatemp_mins, - sampling_tensors.dynatemp_maxs, - sampling_tensors.dynatemp_exps) - - if do_nsigmas: - logits = _apply_top_nsigma(logits, sampling_tensors.nsigmas) - - if do_top_p_top_k and not APHRODITE_USE_SAMPLING_KERNELS: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) - - if do_top_as: - logits = _apply_top_a(logits, sampling_tensors.top_as) - - if do_min_p: - logits = _apply_min_p(logits, sampling_tensors.min_ps) - - if do_tfss: - logits = _apply_tfs(logits, sampling_tensors.tfss) - - if do_eta_cutoffs: - logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs) - - if do_epsilon_cutoffs: - logits = _apply_epsilon_cutoff(logits, - sampling_tensors.epsilon_cutoffs) - - if do_typical_ps: - logits = _apply_typical_sampling(logits, - sampling_tensors.typical_ps) - - if do_quadratic: - logits = _apply_quadratic_sampling( - logits, sampling_tensors.smoothing_factors, - sampling_tensors.smoothing_curves) - - if do_xtc: - logits = _apply_xtc_sampling( - logits, sampling_tensors.xtc_thresholds, - sampling_tensors.xtc_probabilities) - - if do_temperatures and do_temp_last: - _apply_temperatures(logits, sampling_tensors.temperatures, - sampling_tensors.dynatemp_mins, - sampling_tensors.dynatemp_maxs, - sampling_tensors.dynatemp_exps) - banned_tokens = _get_custom_token_bans(sampling_metadata) logits = _apply_token_bans(logits, banned_tokens) + sampler_order = None + if sampling_metadata.seq_groups: + sampler_order = sampling_metadata.seq_groups[ + 0].sampling_params.sampler_priority + + # Warn if both custom order and temp_last are specified + if sampler_order is not None and do_temp_last: + logger.warning( + "Both sampler_priority and temperature_last=True " + "were specified. Using custom sampler_priority order " + "and ignoring temperature_last.") + + if sampler_order is None: + default_order = [ + SamplerID.DRY, + SamplerID.PENALTIES, + SamplerID.NO_REPEAT_NGRAM, + SamplerID.TEMPERATURE, + SamplerID.TOP_NSIGMA, + SamplerID.TOP_P_TOP_K, + SamplerID.TOP_A, + SamplerID.MIN_P, + SamplerID.TFS, + SamplerID.ETA_CUTOFF, + SamplerID.EPSILON_CUTOFF, + SamplerID.TYPICAL_P, + SamplerID.QUADRATIC, + SamplerID.XTC, + ] + + sampler_order = [] + for sampler_id in default_order: + if sampler_id == SamplerID.TEMPERATURE and do_temp_last: + continue + sampler_order.append(sampler_id) + + if sampler_id == SamplerID.XTC and do_temp_last: + sampler_order.append(SamplerID.TEMPERATURE) + + if sampling_metadata.seq_groups and sampling_metadata.seq_groups[ + 0].is_prompt: + logger.debug("Sampler execution order: ") + for i, sampler_id in enumerate(sampler_order, 1): + logger.debug(f"{i}. {SamplerID(sampler_id).name}") + + enabled_samplers = [] + # ruff: noqa: E701 + if do_penalties: enabled_samplers.append("PENALTIES") + if do_no_repeat_ngrams: enabled_samplers.append("NO_REPEAT_NGRAM") + if do_temperatures: enabled_samplers.append("TEMPERATURE") + if do_top_p_top_k: enabled_samplers.append("TOP_P_TOP_K") + if do_top_as: enabled_samplers.append("TOP_A") + if do_min_p: enabled_samplers.append("MIN_P") + if do_tfss: enabled_samplers.append("TFS") + if do_eta_cutoffs: enabled_samplers.append("ETA_CUTOFF") + if do_epsilon_cutoffs: enabled_samplers.append("EPSILON_CUTOFF") + if do_typical_ps: enabled_samplers.append("TYPICAL_P") + if do_quadratic: enabled_samplers.append("QUADRATIC") + if do_xtc: enabled_samplers.append("XTC") + if do_nsigmas: enabled_samplers.append("TOP_NSIGMA") + if do_dry: enabled_samplers.append("DRY") + if do_skew: enabled_samplers.append("SKEW") + logger.debug(f"Enabled samplers: {', '.join(enabled_samplers)}") + + for sampler_id in sampler_order: + if sampler_id == SamplerID.DRY and do_dry: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + f"Applying DRY with dry_multiplier: " + f"{sampling_tensors.dry_multipliers}.") + logits = _apply_dry( + logits, + sampling_tensors.prompt_tokens, + sampling_tensors.dry_multipliers, + sampling_tensors.dry_bases, + sampling_tensors.dry_allowed_lengths, + sampling_tensors.dry_sequence_breaker_ids) + + elif sampler_id == SamplerID.PENALTIES and do_penalties: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying penalties with " + f"pres_pen: {sampling_tensors.presence_penalties}, " + f"freq_pen: {sampling_tensors.frequency_penalties}, " + f"rep_pen: {sampling_tensors.repetition_penalties}.") + logits = _apply_penalties( + logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) + + elif sampler_id == SamplerID.NO_REPEAT_NGRAM and \ + do_no_repeat_ngrams: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying no_repeat_ngram with no_repeat_ngram_size: " + f"{sampling_tensors.no_repeat_ngram_sizes}.") + logits = _apply_no_repeat_ngram( + logits, + sampling_tensors.prompt_tokens, + sampling_tensors.no_repeat_ngram_sizes) + + elif sampler_id == SamplerID.TEMPERATURE and do_temperatures: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying temperatures with temperature: " + f"{sampling_tensors.temperatures}, " + f"dynatemp_min: {sampling_tensors.dynatemp_mins}, " + f"dynatemp_max: {sampling_tensors.dynatemp_maxs}, " + f"dynamtep_exp: {sampling_tensors.dynatemp_exps}.") + _apply_temperatures( + logits, sampling_tensors.temperatures, + sampling_tensors.dynatemp_mins, + sampling_tensors.dynatemp_maxs, + sampling_tensors.dynatemp_exps) + + elif sampler_id == SamplerID.TOP_NSIGMA and do_nsigmas: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying Top-Nsigma with nsigma: " + f"{sampling_tensors.nsigmas}") + logits = _apply_top_nsigma( + logits, sampling_tensors.nsigmas) + + elif sampler_id == SamplerID.TOP_P_TOP_K and do_top_p_top_k and \ + not APHRODITE_USE_SAMPLING_KERNELS: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying Top-p and Top-k with top-p: " + f"{sampling_tensors.top_ps}, top_k: " + f"{sampling_tensors.top_ks}.") + logits = _apply_top_k_top_p( + logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) + + elif sampler_id == SamplerID.TOP_A and do_top_as: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying Top-a with Top-a: " + f"{sampling_tensors.top_as}.") + logits = _apply_top_a( + logits, sampling_tensors.top_as) + + elif sampler_id == SamplerID.MIN_P and do_min_p: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying Min-p with Min-p: " + f"{sampling_tensors.min_ps}.") + logits = _apply_min_p( + logits, sampling_tensors.min_ps) + + elif sampler_id == SamplerID.TFS and do_tfss: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying Tail-Free Sampling with tfs: " + f"{sampling_tensors.tfss}.") + logits = _apply_tfs( + logits, sampling_tensors.tfss) + + elif sampler_id == SamplerID.ETA_CUTOFF and do_eta_cutoffs: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying ETA Cutoff with eta_cutoff: " + f"{sampling_tensors.eta_cutoffs}.") + logits = _apply_eta_cutoff( + logits, sampling_tensors.eta_cutoffs) + + elif sampler_id == SamplerID.EPSILON_CUTOFF and do_epsilon_cutoffs: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying Epsilon Cutoff with epsilon_cutoff: " + f"{sampling_tensors.epsilon_cutoffs}.") + logits = _apply_epsilon_cutoff( + logits, sampling_tensors.epsilon_cutoffs) + + elif sampler_id == SamplerID.TYPICAL_P and do_typical_ps: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying Locally Typical Sampling with typical_p: " + f"{sampling_tensors.typical_ps}.") + logits = _apply_typical_sampling( + logits, sampling_tensors.typical_ps) + + elif sampler_id == SamplerID.QUADRATIC and do_quadratic: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying Quadratic and Cubic Sampling with " + "smoothing_factors: " + f"{sampling_tensors.smoothing_factors}," + f" smoothing_curves: " + f"{sampling_tensors.smoothing_curves}.") + logits = _apply_quadratic_sampling( + logits, sampling_tensors.smoothing_factors, + sampling_tensors.smoothing_curves) + + elif sampler_id == SamplerID.XTC and do_xtc: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying Exclude Top Choices sampling with " + f"xtc_threshold: {sampling_tensors.xtc_thresholds}, " + "xtc_probability: " + f"{sampling_tensors.xtc_probabilities}.") + logits = _apply_xtc_sampling( + logits, sampling_tensors.xtc_thresholds, + sampling_tensors.xtc_probabilities) + + # We use float32 for probabilities and log probabilities. # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) # skew needs to be applied post-softmax if do_skew: + if (sampling_metadata.seq_groups and + sampling_metadata.seq_groups[0].is_prompt): + logger.debug( + "Applying Skew sampling with skew: " + f"{sampling_tensors.skews}.") # reference: https://github.com/turboderp/exllamav2/commit/1de4cdd70b09208e7b4f17ee322c190e16f60efd cum_probs = torch.cumsum(probs, dim=-1) cum_probs = torch.pow(cum_probs, torch.exp( From 483c9e6e592588136601a0a379e0cd0a39545e53 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Tue, 26 Nov 2024 22:56:33 -0800 Subject: [PATCH 06/15] fix: disable awq_marlin override for awq models (#843) * fix: disable awq_marlin override for awq models * False->None * promote to warning --- aphrodite/common/config.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/aphrodite/common/config.py b/aphrodite/common/config.py index ffd4602b2..39882a0d0 100644 --- a/aphrodite/common/config.py +++ b/aphrodite/common/config.py @@ -316,8 +316,15 @@ def _verify_quantization(self) -> None: quantization_override = method.override_quantization_method( quant_cfg, self.quantization) if quantization_override: - quant_method = quantization_override - self.quantization = quantization_override + if quantization_override == "awq_marlin": + quant_method = quant_method + logger.warning( + "awq_marlin kernels are temporarily disabled, " + "they will be re-enabled with a future release. " + "Falling back to AWQ kernels.") + else: + quant_method = quantization_override + self.quantization = quantization_override break # Verify quantization configurations. From 538471f76e9058fa871b057b502a5873ecafba0d Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:01:20 -0800 Subject: [PATCH 07/15] chore: bump mistral_common to 1.5.0 (#844) --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index a3b111a30..3839db768 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -28,7 +28,7 @@ librosa # Required for audio processing soundfile # Required for audio processing gguf == 0.9.1 importlib_metadata -mistral_common >= 1.3.4 +mistral_common >= 1.5.0 protobuf pandas msgspec From d2971a68315eb12180c6a1be899670b5c7378b19 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:30:35 -0800 Subject: [PATCH 08/15] ci: bump version to 0.6.4 (#845) --- aphrodite/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aphrodite/version.py b/aphrodite/version.py index 8aad9d012..7ba0c1c11 100644 --- a/aphrodite/version.py +++ b/aphrodite/version.py @@ -11,4 +11,4 @@ __commit__ = "COMMIT_HASH_PLACEHOLDER" __short_commit__ = "SHORT_COMMIT_HASH_PLACEHOLDER" -__version__ = "0.6.3.post1" +__version__ = "0.6.4" From d486d7ac012dac21963faec4fe435b4057c80cb7 Mon Sep 17 00:00:00 2001 From: Luke Harold Miles <10591373+qpwo@users.noreply.github.com> Date: Sat, 30 Nov 2024 18:00:42 -0800 Subject: [PATCH 09/15] docs: add linux arm64/aarch64/GH200 installation tips (#851) * add linux arm64/aarch64/GH200 installation tips * mention spack in arm tips * Update docs/pages/installation/installation.md * Update docs/pages/installation/installation.md * Update docs/pages/installation/installation.md * Update docs/pages/installation/installation.md * Update docs/pages/installation/installation.md * Update docs/pages/installation/installation.md * Update docs/pages/installation/installation.md --------- Co-authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com> --- docs/pages/installation/installation.md | 41 ++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/docs/pages/installation/installation.md b/docs/pages/installation/installation.md index 3e67e74df..ee2c6156f 100644 --- a/docs/pages/installation/installation.md +++ b/docs/pages/installation/installation.md @@ -65,12 +65,48 @@ Afterwards, prefix every Aphrodite-related command with `./runtime.sh`. e.g.: ./runtime.sh aphrodite run -h ``` +## Linux arm64/aarch64/GH200 tips + +The NVIDIA GH200 comes with an ARM CPU, so you might have to look around for the binaries you need. +As of November 2024, this produced a working aphrodite build: + +```sh +conda create -y -n 311 python=3.11; conda activate 311 +pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu124 +conda install -y -c conda-forge cuda-nvcc_linux-aarch64=12.4 libstdcxx-ng=12 +conda install -y cmake sccache + +export CUDA_HOME=$CONDA_PREFIX +export PATH=$CUDA_HOME/bin:$PATH +python -c 'import torch; print(torch.tensor(5).cuda() + 1, "torch cuda ok")' + +cd aphrodite-engine + +pip install nvidia-ml-py==12.555.43 protobuf==3.20.2 ninja msgspec coloredlogs portalocker pytimeparse -r requirements-common.txt +pip install --no-clean --no-deps --no-build-isolation -v . + +# if you want flash attention: +cd .. +git clone https://github.com/AlpinDale/flash-attention +cd flash-attention +pip install --no-clean --no-deps --no-build-isolation -v . +``` + +A few places to look for aarch64 binaries if you're having trouble: + +- [conda aarch64 defaults channel](https://repo.anaconda.com/pkgs/main/linux-aarch64/) +- pytorch.org hosts wheels at https://download.pytorch.org/whl and https://download.pytorch.org/whl/cuXXX (eg https://download.pytorch.org/whl/cu124). Note that `/whl/cu124` is a separate index, not a folder in `/whl`. There is also https://download.pytorch.org/whl/nightly/. +- [nvidia's NGC docker containers](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags) come with many tools and python packages bundled +- Sometimes a project will have ARM binaries in their github build artifacts before the official releases. [example](https://github.com/pytorch/pytorch/actions/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml) +- The spack package manager may be helpful for building especially tricky sources, like pytorch. + ## Installation with Docker We provide both a pre-built docker image, and a Dockerfile. ### Using the pre-built Docker image ```sh +# Run this with sudo if you run into permission issues. docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ #--env "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7" \ @@ -78,7 +114,7 @@ docker run --runtime nvidia --gpus all \ --ipc=host \ alpindale/aphrodite-openai:latest \ --model NousResearch/Meta-Llama-3.1-8B-Instruct \ - --tensor-parallel-size 8 \ + --tensor-parallel-size 1 \ --api-keys "sk-empty" ``` @@ -100,6 +136,3 @@ This Dockerfile will build for all CUDA arches, which may take hours. You can li ``` You can run your built image using the command in the previous section. - - - From 14ac216498d3c7e6e69a782b177b899fdf1a3463 Mon Sep 17 00:00:00 2001 From: Selali Date: Sat, 30 Nov 2024 19:34:02 -0800 Subject: [PATCH 10/15] sampler: add output_tokens to DRY sampler (#849) * Add Debug Statements * Test Token Fix * Remove Debug Statements * perform concat after checking for 0 multipliers --------- Co-authored-by: AlpinDale --- aphrodite/modeling/layers/sampler.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index b1249a6ce..b57583ca7 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -146,7 +146,7 @@ def forward( # Prepare sampling tensors with pinned memory to avoid blocking. if not sampling_metadata.reuse_sampling_tensors: self._init_sampling_tensors(logits, sampling_metadata) - elif self._do_penalties: + elif self._do_penalties or self._do_dry: # In this case, the sampling tensors logic depends on # "output_tokens" of a sequence. As a result, we cannot # reuse sampling tensors, since "output_tokens" changes @@ -250,6 +250,7 @@ def forward( logits = _apply_dry( logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, sampling_tensors.dry_multipliers, sampling_tensors.dry_bases, sampling_tensors.dry_allowed_lengths, @@ -616,7 +617,8 @@ def _apply_min_tokens_penalty( def _apply_dry( logits: torch.Tensor, - input_ids: torch.Tensor, + input_token_ids: torch.Tensor, + output_token_ids: torch.Tensor, multipliers: torch.Tensor, bases: torch.Tensor, allowed_lengths: torch.Tensor, @@ -630,7 +632,9 @@ def _apply_dry( # Don't apply dry penalties if multiplier is 0 if torch.all(multipliers == 0): return logits - + + # we need to apply dry to both input and output tokens + input_ids = torch.cat((input_token_ids, output_token_ids), dim=1) # Process each sequence in the batch for i, (input_ids_row, logits_row) in enumerate(zip(input_ids, logits)): multiplier = multipliers[i].item() From 72c505ad8497636f5c61f0d4b02b62b7060bd9ea Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Mon, 2 Dec 2024 02:12:30 -0800 Subject: [PATCH 11/15] sampler: fix dry concurrency issue (#852) --- aphrodite/modeling/layers/sampler.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index b57583ca7..a28c12eb8 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -625,7 +625,7 @@ def _apply_dry( sequence_breakers_ids: torch.Tensor ) -> torch.Tensor: """ - Apply Exclude Don't Repeat Yourself (DRY) sampling to the logits. + Apply Don't Repeat Yourself (DRY) sampling to the logits. Reference: https://github.com/oobabooga/text-generation-webui/pull/5677 """ @@ -635,6 +635,8 @@ def _apply_dry( # we need to apply dry to both input and output tokens input_ids = torch.cat((input_token_ids, output_token_ids), dim=1) + vocab_size = logits.size(-1) + # Process each sequence in the batch for i, (input_ids_row, logits_row) in enumerate(zip(input_ids, logits)): multiplier = multipliers[i].item() @@ -661,8 +663,8 @@ def _apply_dry( # Get the token that followed this match in the input next_token = input_ids_row[idx + 1].item() - # Skip if next token is a sequence breaker - if next_token in sequence_breakers_ids: + # Skip if next token is a sequence breaker or out of vocab range + if next_token in sequence_breakers_ids or next_token >= vocab_size: continue # We found last_token matches at this index, so match length starts @@ -700,7 +702,7 @@ def _apply_dry( base = bases[i] for token, match_length in match_lengths.items(): - if match_length >= allowed_length: + if match_length >= allowed_length and token < vocab_size: penalty = multiplier * (base ** (match_length - allowed_length)) logits_row[token] -= penalty From 2150bb501990fc63d77317894ee19c2526590813 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:16:10 -0800 Subject: [PATCH 12/15] sampler: add range parameter for DRY (#855) * sampler: add range parameter for DRY * openai: add the dry_range parameter to OpenAI server * openai: alias dry_range to dry_penalty_last_n * misc: more comments clean up --- aphrodite/common/sampling_params.py | 8 ++ aphrodite/endpoints/openai/protocol.py | 10 ++ aphrodite/modeling/layers/sampler.py | 42 ++++---- aphrodite/modeling/sampling_metadata.py | 15 ++- tests/samplers/test_sampler.py | 130 ++++++++++++++++++++++++ 5 files changed, 180 insertions(+), 25 deletions(-) diff --git a/aphrodite/common/sampling_params.py b/aphrodite/common/sampling_params.py index e59f16a4f..0b2f4a1ae 100644 --- a/aphrodite/common/sampling_params.py +++ b/aphrodite/common/sampling_params.py @@ -192,6 +192,8 @@ class SamplingParams( input into sections where repetition is evaluated separately. Common examples are newlines, quotes, and other structural tokens. Defaults to None. + dry_range: The range of tokens (input + output) to apply the DRY + sampler. skew: Bias the token selection towards higher or lower probability tokens. Defaults to 0 (disabled). sampler_priority: A list of integers to control the order in which @@ -247,6 +249,7 @@ class SamplingParams( dry_base: float = 1.75 dry_allowed_length: int = 2 dry_sequence_breaker_ids: List[int] = [] + dry_range: int = 0 skew: float = 0.0 sampler_priority: Optional[List[int]] = [] # The below fields are not supposed to be used as an input. @@ -300,6 +303,7 @@ class SamplingParams( "dry_base": 1.75, "dry_allowed_length": 2, "dry_sequence_breaker_ids": [], + "dry_range": 0, "skew": 0.0, "sampler_priority": [], } @@ -447,6 +451,10 @@ def _verify_args(self) -> None: raise ValueError( "dry_allowed_length must be non-negative, got " f"{self.dry_allowed_length}.") + if self.dry_range < 0: + raise ValueError( + "dry_range must be non-negative, got " + f"{self.dry_range}.") if self.skew < 0.0: raise ValueError( "skew must be non-negative, got " diff --git a/aphrodite/endpoints/openai/protocol.py b/aphrodite/endpoints/openai/protocol.py index 530144bc6..f866b411f 100644 --- a/aphrodite/endpoints/openai/protocol.py +++ b/aphrodite/endpoints/openai/protocol.py @@ -155,6 +155,10 @@ class ChatCompletionRequest(OpenAIBaseModel): dry_allowed_length: Optional[int] = 2 dry_sequence_breakers: Optional[List[str]] = Field( default=["\n", ":", "\"", "*"]) + dry_range: Optional[int] = Field( + default=0, + validation_alias=AliasChoices("dry_range", + "dry_penalty_last_n")) dynatemp_min: Optional[float] = 0.0 dynatemp_max: Optional[float] = 0.0 dynatemp_exponent: Optional[float] = 1.0 @@ -316,6 +320,7 @@ def to_sampling_params( dry_base=self.dry_base, dry_allowed_length=self.dry_allowed_length, dry_sequence_breaker_ids=dry_sequence_breaker_ids, + dry_range=self.dry_range, dynatemp_min=self.dynatemp_min, dynatemp_max=self.dynatemp_max, dynatemp_exponent=self.dynatemp_exponent, @@ -436,6 +441,10 @@ class CompletionRequest(OpenAIBaseModel): dry_allowed_length: Optional[int] = 2 dry_sequence_breakers: Optional[List[str]] = Field( default=["\n", ":", "\"", "*"]) + dry_range: Optional[int] = Field( + default=0, + validation_alias=AliasChoices("dry_range", + "dry_penalty_last_n")) dynatemp_min: Optional[float] = 0.0 dynatemp_max: Optional[float] = 0.0 dynatemp_exponent: Optional[float] = 1.0 @@ -556,6 +565,7 @@ def to_sampling_params( dry_base=self.dry_base, dry_allowed_length=self.dry_allowed_length, dry_sequence_breaker_ids=dry_sequence_breaker_ids, + dry_range=self.dry_range, dynatemp_min=self.dynatemp_min, dynatemp_max=self.dynatemp_max, dynatemp_exponent=self.dynatemp_exponent, diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index a28c12eb8..7f3ecc996 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -254,7 +254,8 @@ def forward( sampling_tensors.dry_multipliers, sampling_tensors.dry_bases, sampling_tensors.dry_allowed_lengths, - sampling_tensors.dry_sequence_breaker_ids) + sampling_tensors.dry_sequence_breaker_ids, + sampling_tensors.dry_ranges) elif sampler_id == SamplerID.PENALTIES and do_penalties: if (sampling_metadata.seq_groups and @@ -622,61 +623,59 @@ def _apply_dry( multipliers: torch.Tensor, bases: torch.Tensor, allowed_lengths: torch.Tensor, - sequence_breakers_ids: torch.Tensor + sequence_breakers_ids: torch.Tensor, + ranges: torch.Tensor, ) -> torch.Tensor: """ Apply Don't Repeat Yourself (DRY) sampling to the logits. Reference: https://github.com/oobabooga/text-generation-webui/pull/5677 """ - # Don't apply dry penalties if multiplier is 0 if torch.all(multipliers == 0): return logits - # we need to apply dry to both input and output tokens + # DRY needs to be applied to both input AND output tokens input_ids = torch.cat((input_token_ids, output_token_ids), dim=1) vocab_size = logits.size(-1) - + # Process each sequence in the batch for i, (input_ids_row, logits_row) in enumerate(zip(input_ids, logits)): multiplier = multipliers[i].item() if multiplier == 0: - continue # Skip processing for this sequence - # Get the last token + continue + + range_limit = ranges[i].item() + if range_limit == 0: + search_start = 0 + else: + search_start = max(0, len(input_ids_row) - range_limit) + last_token = input_ids_row[-1].item() - # Skip if last token is a sequence breaker if last_token in sequence_breakers_ids: continue # Find matches of the last token, excluding the last position - match_indices = (input_ids_row[:-1] == last_token).nonzero() + # Only look within the specified range + match_indices = (input_ids_row[search_start:-1] == last_token).nonzero() + if len(match_indices) > 0: + match_indices = match_indices + search_start - # Track max matching sequence length for each potential next token match_lengths = {} - # Process each match for idx in match_indices: - # Convert to scalar idx = idx.item() - - # Get the token that followed this match in the input next_token = input_ids_row[idx + 1].item() - # Skip if next token is a sequence breaker or out of vocab range if next_token in sequence_breakers_ids or next_token >= vocab_size: continue - # We found last_token matches at this index, so match length starts - # at 1 match_length = 1 - # Try to extend match backwards while match_length < 50: j = idx - match_length k = len(input_ids_row) - match_length - 1 - if j < 0 or k < 0: - # Reached start of input + if j < search_start or j < 0 or k < 0: break if input_ids_row[j].item() != input_ids_row[k].item(): @@ -684,19 +683,16 @@ def _apply_dry( break if input_ids_row[k].item() in sequence_breakers_ids: - # Hit a sequence breaker break match_length += 1 - # Update max match length for this next token if next_token in match_lengths: match_lengths[next_token] = max( match_length, match_lengths[next_token]) else: match_lengths[next_token] = match_length - # Apply penalties based on match lengths allowed_length = allowed_lengths[i] multiplier = multipliers[i] base = bases[i] diff --git a/aphrodite/modeling/sampling_metadata.py b/aphrodite/modeling/sampling_metadata.py index f650427a3..947456264 100644 --- a/aphrodite/modeling/sampling_metadata.py +++ b/aphrodite/modeling/sampling_metadata.py @@ -393,6 +393,7 @@ class SamplingTensors: dry_bases: torch.Tensor dry_allowed_lengths: torch.Tensor dry_sequence_breaker_ids: torch.Tensor + dry_ranges: torch.Tensor skews: torch.Tensor sampling_seeds: torch.Tensor sample_indices: torch.Tensor @@ -447,6 +448,7 @@ def from_sampling_metadata( dry_bases: List[float] = [] dry_allowed_lengths: List[int] = [] dry_sequence_breaker_ids: List[List[int]] = [] + dry_ranges: List[int] = [] skews: List[float] = [] do_penalties = False @@ -552,6 +554,7 @@ def from_sampling_metadata( dry_allowed_lengths += [params.dry_allowed_length] * n_seqs dry_sequence_breaker_ids += ( [params.dry_sequence_breaker_ids] * n_seqs) + dry_ranges += [params.dry_range] * n_seqs skews += [params.skew] * n_seqs if _USE_TRITON_SAMPLER: @@ -601,7 +604,7 @@ def from_sampling_metadata( no_repeat_ngram_sizes, tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors, smoothing_curves, xtc_thresholds, xtc_probabilities, nsigmas, dry_multipliers, dry_bases, - dry_allowed_lengths, dry_sequence_breaker_ids, skews, + dry_allowed_lengths, dry_sequence_breaker_ids, dry_ranges, skews, sampling_seeds, sample_indices, prompt_tokens, output_tokens, vocab_size, extra_seeds_to_generate, device, dtype) return (sampling_tensors, do_penalties, do_no_repeat_ngrams, @@ -626,7 +629,8 @@ def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float], dry_multipliers: List[float], dry_bases: List[float], dry_allowed_lengths: List[int], dry_sequence_breaker_ids: List[List[int]], - skews: List[float], sampling_seeds: List[List[int]], + dry_ranges: List[int], skews: List[float], + sampling_seeds: List[List[int]], sample_indices: List[int], prompt_tokens: List[array], output_tokens: List[array], vocab_size: int, extra_seeds_to_generate: int, device: torch.device, @@ -792,6 +796,12 @@ def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float], dtype=torch.long, pin_memory=pin_memory, ) + dry_ranges_t = torch.tensor( + dry_ranges, + device="cpu", + dtype=torch.int, + pin_memory=pin_memory, + ) skews_t = torch.tensor( skews, device="cpu", @@ -865,6 +875,7 @@ def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float], non_blocking=True), dry_sequence_breaker_ids=dry_sequence_breakers_t.to(device=device, non_blocking=True), + dry_ranges=dry_ranges_t.to(device=device, non_blocking=True), skews=skews_t.to(device=device, non_blocking=True), typical_ps=typical_ps_t.to(device=device, non_blocking=True), prompt_tokens=prompt_t.to(device=device, non_blocking=True), diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 60e01837d..8ba10a08e 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -805,6 +805,136 @@ def test_sampler_no_repeat_ngram(seed: int, device: str): "No-repeat-ngram sampling is not deterministic with same seed" +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_dry(device: str): + vocab_size = 8 + + def test_sampling_params(sampling_params: List[SamplingParams]): + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + seq_lens: List[int] = [] + for i in range(2): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={ + 0: SequenceData(array(APHRODITE_TOKEN_ID_ARRAY_TYPE, + [1, 2, 3, 1, 2])) + }, + sampling_params=sampling_params[i], + block_tables={0: [1]}, + )) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens=seq_lens, + device=device, + pin_memory=is_pin_memory_available()) + + fake_logits = torch.full((2, vocab_size), + 1e-2, + device=device, + dtype=torch.float16) + fake_logits[:, 3] = 1.0 + + sampler = MockLogitsSampler(fake_logits) + sampler_output = sampler(logits=fake_logits, + sampling_metadata=sampling_metadata) + + generated_tokens = [] + for output in sampler_output: + generated_tokens.append(output.samples[0].output_token) + + return generated_tokens + + # Test case 1: DRY disabled (multiplier = 0) + sampling_params_no_dry = SamplingParams( + temperature=0.0, + dry_multiplier=0.0, + ) + + # Test case 2: DRY enabled with full range + sampling_params_full_dry = SamplingParams( + temperature=0.0, + dry_multiplier=1.0, + dry_allowed_length=2, + dry_base=2.0, + dry_range=0, + ) + + sampling_params_limited_dry = SamplingParams( + temperature=0.0, + dry_multiplier=1.0, + dry_allowed_length=2, + dry_base=2.0, + dry_range=3, + ) + + tokens1 = test_sampling_params( + [sampling_params_no_dry, sampling_params_full_dry]) + + assert tokens1[0] == 3, "Without DRY, should choose highest logit token" + assert tokens1[1] != 3, "With full-range DRY, should avoid repeating pattern" # noqa: E501 + + tokens2 = test_sampling_params( + [sampling_params_full_dry, sampling_params_limited_dry]) + + assert tokens2[0] != 3, "Full-range DRY should detect full pattern" + assert tokens2[1] == 3, "Limited-range DRY should only consider recent tokens" # noqa: E501 + + tokens3 = test_sampling_params( + [sampling_params_full_dry, sampling_params_limited_dry]) + assert tokens2 == tokens3, "DRY sampling should be deterministic" + +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_dry_sequence_breakers(device: str): + """Test that DRY respects sequence breakers.""" + vocab_size = 8 + + # 7 is a sequence breaker + input_sequence = [1, 2, 7, 1, 2] + + seq_group_metadata = SequenceGroupMetadata( + request_id="test_0", + is_prompt=True, + seq_data={ + 0: SequenceData(array(APHRODITE_TOKEN_ID_ARRAY_TYPE, + input_sequence)) + }, + sampling_params=SamplingParams( + temperature=0.0, + dry_multiplier=1.0, + dry_allowed_length=2, + dry_base=2.0, + dry_range=0, + dry_sequence_breaker_ids=[7], + ), + block_tables={0: [1]}, + ) + + sampling_metadata = SamplingMetadata.prepare( + [seq_group_metadata], + seq_lens=[len(input_sequence)], + query_lens=[len(input_sequence)], + device=device, + pin_memory=is_pin_memory_available()) + + fake_logits = torch.full((1, vocab_size), + 1e-2, + device=device, + dtype=torch.float16) + fake_logits[0, 3] = 1.0 + + sampler = MockLogitsSampler(fake_logits) + sampler_output = sampler(logits=fake_logits, + sampling_metadata=sampling_metadata) + + assert sampler_output[0].samples[0].output_token == 3, \ + "DRY should not detect patterns across sequence breakers" + + @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_nsigma(seed: int, device: str): From 0035dc42eda87044711eee68f06e8246d76ddbcd Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Mon, 2 Dec 2024 16:17:32 -0800 Subject: [PATCH 13/15] sampler: optimize DRY performance using z-algorithm (#856) * sampler: optimize DRY performance using z-algorithm * misc: trailing whitespace * sampler: apply sequence breakers per-sequence in a batch * sampler: fix typos in DRY (#857) Fix typos --------- Co-authored-by: Mehdi Z. <98492916+gitzaidi@users.noreply.github.com> --- aphrodite/modeling/layers/sampler.py | 114 +++++++++++++++------------ 1 file changed, 65 insertions(+), 49 deletions(-) diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index 7f3ecc996..80880285c 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -638,69 +638,85 @@ def _apply_dry( input_ids = torch.cat((input_token_ids, output_token_ids), dim=1) vocab_size = logits.size(-1) + def compute_z_array(s: List[int], end: int, search_start: int) -> List[int]: + """ + Compute Z array using two-pointer technique for linear time complexity + """ + z = [0] * len(s) + right = end - 1 + left = end - 1 + + while right >= search_start: + while left == right and left >= search_start: + if s[right] == s[end]: + break + right -= 1 + left -= 1 + + while left >= search_start and s[left] == s[end - (right - left)]: + z[right] += 1 + left -= 1 + + helper = right + while right > left: + right -= 1 + if left == right: + break + z[right] = min(z[end - (helper - right)], right - left) + if left >= search_start and right - z[right] <= left: + break + + return z + # Process each sequence in the batch for i, (input_ids_row, logits_row) in enumerate(zip(input_ids, logits)): multiplier = multipliers[i].item() if multiplier == 0: continue + seq_breakers = set(sequence_breakers_ids[i].tolist()) + input_ids_list = input_ids_row.tolist() + last_token = input_ids_list[-1] + + if last_token in seq_breakers: + continue + range_limit = ranges[i].item() if range_limit == 0: search_start = 0 else: - search_start = max(0, len(input_ids_row) - range_limit) - - last_token = input_ids_row[-1].item() - - if last_token in sequence_breakers_ids: - continue - - # Find matches of the last token, excluding the last position - # Only look within the specified range - match_indices = (input_ids_row[search_start:-1] == last_token).nonzero() - if len(match_indices) > 0: - match_indices = match_indices + search_start - - match_lengths = {} - - for idx in match_indices: - idx = idx.item() - next_token = input_ids_row[idx + 1].item() - - if next_token in sequence_breakers_ids or next_token >= vocab_size: - continue - - match_length = 1 - - while match_length < 50: - j = idx - match_length - k = len(input_ids_row) - match_length - 1 - if j < search_start or j < 0 or k < 0: - break - - if input_ids_row[j].item() != input_ids_row[k].item(): - # No more matches - break - - if input_ids_row[k].item() in sequence_breakers_ids: - break - - match_length += 1 - - if next_token in match_lengths: - match_lengths[next_token] = max( - match_length, match_lengths[next_token]) - else: - match_lengths[next_token] = match_length - + search_start = max(0, len(input_ids_list) - range_limit) + + # Find max match length based on sequence breakers + max_match_length = 0 + MAX_LENGTH = min(len(input_ids_list), 1000) # Prevent overflow + while (max_match_length < MAX_LENGTH and + input_ids_list[len(input_ids_list) - max_match_length - 1] + not in seq_breakers): + max_match_length += 1 + + z_array = compute_z_array( + input_ids_list, len(input_ids_list) - 1, search_start) + + z_array = [min(length, max_match_length) for length in z_array] + + penalties = {} allowed_length = allowed_lengths[i] - multiplier = multipliers[i] base = bases[i] - for token, match_length in match_lengths.items(): - if match_length >= allowed_length and token < vocab_size: + for idx, match_length in enumerate(z_array[:-1]): + if match_length >= allowed_length: + next_token = input_ids_list[idx + 1] + if (next_token >= vocab_size or next_token in + seq_breakers): + continue + penalty = multiplier * (base ** (match_length - allowed_length)) - logits_row[token] -= penalty + penalties[next_token] = max( + penalty, penalties.get(next_token, 0)) + + for token, penalty in penalties.items(): + logits_row[token] -= penalty return logits From 3392b81bf98e3783d03022776d438496326e4b20 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Mon, 2 Dec 2024 17:25:37 -0800 Subject: [PATCH 14/15] sampler: allow parsing sampler order using strings (#858) --- aphrodite/common/sampling_params.py | 45 ++++++++++++++++++++++---- aphrodite/endpoints/openai/protocol.py | 4 +-- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/aphrodite/common/sampling_params.py b/aphrodite/common/sampling_params.py index 0b2f4a1ae..64fa1b63a 100644 --- a/aphrodite/common/sampling_params.py +++ b/aphrodite/common/sampling_params.py @@ -42,6 +42,30 @@ class SamplerID(IntEnum): QUADRATIC = 12 XTC = 13 + @classmethod + def from_str(cls, value: Union[str, int]) -> "SamplerID": + """Convert string or int to SamplerID enum. + + Args: + value: String name (case-insensitive) or integer value + + Returns: + SamplerID enum value + + Raises: + ValueError: If value cannot be converted to SamplerID + """ + if isinstance(value, int): + return cls(value) + + try: + return cls[value.upper()] + except KeyError as e: + valid_names = [x.name for x in cls] + raise ValueError( + f"Invalid sampler name '{value}'. Must be one of: {valid_names}" + ) from e + LogitsProcessorFunc = Union[Callable[[List[int], torch.Tensor], torch.Tensor], Callable[[List[int], List[int], torch.Tensor], @@ -291,7 +315,7 @@ class SamplingParams( "logprobs": None, "prompt_logprobs": None, "detokenize": True, - "custom_token_bans": [], + "custom_token_bans": None, "skip_special_tokens": True, "spaces_between_special_tokens": True, "include_stop_str_in_output": False, @@ -466,20 +490,27 @@ def _verify_args(self) -> None: return if not isinstance(self.sampler_priority, list): - raise ValueError("sampler_priority must be a list of integers") + raise ValueError( + "sampler_priority must be a list of integers or strings") + try: - provided_samplers = { - SamplerID(x) for x in self.sampler_priority} + self.sampler_priority = [ + SamplerID.from_str(x) for x in self.sampler_priority + ] + provided_samplers = set(self.sampler_priority) except ValueError as e: raise ValueError( - f"Invalid sampler ID in priority list: {e}") from e + f"Invalid sampler ID in priority list: {e}" + ) from e required_samplers = set(SamplerID) if not required_samplers.issubset(provided_samplers): missing = required_samplers - provided_samplers missing_names = [s.name for s in missing] - raise ValueError(f"Missing required samplers in priority list: " - f"{missing_names}") + raise ValueError( + "Missing required samplers in priority list: " + f"{missing_names}" + ) def _verify_beam_search(self) -> None: if self.best_of == 1: diff --git a/aphrodite/endpoints/openai/protocol.py b/aphrodite/endpoints/openai/protocol.py index f866b411f..620d97d42 100644 --- a/aphrodite/endpoints/openai/protocol.py +++ b/aphrodite/endpoints/openai/protocol.py @@ -165,7 +165,7 @@ class ChatCompletionRequest(OpenAIBaseModel): nsigma: Optional[float] = 0.0 skew: Optional[float] = 0.0 custom_token_bans: Optional[List[int]] = None - sampler_priority: Optional[List[int]] = Field( + sampler_priority: Optional[Union[List[int], List[str]]] = Field( default=[], validation_alias=AliasChoices("sampler_priority", "sampler_order")) @@ -451,7 +451,7 @@ class CompletionRequest(OpenAIBaseModel): nsigma: Optional[float] = 0.0 skew: Optional[float] = 0.0 custom_token_bans: Optional[List[int]] = None - sampler_priority: Optional[List[int]] = Field( + sampler_priority: Optional[Union[List[int], List[str]]] = Field( default=[], validation_alias=AliasChoices("sampler_priority", "sampler_order")) From 8b8d2ce7e24f3ef3d949f497032b0e6a43cbd7e2 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Mon, 2 Dec 2024 17:51:09 -0800 Subject: [PATCH 15/15] ci: bump aphrodite version to 0.6.4.post1 (#859) --- aphrodite/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aphrodite/version.py b/aphrodite/version.py index 7ba0c1c11..29955888a 100644 --- a/aphrodite/version.py +++ b/aphrodite/version.py @@ -11,4 +11,4 @@ __commit__ = "COMMIT_HASH_PLACEHOLDER" __short_commit__ = "SHORT_COMMIT_HASH_PLACEHOLDER" -__version__ = "0.6.4" +__version__ = "0.6.4.post1"