From 59d1426c29be0b91ffa6cbb0fb881923ceae8798 Mon Sep 17 00:00:00 2001 From: koubaa Date: Mon, 9 Dec 2024 13:13:34 -0600 Subject: [PATCH 1/2] fix algorithm method binding Signed-off-by: koubaa --- python/src/main.cpp | 351 +++++++++++++++++++++----------------------- 1 file changed, 171 insertions(+), 180 deletions(-) diff --git a/python/src/main.cpp b/python/src/main.cpp index 210fbe04..5b272838 100644 --- a/python/src/main.cpp +++ b/python/src/main.cpp @@ -49,6 +49,83 @@ opAlgoDispatchPyInit(std::shared_ptr& algorithm, } } +namespace pv { + std::vector to_vector(const py::bytes& spirv) { + py::buffer_info info(py::buffer(spirv).request()); + const char* data = reinterpret_cast(info.ptr); + size_t length = static_cast(info.size); + std::vector spirvVec((uint32_t*)data, + (uint32_t*)(data + length)); + return spirvVec; + } + + template + std::shared_ptr algorithm( + kp::Manager& mgr, + const std::vector>& tensors, + const py::bytes& spirv, + const kp::Workgroup& workgroup, + const py::array& spec_consts, + const py::array& push_consts) { + std::vector spirvVec = pv::to_vector(spirv); + const py::buffer_info pushInfo = push_consts.request(); + const py::buffer_info specInfo = spec_consts.request(); + std::vector specConstsVec( + (spec_T*)specInfo.ptr, ((spec_T*)specInfo.ptr) + specInfo.size); + std::vector pushConstsVec( + (push_T*)pushInfo.ptr, ((push_T*)pushInfo.ptr) + pushInfo.size); + return mgr.algorithm( + tensors, + spirvVec, + workgroup, + specConstsVec, + pushConstsVec + ); + } + + template + std::shared_ptr algorithm( + kp::Manager& mgr, + const std::vector>& tensors, + const py::bytes& spirv, + const kp::Workgroup& workgroup, + const std::vector& specConsts, + const py::array& push_consts) { + std::vector spirvVec = pv::to_vector(spirv); + const py::buffer_info pushInfo = push_consts.request(); + std::vector pushConstsVec( + (push_T*)pushInfo.ptr, ((push_T*)pushInfo.ptr) + pushInfo.size); + return mgr.algorithm( + tensors, + spirvVec, + workgroup, + specConsts, + pushConstsVec + ); + } + + template + std::shared_ptr algorithm( + kp::Manager& mgr, + const std::vector>& tensors, + const py::bytes& spirv, + const kp::Workgroup& workgroup, + const py::array& specConsts, + const std::vector& pushConsts) { + std::vector spirvVec = pv::to_vector(spirv); + const py::buffer_info specInfo = specConsts.request(); + std::vector specConstsVec( + (spec_T*)specInfo.ptr, ((spec_T*)specInfo.ptr) + specInfo.size); + return mgr.algorithm( + tensors, + spirvVec, + workgroup, + specConstsVec, + pushConsts + ); + } +} + PYBIND11_MODULE(kp, m) { @@ -455,11 +532,8 @@ PYBIND11_MODULE(kp, m) const kp::Workgroup& workgroup, const std::vector& spec_consts, const std::vector& push_consts) { - py::buffer_info info(py::buffer(spirv).request()); - const char* data = reinterpret_cast(info.ptr); - size_t length = static_cast(info.size); - std::vector spirvVec((uint32_t*)data, - (uint32_t*)(data + length)); + std::vector spirvVec = pv::to_vector(spirv); + KP_LOG_DEBUG("Kompute Python Manager creating Algorithm."); return self.algorithm( tensors, spirvVec, workgroup, spec_consts, push_consts); }, @@ -469,6 +543,66 @@ PYBIND11_MODULE(kp, m) py::arg("workgroup") = kp::Workgroup(), py::arg("spec_consts") = std::vector(), py::arg("push_consts") = std::vector()) + .def( + "algorithm", + [](kp::Manager& self, + const std::vector>& tensors, + const py::bytes& spirv, + const kp::Workgroup& workgroup, + const py::array& spec_consts, + const std::vector& push_consts) { + KP_LOG_DEBUG("Kompute Python Manager creating Algorithm_T with " + "spec consts data size {} dtype {}", + spec_consts.size(), + std::string(py::str(spec_consts.dtype()))); + if (spec_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (spec_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (spec_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (spec_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } + // If reach then no valid dtype supported + throw std::runtime_error("Kompute Python no valid dtype supported"); + }, + DOC(kp, Manager, algorithm), + py::arg("tensors"), + py::arg("spirv"), + py::arg("workgroup") = kp::Workgroup(), + py::arg("spec_consts") = std::vector(), + py::arg("push_consts") = std::vector()) + .def( + "algorithm", + [](kp::Manager& self, + const std::vector>& tensors, + const py::bytes& spirv, + const kp::Workgroup& workgroup, + const std::vector& spec_consts, + const py::array& push_consts) { + KP_LOG_DEBUG("Kompute Python Manager creating Algorithm_T with " + "push consts data size {} dtype {}", + push_consts.size(), + std::string(py::str(push_consts.dtype()))); + if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } + // If reach then no valid dtype supported + throw std::runtime_error("Kompute Python no valid dtype supported"); + }, + DOC(kp, Manager, algorithm), + py::arg("tensors"), + py::arg("spirv"), + py::arg("workgroup") = kp::Workgroup(), + py::arg("spec_consts") = std::vector(), + py::arg("push_consts") = std::vector()) .def( "algorithm", [np](kp::Manager& self, @@ -477,14 +611,6 @@ PYBIND11_MODULE(kp, m) const kp::Workgroup& workgroup, const py::array& spec_consts, const py::array& push_consts) { - py::buffer_info info(py::buffer(spirv).request()); - const char* data = reinterpret_cast(info.ptr); - size_t length = static_cast(info.size); - std::vector spirvVec((uint32_t*)data, - (uint32_t*)(data + length)); - - const py::buffer_info pushInfo = push_consts.request(); - const py::buffer_info specInfo = spec_consts.request(); KP_LOG_DEBUG("Kompute Python Manager creating Algorithm_T with " "push consts data size {} dtype {} and spec const " @@ -497,179 +623,44 @@ PYBIND11_MODULE(kp, m) // We have to iterate across a combination of parameters due to the // lack of support for templating if (spec_consts.dtype().is(py::dtype::of())) { - std::vector specConstsVec( - (float*)specInfo.ptr, ((float*)specInfo.ptr) + specInfo.size); - if (spec_consts.dtype().is(py::dtype::of())) { - std::vector pushConstsVec((float*)pushInfo.ptr, - ((float*)pushInfo.ptr) + - pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specConstsVec, - pushConstsVec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushConstsVec( - (int32_t*)pushInfo.ptr, - ((int32_t*)pushInfo.ptr) + pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specConstsVec, - pushConstsVec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushConstsVec( - (uint32_t*)pushInfo.ptr, - ((uint32_t*)pushInfo.ptr) + pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specConstsVec, - pushConstsVec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushConstsVec((double*)pushInfo.ptr, - ((double*)pushInfo.ptr) + - pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specConstsVec, - pushConstsVec); + if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); } } else if (spec_consts.dtype().is(py::dtype::of())) { - std::vector specconstsvec((int32_t*)specInfo.ptr, - ((int32_t*)specInfo.ptr) + - specInfo.size); - if (spec_consts.dtype().is(py::dtype::of())) { - std::vector pushconstsvec((float*)pushInfo.ptr, - ((float*)pushInfo.ptr) + - pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specconstsvec, - pushconstsvec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushconstsvec( - (int32_t*)pushInfo.ptr, - ((int32_t*)pushInfo.ptr) + pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specconstsvec, - pushconstsvec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushconstsvec( - (uint32_t*)pushInfo.ptr, - ((uint32_t*)pushInfo.ptr) + pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specconstsvec, - pushconstsvec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushconstsvec((double*)pushInfo.ptr, - ((double*)pushInfo.ptr) + - pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specconstsvec, - pushconstsvec); + if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); } } else if (spec_consts.dtype().is(py::dtype::of())) { - std::vector specconstsvec((uint32_t*)specInfo.ptr, - ((uint32_t*)specInfo.ptr) + - specInfo.size); - if (spec_consts.dtype().is(py::dtype::of())) { - std::vector pushconstsvec((float*)pushInfo.ptr, - ((float*)pushInfo.ptr) + - pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specconstsvec, - pushconstsvec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushconstsvec( - (int32_t*)pushInfo.ptr, - ((int32_t*)pushInfo.ptr) + pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specconstsvec, - pushconstsvec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushconstsvec( - (uint32_t*)pushInfo.ptr, - ((uint32_t*)pushInfo.ptr) + pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specconstsvec, - pushconstsvec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushconstsvec((double*)pushInfo.ptr, - ((double*)pushInfo.ptr) + - pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specconstsvec, - pushconstsvec); + if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); } } else if (spec_consts.dtype().is(py::dtype::of())) { - std::vector specconstsvec((double*)specInfo.ptr, - ((double*)specInfo.ptr) + - specInfo.size); - if (spec_consts.dtype().is(py::dtype::of())) { - std::vector pushconstsvec((float*)pushInfo.ptr, - ((float*)pushInfo.ptr) + - pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specconstsvec, - pushconstsvec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushconstsvec((int32_t*)pushInfo.ptr, - ((int32_t*)pushInfo.ptr) + - pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specconstsvec, - pushconstsvec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushconstsvec((uint32_t*)pushInfo.ptr, - ((uint32_t*)pushInfo.ptr) + - pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specconstsvec, - pushconstsvec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushconstsvec((double*)pushInfo.ptr, - ((double*)pushInfo.ptr) + - pushInfo.size); - return self.algorithm(tensors, - spirvVec, - workgroup, - specconstsvec, - pushconstsvec); + if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + } else if (push_consts.dtype().is(py::dtype::of())) { + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); } } // If reach then no valid dtype supported From 8e0289b2e671ff7506844297c305aa3dfbb910e0 Mon Sep 17 00:00:00 2001 From: koubaa Date: Tue, 10 Dec 2024 11:09:47 -0600 Subject: [PATCH 2/2] noconvert pyarray arguments Signed-off-by: koubaa --- python/src/main.cpp | 46 ++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/python/src/main.cpp b/python/src/main.cpp index 086ce7e2..953dd789 100644 --- a/python/src/main.cpp +++ b/python/src/main.cpp @@ -563,19 +563,21 @@ PYBIND11_MODULE(kp, m) const std::vector>& tensors, const py::bytes& spirv, const kp::Workgroup& workgroup, - const std::vector& spec_consts, - const std::vector& push_consts) { + const py::list& spec_consts, + const py::list& push_consts) { std::vector spirvVec = pv::to_vector(spirv); KP_LOG_DEBUG("Kompute Python Manager creating Algorithm."); + auto pushConstsVec = push_consts.cast>(); + auto specConstsVec = spec_consts.cast>(); return self.algorithm( - tensors, spirvVec, workgroup, spec_consts, push_consts); + tensors, spirvVec, workgroup, specConstsVec, pushConstsVec); }, DOC(kp, Manager, algorithm), py::arg("tensors"), py::arg("spirv"), py::arg("workgroup") = kp::Workgroup(), - py::arg("spec_consts") = std::vector(), - py::arg("push_consts") = std::vector()) + py::arg("spec_consts") = py::list(), + py::arg("push_consts") = py::list()) .def( "algorithm", [](kp::Manager& self, @@ -583,19 +585,20 @@ PYBIND11_MODULE(kp, m) const py::bytes& spirv, const kp::Workgroup& workgroup, const py::array& spec_consts, - const std::vector& push_consts) { + const py::list& push_consts) { KP_LOG_DEBUG("Kompute Python Manager creating Algorithm_T with " "spec consts data size {} dtype {}", spec_consts.size(), std::string(py::str(spec_consts.dtype()))); + auto pushConstsVec = push_consts.cast>(); if (spec_consts.dtype().is(py::dtype::of())) { - return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, pushConstsVec); } else if (spec_consts.dtype().is(py::dtype::of())) { - return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, pushConstsVec); } else if (spec_consts.dtype().is(py::dtype::of())) { - return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, pushConstsVec); } else if (spec_consts.dtype().is(py::dtype::of())) { - return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, pushConstsVec); } // If reach then no valid dtype supported throw std::runtime_error("Kompute Python no valid dtype supported"); @@ -604,28 +607,29 @@ PYBIND11_MODULE(kp, m) py::arg("tensors"), py::arg("spirv"), py::arg("workgroup") = kp::Workgroup(), - py::arg("spec_consts") = std::vector(), - py::arg("push_consts") = std::vector()) + py::arg("spec_consts").noconvert(true) = py::array(), + py::arg("push_consts") = py::list()) .def( "algorithm", [](kp::Manager& self, const std::vector>& tensors, const py::bytes& spirv, const kp::Workgroup& workgroup, - const std::vector& spec_consts, + const py::list& spec_consts, const py::array& push_consts) { KP_LOG_DEBUG("Kompute Python Manager creating Algorithm_T with " "push consts data size {} dtype {}", push_consts.size(), std::string(py::str(push_consts.dtype()))); + auto specConstsVec = spec_consts.cast>(); if (push_consts.dtype().is(py::dtype::of())) { - return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + return pv::algorithm(self, tensors, spirv, workgroup, specConstsVec, push_consts); } else if (push_consts.dtype().is(py::dtype::of())) { - return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + return pv::algorithm(self, tensors, spirv, workgroup, specConstsVec, push_consts); } else if (push_consts.dtype().is(py::dtype::of())) { - return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + return pv::algorithm(self, tensors, spirv, workgroup, specConstsVec, push_consts); } else if (push_consts.dtype().is(py::dtype::of())) { - return pv::algorithm(self, tensors, spirv, workgroup, spec_consts, push_consts); + return pv::algorithm(self, tensors, spirv, workgroup, specConstsVec, push_consts); } // If reach then no valid dtype supported throw std::runtime_error("Kompute Python no valid dtype supported"); @@ -634,8 +638,8 @@ PYBIND11_MODULE(kp, m) py::arg("tensors"), py::arg("spirv"), py::arg("workgroup") = kp::Workgroup(), - py::arg("spec_consts") = std::vector(), - py::arg("push_consts") = std::vector()) + py::arg("spec_consts") = py::list(), + py::arg("push_consts").noconvert(true) = py::array()) .def( "algorithm", [np](kp::Manager& self, @@ -703,8 +707,8 @@ PYBIND11_MODULE(kp, m) py::arg("tensors"), py::arg("spirv"), py::arg("workgroup") = kp::Workgroup(), - py::arg("spec_consts") = std::vector(), - py::arg("push_consts") = std::vector()) + py::arg("spec_consts").noconvert(true) = py::array(), + py::arg("push_consts").noconvert(true) = py::array()) .def( "list_devices", [](kp::Manager& self) {