Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kernel] switch from PYBIND11 to TORCH_LIBRARY #617

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions exllamav2/exllamav2_ext/config.h

This file was deleted.

30 changes: 16 additions & 14 deletions exllamav2/exllamav2_ext/ext_cache.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
Expand All @@ -15,7 +15,8 @@

#include "cpp/util.h"

void fp16_to_fp8(torch::Tensor in_tensor, torch::Tensor out_tensor, int batch_size, int offset, int width)
void fp16_to_fp8(torch::Tensor in_tensor, torch::Tensor out_tensor,
int64_t batch_size, int64_t offset, int64_t width)
{
TORCH_CHECK_DTYPE(in_tensor, kHalf);
TORCH_CHECK_DTYPE(out_tensor, kUInt8);
Expand Down Expand Up @@ -46,7 +47,8 @@ void fp16_to_fp8(torch::Tensor in_tensor, torch::Tensor out_tensor, int batch_si
);
}

void fp8_to_fp16(torch::Tensor in_tensor, torch::Tensor out_tensor, int batch_size, int offset, int width)
void fp8_to_fp16(torch::Tensor in_tensor, torch::Tensor out_tensor,
int64_t batch_size, int64_t offset, int64_t width)
{
TORCH_CHECK_DTYPE(in_tensor, kUInt8);
TORCH_CHECK_DTYPE(out_tensor, kHalf);
Expand Down Expand Up @@ -85,15 +87,15 @@ void fp16_to_q_kv
torch::Tensor v_in,
torch::Tensor v_out,
torch::Tensor v_scales,
int batch_size,
int offset,
int width,
int page_size,
int64_t batch_size,
int64_t offset,
int64_t width,
int64_t page_size,
torch::Tensor cache_seqlens,
torch::Tensor block_table,
torch::Tensor cal_k,
torch::Tensor cal_v,
int wbits
int64_t wbits
)
{
TORCH_CHECK_DTYPE(k_in, kHalf);
Expand Down Expand Up @@ -193,15 +195,15 @@ void q_to_fp16_kv
torch::Tensor v_in,
torch::Tensor v_out,
torch::Tensor v_scales,
int batch_size,
int offset,
int width,
int page_size,
int64_t batch_size,
int64_t offset,
int64_t width,
int64_t page_size,
torch::Tensor cache_seqlens,
torch::Tensor block_table,
torch::Tensor cal_k,
torch::Tensor cal_v,
int wbits
int64_t wbits
)
{
TORCH_CHECK_DTYPE(k_in, kUInt8);
Expand Down Expand Up @@ -310,7 +312,7 @@ int count_match
(
torch::Tensor a,
torch::Tensor b,
int max_a
int64_t max_a
)
{
uint64_t* pa = (uint64_t*) a.data_ptr();
Expand Down
54 changes: 0 additions & 54 deletions exllamav2/exllamav2_ext/ext_cache.h

This file was deleted.

4 changes: 2 additions & 2 deletions exllamav2/exllamav2_ext/ext_element.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
Expand All @@ -17,7 +17,7 @@
void softcap_
(
torch::Tensor x,
float scale
double scale
)
{
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
Expand Down
6 changes: 0 additions & 6 deletions exllamav2/exllamav2_ext/ext_element.h

This file was deleted.

6 changes: 3 additions & 3 deletions exllamav2/exllamav2_ext/ext_gemm.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
Expand All @@ -20,8 +20,8 @@ void gemm_half_half_half
torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
const float alpha,
const float beta,
const double alpha,
const double beta,
bool force_cublas
)
{
Expand Down
10 changes: 0 additions & 10 deletions exllamav2/exllamav2_ext/ext_gemm.h

This file was deleted.

2 changes: 1 addition & 1 deletion exllamav2/exllamav2_ext/ext_hadamard.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <cstdint>
#include <cstdio>
#include <pybind11/pybind11.h>
Expand Down
10 changes: 0 additions & 10 deletions exllamav2/exllamav2_ext/ext_hadamard.h

This file was deleted.

16 changes: 8 additions & 8 deletions exllamav2/exllamav2_ext/ext_norm.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
Expand All @@ -25,7 +25,7 @@ void rms_norm
torch::Tensor x,
torch::Tensor w,
torch::Tensor y,
float epsilon
double epsilon
)
{
bool input_fp32 = x.dtype() == torch::kFloat;
Expand Down Expand Up @@ -61,7 +61,7 @@ void rms_norm_tp
std::vector<torch::Tensor> x,
std::vector<torch::Tensor> w,
std::vector<torch::Tensor> y,
float epsilon,
double epsilon,
uintptr_t tp_context
)
{
Expand Down Expand Up @@ -96,7 +96,7 @@ void rms_norm_
(
torch::Tensor x,
torch::Tensor w,
float epsilon
double epsilon
)
{
rms_norm(x, w, x, epsilon);
Expand All @@ -111,7 +111,7 @@ void layer_norm
torch::Tensor w,
torch::Tensor b,
torch::Tensor y,
float epsilon
double epsilon
)
{
TORCH_CHECK_DTYPE(x, kHalf);
Expand Down Expand Up @@ -147,7 +147,7 @@ void layer_norm_
torch::Tensor x,
torch::Tensor w,
torch::Tensor b,
float epsilon
double epsilon
)
{
layer_norm(x, w, b, x, epsilon);
Expand All @@ -162,7 +162,7 @@ void head_norm
torch::Tensor w,
torch::Tensor b,
torch::Tensor y,
float epsilon
double epsilon
)
{
TORCH_CHECK_DTYPE(x, kHalf);
Expand Down Expand Up @@ -202,7 +202,7 @@ void head_norm_
torch::Tensor x,
torch::Tensor w,
torch::Tensor b,
float epsilon
double epsilon
)
{
head_norm(x, w, b, x, epsilon);
Expand Down
61 changes: 0 additions & 61 deletions exllamav2/exllamav2_ext/ext_norm.h

This file was deleted.

Loading