Skip to content

Commit

Permalink
Merge branch 'develop' into fix_max_seq
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras authored Nov 27, 2024
2 parents 7f1b5e3 + 529cefe commit aa3f760
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/clang-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
push:
branches:
- main
- develop
paths:
- '**/*.h'
- '**/*.cpp'
Expand All @@ -15,6 +16,7 @@ on:
pull_request:
branches:
- main
- develop
paths:
- '**/*.h'
- '**/*.cpp'
Expand Down
29 changes: 15 additions & 14 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,21 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {

// Launch activation and gating kernel.
#ifdef USE_ROCM
#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \
vllm::scaled_act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<c10::Float8_e4m3fnuz>(), \
input.data_ptr<scalar_t>(), d, \
1.0 / (*scale.data_ptr<float>())); \
});
#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \
vllm::scaled_act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<c10::Float8_e4m3fnuz>(), \
input.data_ptr<scalar_t>(), d, \
1.0 / (*scale.data_ptr<float>())); \
});
#endif

void silu_and_mul(torch::Tensor& out, // [..., d]
Expand Down
2 changes: 1 addition & 1 deletion csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
LAUNCH_RMS_NORM(0);
}
#else
LAUNCH_RMS_NORM(0);
LAUNCH_RMS_NORM(0);
#endif
}

Expand Down

0 comments on commit aa3f760

Please sign in to comment.