Skip to content

Commit

Permalink
review comments and format
Browse files Browse the repository at this point in the history
Signed-off-by: Sage Moore <[email protected]>
  • Loading branch information
SageMoore committed Dec 19, 2024
1 parent bfdac35 commit 8514b0e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
14 changes: 9 additions & 5 deletions csrc/quantization/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,22 @@ __global__ void act_and_mul_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const float* scale, const int d) {
const int64_t token_idx = blockIdx.x;
const int32_t blocks_per_token = gridDim.y;

const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t);

// We don't expect the hidden dimension to exceed 32 bits so int32 should
// be safe here.
const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token);
const int32_t elems_per_block =
round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load);
const int32_t block_start = blockIdx.y * elems_per_block;
int32_t block_end = block_start + elems_per_block;
block_end = block_end > d ? d : block_end;

// token_idx is 64 bit to prevent 32 bit overflow when the number of tokens
// is very large
const int64_t token_idx = blockIdx.x;
const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d;
const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d;
FP8_TYPE* __restrict__ out_ptr = out + token_idx * d;
Expand Down Expand Up @@ -105,12 +109,12 @@ __global__ void act_and_mul_quant_kernel(
scale.data_ptr<float>(), d); \
});

void silu_and_mul_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input,
torch::Tensor& scale) // [..., 2 * d]
{
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& scale) {
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.size(-1) % 2 == 0);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
5 changes: 0 additions & 5 deletions tests/compile/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,

config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_reshape=True)
# compilation_config = CompilationConfig(level=3,
# custom_ops=["+silu_and_mul"])
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)
act_quant_fusion_pass = ActivationQuantFusionPass.instance(config)
Expand Down Expand Up @@ -85,9 +83,6 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,

gen_no_func = llm.generate(prompts, sampling_params)

# print(
# backend_func.graph_pre_pass.python_code(root_module="self",
# verbose=True).src)
for output_func, output_no_func in zip(gen_func, gen_no_func):
assert output_func.outputs[0].text == output_no_func.outputs[0].text

Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_fused_quant_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
DTYPES = [torch.bfloat16, torch.float16]
QUANT_DTYPES = [torch.float8_e4m3fn]
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
HIDDEN_SIZES = [32, 64, 128, 2048, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
Expand Down

0 comments on commit 8514b0e

Please sign in to comment.