Skip to content

Commit

Permalink
address pr comments
Browse files Browse the repository at this point in the history
Signed-off-by: Alex McCaskey <[email protected]>
  • Loading branch information
amccaskey committed Oct 16, 2024
1 parent 59f884c commit a390b7c
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 62 deletions.
2 changes: 1 addition & 1 deletion include/cudaq/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ std::unique_ptr<mlir::Pass> createRaiseToAffinePass();
std::unique_ptr<mlir::Pass> createUnwindLoweringPass();

std::unique_ptr<mlir::Pass>
createPySynthCallableBlockArgs(const std::vector<std::string> &,
createPySynthCallableBlockArgs(const llvm::SmallVector<llvm::StringRef> &,
bool removeBlockArg = false);
inline std::unique_ptr<mlir::Pass> createPySynthCallableBlockArgs() {
return createPySynthCallableBlockArgs({}, false);
Expand Down
61 changes: 29 additions & 32 deletions lib/Optimizer/Transforms/PySynthCallableBlockArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
******************************************************************************/

#include "PassDetails.h"
#include "cudaq/Optimizer/Builder/Runtime.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
Expand All @@ -22,12 +23,13 @@ namespace {

class ReplaceCallIndirect : public OpConversionPattern<func::CallIndirectOp> {
public:
const std::vector<std::string> &names;
const std::map<std::size_t, std::size_t> &blockArgToNameMap;
const SmallVector<StringRef> &names;
// const llvm::DenseMap<std::size_t, std::size_t>& blockArgToNameMap;
llvm::DenseMap<std::size_t, std::size_t> &blockArgToNameMap;

ReplaceCallIndirect(MLIRContext *ctx,
const std::vector<std::string> &functionNames,
const std::map<std::size_t, std::size_t> &map)
const SmallVector<StringRef> &functionNames,
llvm::DenseMap<std::size_t, std::size_t> &map)
: OpConversionPattern<func::CallIndirectOp>(ctx), names(functionNames),
blockArgToNameMap(map) {}

Expand All @@ -41,13 +43,11 @@ class ReplaceCallIndirect : public OpConversionPattern<func::CallIndirectOp> {
if (auto blockArg =
dyn_cast<BlockArgument>(ccCallableFunc.getOperand())) {
auto argIdx = blockArg.getArgNumber();
auto replacementName = names[blockArgToNameMap.at(argIdx)];
auto replacementName = names[blockArgToNameMap[argIdx]];
auto replacement = module.lookupSymbol<func::FuncOp>(
"__nvqpp__mlirgen__" + replacementName);
if (!replacement) {
op.emitError("Invalid replacement function " + replacementName);
cudaq::runtime::cudaqGenPrefixName + replacementName.str());
if (!replacement)
return failure();
}

rewriter.replaceOpWithNewOp<func::CallOp>(op, replacement,
adaptor.getCalleeOperands());
Expand All @@ -62,12 +62,12 @@ class ReplaceCallIndirect : public OpConversionPattern<func::CallIndirectOp> {
class ReplaceCallCallable
: public OpConversionPattern<cudaq::cc::CallCallableOp> {
public:
const std::vector<std::string> &names;
const std::map<std::size_t, std::size_t> &blockArgToNameMap;
const SmallVector<StringRef> &names;
llvm::DenseMap<std::size_t, std::size_t> &blockArgToNameMap;

ReplaceCallCallable(MLIRContext *ctx,
const std::vector<std::string> &functionNames,
const std::map<std::size_t, std::size_t> &map)
const SmallVector<StringRef> &functionNames,
llvm::DenseMap<std::size_t, std::size_t> &map)
: OpConversionPattern<cudaq::cc::CallCallableOp>(ctx),
names(functionNames), blockArgToNameMap(map) {}

Expand All @@ -78,13 +78,11 @@ class ReplaceCallCallable
auto module = op->getParentOp()->getParentOfType<ModuleOp>();
if (auto blockArg = dyn_cast<BlockArgument>(callableOperand)) {
auto argIdx = blockArg.getArgNumber();
auto replacementName = names[blockArgToNameMap.at(argIdx)];
auto replacementName = names[blockArgToNameMap[argIdx]];
auto replacement = module.lookupSymbol<func::FuncOp>(
"__nvqpp__mlirgen__" + replacementName);
if (!replacement) {
op.emitError("Invalid replacement function " + replacementName);
cudaq::runtime::cudaqGenPrefixName + replacementName.str());
if (!replacement)
return failure();
}

rewriter.replaceOpWithNewOp<func::CallOp>(op, replacement,
adaptor.getArgs());
Expand All @@ -96,11 +94,11 @@ class ReplaceCallCallable

class UpdateQuakeApplyOp : public OpConversionPattern<quake::ApplyOp> {
public:
const std::vector<std::string> &names;
const std::map<std::size_t, std::size_t> &blockArgToNameMap;
const SmallVector<StringRef> &names;
llvm::DenseMap<std::size_t, std::size_t> &blockArgToNameMap;
UpdateQuakeApplyOp(MLIRContext *ctx,
const std::vector<std::string> &functionNames,
const std::map<std::size_t, std::size_t> &map)
const SmallVector<StringRef> &functionNames,
llvm::DenseMap<std::size_t, std::size_t> &map)
: OpConversionPattern<quake::ApplyOp>(ctx), names(functionNames),
blockArgToNameMap(map) {}

Expand All @@ -112,13 +110,11 @@ class UpdateQuakeApplyOp : public OpConversionPattern<quake::ApplyOp> {
auto ctx = op.getContext();
if (auto blockArg = dyn_cast<BlockArgument>(callableOperand)) {
auto argIdx = blockArg.getArgNumber();
auto replacementName = names[blockArgToNameMap.at(argIdx)];
auto replacementName = names[blockArgToNameMap[argIdx]];
auto replacement = module.lookupSymbol<func::FuncOp>(
"__nvqpp__mlirgen__" + replacementName);
if (!replacement) {
op.emitError("Invalid replacement function " + replacementName);
"__nvqpp__mlirgen__" + replacementName.str());
if (!replacement)
return failure();
}

rewriter.replaceOpWithNewOp<quake::ApplyOp>(
op, TypeRange{}, FlatSymbolRefAttr::get(ctx, replacement.getName()),
Expand All @@ -136,8 +132,8 @@ class PySynthCallableBlockArgs
bool removeBlockArg = false;

public:
std::vector<std::string> names;
PySynthCallableBlockArgs(const std::vector<std::string> &_names, bool remove)
SmallVector<StringRef> names;
PySynthCallableBlockArgs(const SmallVector<StringRef> &_names, bool remove)
: removeBlockArg(remove), names(_names) {}

void runOnOperation() override {
Expand All @@ -147,7 +143,7 @@ class PySynthCallableBlockArgs

std::size_t numCallableBlockArgs = 0;
// need to map blockArgIdx -> counter(0,1,2,...)
std::map<std::size_t, std::size_t> blockArgToNamesMap;
llvm::DenseMap<std::size_t, std::size_t> blockArgToNamesMap;
for (std::size_t i = 0, k = 0; auto ty : op.getFunctionType().getInputs()) {
if (isa<cudaq::cc::CallableType>(ty)) {
numCallableBlockArgs++;
Expand Down Expand Up @@ -201,7 +197,8 @@ class PySynthCallableBlockArgs
};
} // namespace

std::unique_ptr<Pass> cudaq::opt::createPySynthCallableBlockArgs(
const std::vector<std::string> &names, bool removeBlockArg) {
std::unique_ptr<Pass>
cudaq::opt::createPySynthCallableBlockArgs(const SmallVector<StringRef> &names,
bool removeBlockArg) {
return std::make_unique<PySynthCallableBlockArgs>(names, removeBlockArg);
}
2 changes: 1 addition & 1 deletion python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ if(CUDAQ_BUILD_TESTS)
endif()
endif()

add_subdirectory(runtime/interop)
add_subdirectory(runtime/interop)
4 changes: 2 additions & 2 deletions python/extension/CUDAQuantumExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ PYBIND11_MODULE(_quakeDialects, m) {
cudaqRuntime.def(
"isRegisteredDeviceModule",
[](const std::string &name) {
return cudaq::isRegisteredDeviceModule(name);
return cudaq::python::isRegisteredDeviceModule(name);
},
"Return true if the input name (mod1.mod2...) is a registered C++ device "
"module.");
Expand All @@ -261,7 +261,7 @@ PYBIND11_MODULE(_quakeDialects, m) {
const std::string &moduleName) -> std::optional<std::string> {
std::tuple<std::string, std::string> ret;
try {
ret = cudaq::getDeviceKernel(moduleName);
ret = cudaq::python::getDeviceKernel(moduleName);
} catch (...) {
return std::nullopt;
}
Expand Down
8 changes: 5 additions & 3 deletions python/runtime/cudaq/platform/py_alt_launch_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ jitAndCreateArgs(const std::string &name, MlirModule module,
auto cloned = mod.clone();
auto context = cloned.getContext();
PassManager pm(context);
pm.addNestedPass<func::FuncOp>(
cudaq::opt::createPySynthCallableBlockArgs(names));
pm.addNestedPass<func::FuncOp>(cudaq::opt::createPySynthCallableBlockArgs(
SmallVector<StringRef>(names.begin(), names.end())));
pm.addPass(cudaq::opt::createGenerateDeviceCodeLoader({.jitTime = true}));
pm.addPass(cudaq::opt::createGenerateKernelExecution(
{.startingArgIdx = startingArgIdx}));
Expand Down Expand Up @@ -771,7 +771,9 @@ void bindAltLaunchKernel(py::module &mod) {
auto context = m.getContext();
PassManager pm(context);
pm.addNestedPass<func::FuncOp>(
cudaq::opt::createPySynthCallableBlockArgs(funcNames, true));
cudaq::opt::createPySynthCallableBlockArgs(
SmallVector<StringRef>(funcNames.begin(), funcNames.end()),
true));
if (failed(pm.run(m)))
throw std::runtime_error(
"cudaq::jit failed to remove callable block arguments.");
Expand Down
14 changes: 7 additions & 7 deletions python/runtime/interop/PythonCppInterop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "PythonCppInterop.h"
#include "cudaq.h"

namespace cudaq {
namespace cudaq::python {

std::string getKernelName(std::string &input) {
size_t pos = 0;
Expand Down Expand Up @@ -60,11 +60,11 @@ std::tuple<std::string, std::string>
getMLIRCodeAndName(const std::string &name, const std::string mangledArgs) {
auto cppMLIRCode =
cudaq::get_quake(std::remove_cvref_t<decltype(name)>(name), mangledArgs);
auto kernelName = cudaq::getKernelName(cppMLIRCode);
cppMLIRCode = "module {\nfunc.func @" + kernelName +
cudaq::extractSubstring(cppMLIRCode, "func.func @" + kernelName,
"func.func") +
"\n}";
auto kernelName = cudaq::python::getKernelName(cppMLIRCode);
cppMLIRCode =
"module {\nfunc.func @" + kernelName +
extractSubstring(cppMLIRCode, "func.func @" + kernelName, "func.func") +
"\n}";
return std::make_tuple(kernelName, cppMLIRCode);
}

Expand Down Expand Up @@ -97,4 +97,4 @@ getDeviceKernel(const std::string &compositeName) {
return iter->second;
}

} // namespace cudaq
} // namespace cudaq::python
8 changes: 4 additions & 4 deletions python/runtime/interop/PythonCppInterop.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

namespace py = pybind11;

namespace cudaq {
namespace cudaq::python {

/// @class CppPyKernelDecorator
/// @brief A C++ wrapper for a Python object representing a CUDA-Q kernel.
Expand Down Expand Up @@ -162,8 +162,8 @@ void addDeviceKernelInterop(py::module_ &m, const std::string &modName,

sub.def(
kernelName.c_str(), [](Signature...) {}, docstring.c_str());
cudaq::registerDeviceKernel(sub.attr("__name__").cast<std::string>(),
kernelName, mangledArgs);
cudaq::python::registerDeviceKernel(sub.attr("__name__").cast<std::string>(),
kernelName, mangledArgs);
return;
}
} // namespace cudaq
} // namespace cudaq::python
4 changes: 2 additions & 2 deletions python/runtime/utils/PyRemoteRESTQPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class PyRemoteRESTQPU : public cudaq::BaseRemoteRESTQPU {
// specific to python before the rest of the RemoteRESTQPU workflow
auto cloned = m_module.clone();
PassManager pm(cloned.getContext());
pm.addNestedPass<func::FuncOp>(
cudaq::opt::createPySynthCallableBlockArgs(callableNames));
pm.addNestedPass<func::FuncOp>(cudaq::opt::createPySynthCallableBlockArgs(
SmallVector<StringRef>(callableNames.begin(), callableNames.end())));
cudaq::opt::addAggressiveEarlyInlining(pm);
pm.addPass(mlir::createCanonicalizerPass());
pm.addNestedPass<mlir::func::FuncOp>(
Expand Down
17 changes: 7 additions & 10 deletions python/tests/interop/test_cpp_quantum_algorithm_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ PYBIND11_MODULE(cudaq_test_cpp_algo, m) {
m.def("test_cpp_qalgo", [](py::object statePrepIn) {
// Wrap the kernel and compile, will throw
// if not a valid kernel
cudaq::CppPyKernelDecorator statePrep(statePrepIn);
cudaq::python::CppPyKernelDecorator statePrep(statePrepIn);
statePrep.compile();

// Our library exposes an "entryPoint" kernel, get its
// mangled name and MLIR code
auto [kernelName, cppMLIRCode] = cudaq::getMLIRCodeAndName("entryPoint");
auto [kernelName, cppMLIRCode] =
cudaq::python::getMLIRCodeAndName("entryPoint");

// Merge the entryPoint kernel into the input stateprep kernel
auto merged = statePrep.merge_kernel(cppMLIRCode);
Expand All @@ -40,16 +41,12 @@ PYBIND11_MODULE(cudaq_test_cpp_algo, m) {
return cudaq::sample(entryPointPtr);
});

// // Demo / Test overloaded kernel functions.
// cudaq::addDeviceKernelInterop<cudaq::qview<>, const std::vector<double> &,
// std::size_t>(m, "qstd", "qft", "");

// Example of how to expose C++ kernels.
cudaq::addDeviceKernelInterop<cudaq::qview<>>(
cudaq::python::addDeviceKernelInterop<cudaq::qview<>>(
m, "qstd", "qft", "(Fake) Quantum Fourier Transform.");
cudaq::addDeviceKernelInterop<cudaq::qview<>, std::size_t>(
cudaq::python::addDeviceKernelInterop<cudaq::qview<>, std::size_t>(
m, "qstd", "another", "Demonstrate we can have multiple ones.");

cudaq::addDeviceKernelInterop<cudaq::qview<>, std::size_t>(m, "qstd", "uccsd",
"");
cudaq::python::addDeviceKernelInterop<cudaq::qview<>, std::size_t>(
m, "qstd", "uccsd", "");
}

0 comments on commit a390b7c

Please sign in to comment.