forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Gradlib torch extension cmake (#282)
* Converted gradlib into a cmake project whilke using TORCH_LIBRARY binding rather than pybind11 * Made gradlib a vllm _gradlib_C module * Reusing binding includes from core vllm * The extension is created by the wrapper * Remove gradlib mentions from the dockerfile
- Loading branch information
Showing
17 changed files
with
363 additions
and
1,163 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#pragma once | ||
|
||
#include <torch/all.h> | ||
|
||
void hipb_create_extension(); | ||
void hipb_destroy_extension(); | ||
torch::Tensor hipb_mm(const torch::Tensor& mat1, const torch::Tensor& mat2, | ||
const int64_t solution_index, | ||
at::optional<torch::Tensor> bias = at::nullopt, | ||
at::optional<c10::ScalarType> out_dtype = at::nullopt, | ||
at::optional<torch::Tensor> scale1 = at::nullopt, | ||
at::optional<torch::Tensor> scale2 = at::nullopt, | ||
at::optional<torch::Tensor> scaleOut = at::nullopt); | ||
|
||
std::vector<int64_t> hipb_findallsols(const torch::Tensor& mat1, | ||
const torch::Tensor& mat2, | ||
at::optional<torch::Tensor> bias, | ||
at::optional<c10::ScalarType> out_dtype); | ||
|
||
void rocb_create_extension(); | ||
void rocb_destroy_extension(); | ||
torch::Tensor RocSolIdxBlas(const torch::Tensor& mat1, | ||
const torch::Tensor& mat2, | ||
const int64_t solution_index); | ||
|
||
std::vector<int64_t> RocFindAllSolIdxBlas(const torch::Tensor& mat1, | ||
const torch::Tensor& mat2); |
Oops, something went wrong.