Skip to content

Commit

Permalink
Enable Python/C++ interop via exposed JIT functionality (#2214)
Browse files Browse the repository at this point in the history
* Enable C++ interop with user Python kernels.

Signed-off-by: Alex McCaskey <[email protected]>
  • Loading branch information
amccaskey authored Oct 17, 2024
1 parent 6a370c9 commit 8bc1349
Show file tree
Hide file tree
Showing 26 changed files with 1,025 additions and 41 deletions.
1 change: 1 addition & 0 deletions cmake/Modules/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ set(CONFIG_FILES
CUDAQConfig.cmake
CUDAQEnsmallenConfig.cmake
CUDAQPlatformDefaultConfig.cmake
CUDAQPythonInteropConfig.cmake
)
set(LANG_FILES
CMakeCUDAQCompiler.cmake.in
Expand Down
3 changes: 3 additions & 0 deletions cmake/Modules/CUDAQConfig.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ find_dependency(CUDAQNlopt REQUIRED)
set (CUDAQEnsmallen_DIR "${CUDAQ_CMAKE_DIR}")
find_dependency(CUDAQEnsmallen REQUIRED)

set (CUDAQPythonInterop_DIR "${CUDAQ_CMAKE_DIR}")
find_dependency(CUDAQPythonInterop REQUIRED)

get_filename_component(PARENT_DIRECTORY ${CUDAQ_CMAKE_DIR} DIRECTORY)
get_filename_component(CUDAQ_LIBRARY_DIR ${PARENT_DIRECTORY} DIRECTORY)
get_filename_component(CUDAQ_INSTALL_DIR ${CUDAQ_LIBRARY_DIR} DIRECTORY)
Expand Down
13 changes: 13 additions & 0 deletions cmake/Modules/CUDAQPythonInteropConfig.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# ============================================================================ #
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

get_filename_component(CUDAQ_PYTHONINTEROP_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)

if(NOT TARGET cudaq::cudaq-python-interop)
include("${CUDAQ_PYTHONINTEROP_CMAKE_DIR}/CUDAQPythonInteropTargets.cmake")
endif()
5 changes: 3 additions & 2 deletions include/cudaq/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ std::unique_ptr<mlir::Pass> createRaiseToAffinePass();
std::unique_ptr<mlir::Pass> createUnwindLoweringPass();

std::unique_ptr<mlir::Pass>
createPySynthCallableBlockArgs(const std::vector<std::string> &);
createPySynthCallableBlockArgs(const llvm::SmallVector<llvm::StringRef> &,
bool removeBlockArg = false);
inline std::unique_ptr<mlir::Pass> createPySynthCallableBlockArgs() {
return createPySynthCallableBlockArgs({});
return createPySynthCallableBlockArgs({}, false);
}

/// Helper function to build an argument synthesis pass. The names of the
Expand Down
100 changes: 73 additions & 27 deletions lib/Optimizer/Transforms/PySynthCallableBlockArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
******************************************************************************/

#include "PassDetails.h"
#include "cudaq/Optimizer/Builder/Runtime.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
Expand All @@ -22,12 +23,13 @@ namespace {

class ReplaceCallIndirect : public OpConversionPattern<func::CallIndirectOp> {
public:
const std::vector<std::string> &names;
const std::map<std::size_t, std::size_t> &blockArgToNameMap;
const SmallVector<StringRef> &names;
// const llvm::DenseMap<std::size_t, std::size_t>& blockArgToNameMap;
llvm::DenseMap<std::size_t, std::size_t> &blockArgToNameMap;

ReplaceCallIndirect(MLIRContext *ctx,
const std::vector<std::string> &functionNames,
const std::map<std::size_t, std::size_t> &map)
const SmallVector<StringRef> &functionNames,
llvm::DenseMap<std::size_t, std::size_t> &map)
: OpConversionPattern<func::CallIndirectOp>(ctx), names(functionNames),
blockArgToNameMap(map) {}

Expand All @@ -41,13 +43,11 @@ class ReplaceCallIndirect : public OpConversionPattern<func::CallIndirectOp> {
if (auto blockArg =
dyn_cast<BlockArgument>(ccCallableFunc.getOperand())) {
auto argIdx = blockArg.getArgNumber();
auto replacementName = names[blockArgToNameMap.at(argIdx)];
auto replacementName = names[blockArgToNameMap[argIdx]];
auto replacement = module.lookupSymbol<func::FuncOp>(
"__nvqpp__mlirgen__" + replacementName);
if (!replacement) {
op.emitError("Invalid replacement function " + replacementName);
cudaq::runtime::cudaqGenPrefixName + replacementName.str());
if (!replacement)
return failure();
}

rewriter.replaceOpWithNewOp<func::CallOp>(op, replacement,
adaptor.getCalleeOperands());
Expand All @@ -59,13 +59,46 @@ class ReplaceCallIndirect : public OpConversionPattern<func::CallIndirectOp> {
}
};

class ReplaceCallCallable
: public OpConversionPattern<cudaq::cc::CallCallableOp> {
public:
const SmallVector<StringRef> &names;
llvm::DenseMap<std::size_t, std::size_t> &blockArgToNameMap;

ReplaceCallCallable(MLIRContext *ctx,
const SmallVector<StringRef> &functionNames,
llvm::DenseMap<std::size_t, std::size_t> &map)
: OpConversionPattern<cudaq::cc::CallCallableOp>(ctx),
names(functionNames), blockArgToNameMap(map) {}

LogicalResult
matchAndRewrite(cudaq::cc::CallCallableOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto callableOperand = adaptor.getCallee();
auto module = op->getParentOp()->getParentOfType<ModuleOp>();
if (auto blockArg = dyn_cast<BlockArgument>(callableOperand)) {
auto argIdx = blockArg.getArgNumber();
auto replacementName = names[blockArgToNameMap[argIdx]];
auto replacement = module.lookupSymbol<func::FuncOp>(
cudaq::runtime::cudaqGenPrefixName + replacementName.str());
if (!replacement)
return failure();

rewriter.replaceOpWithNewOp<func::CallOp>(op, replacement,
adaptor.getArgs());
return success();
}
return failure();
}
};

class UpdateQuakeApplyOp : public OpConversionPattern<quake::ApplyOp> {
public:
const std::vector<std::string> &names;
const std::map<std::size_t, std::size_t> &blockArgToNameMap;
const SmallVector<StringRef> &names;
llvm::DenseMap<std::size_t, std::size_t> &blockArgToNameMap;
UpdateQuakeApplyOp(MLIRContext *ctx,
const std::vector<std::string> &functionNames,
const std::map<std::size_t, std::size_t> &map)
const SmallVector<StringRef> &functionNames,
llvm::DenseMap<std::size_t, std::size_t> &map)
: OpConversionPattern<quake::ApplyOp>(ctx), names(functionNames),
blockArgToNameMap(map) {}

Expand All @@ -77,13 +110,11 @@ class UpdateQuakeApplyOp : public OpConversionPattern<quake::ApplyOp> {
auto ctx = op.getContext();
if (auto blockArg = dyn_cast<BlockArgument>(callableOperand)) {
auto argIdx = blockArg.getArgNumber();
auto replacementName = names[blockArgToNameMap.at(argIdx)];
auto replacementName = names[blockArgToNameMap[argIdx]];
auto replacement = module.lookupSymbol<func::FuncOp>(
"__nvqpp__mlirgen__" + replacementName);
if (!replacement) {
op.emitError("Invalid replacement function " + replacementName);
cudaq::runtime::cudaqGenPrefixName + replacementName.str());
if (!replacement)
return failure();
}

rewriter.replaceOpWithNewOp<quake::ApplyOp>(
op, TypeRange{}, FlatSymbolRefAttr::get(ctx, replacement.getName()),
Expand All @@ -97,10 +128,13 @@ class UpdateQuakeApplyOp : public OpConversionPattern<quake::ApplyOp> {
class PySynthCallableBlockArgs
: public cudaq::opt::PySynthCallableBlockArgsBase<
PySynthCallableBlockArgs> {
private:
bool removeBlockArg = false;

public:
std::vector<std::string> names;
PySynthCallableBlockArgs(const std::vector<std::string> &_names)
: names(_names) {}
SmallVector<StringRef> names;
PySynthCallableBlockArgs(const SmallVector<StringRef> &_names, bool remove)
: removeBlockArg(remove), names(_names) {}

void runOnOperation() override {
auto op = getOperation();
Expand All @@ -109,7 +143,7 @@ class PySynthCallableBlockArgs

std::size_t numCallableBlockArgs = 0;
// need to map blockArgIdx -> counter(0,1,2,...)
std::map<std::size_t, std::size_t> blockArgToNamesMap;
llvm::DenseMap<std::size_t, std::size_t> blockArgToNamesMap;
for (std::size_t i = 0, k = 0; auto ty : op.getFunctionType().getInputs()) {
if (isa<cudaq::cc::CallableType>(ty)) {
numCallableBlockArgs++;
Expand All @@ -129,8 +163,9 @@ class PySynthCallableBlockArgs
return;
}

patterns.insert<ReplaceCallIndirect, UpdateQuakeApplyOp>(
ctx, names, blockArgToNamesMap);
patterns
.insert<ReplaceCallIndirect, ReplaceCallCallable, UpdateQuakeApplyOp>(
ctx, names, blockArgToNamesMap);
ConversionTarget target(*ctx);
// We should remove these operations
target.addIllegalOp<func::CallIndirectOp>();
Expand All @@ -148,11 +183,22 @@ class PySynthCallableBlockArgs
"error synthesizing callable functions for python.\n");
signalPassFailure();
}

if (removeBlockArg) {
auto numArgs = op.getNumArguments();
BitVector argsToErase(numArgs);
for (std::size_t argIndex = 0; argIndex < numArgs; ++argIndex)
if (isa<cudaq::cc::CallableType>(op.getArgument(argIndex).getType()))
argsToErase.set(argIndex);

op.eraseArguments(argsToErase);
}
}
};
} // namespace

std::unique_ptr<Pass> cudaq::opt::createPySynthCallableBlockArgs(
const std::vector<std::string> &names) {
return std::make_unique<PySynthCallableBlockArgs>(names);
std::unique_ptr<Pass>
cudaq::opt::createPySynthCallableBlockArgs(const SmallVector<StringRef> &names,
bool removeBlockArg) {
return std::make_unique<PySynthCallableBlockArgs>(names, removeBlockArg);
}
2 changes: 2 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,5 @@ if(CUDAQ_BUILD_TESTS)
message(FATAL_ERROR "CUDA Quantum Python Warning - CUDAQ_BUILD_TESTS=TRUE but can't find numpy or pytest modules required for testing.")
endif()
endif()

add_subdirectory(runtime/interop)
15 changes: 13 additions & 2 deletions python/cudaq/kernel/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .utils import globalAstRegistry, globalKernelRegistry, mlirTypeFromAnnotation
from ..mlir.dialects import cc
from ..mlir.ir import *
from ..mlir._mlir_libs._quakeDialects import cudaq_runtime


class MidCircuitMeasurementAnalyzer(ast.NodeVisitor):
Expand Down Expand Up @@ -161,13 +162,23 @@ def visit_Call(self, node):

if len(moduleNames):
moduleNames.reverse()
if cudaq_runtime.isRegisteredDeviceModule(
'.'.join(moduleNames)):
return

# This will throw if the function / module is invalid
m = importlib.import_module('.'.join(moduleNames))
try:
m = importlib.import_module('.'.join(moduleNames))
except:
return

getattr(m, node.func.attr)
name = node.func.attr

if name not in globalAstRegistry:
raise RuntimeError(
f"{name} is not a valid kernel to call.")
f"{name} is not a valid kernel to call ({'.'.join(moduleNames)})."
)

self.depKernels[name] = globalAstRegistry[name]

Expand Down
54 changes: 54 additions & 0 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,59 @@ def visit_Call(self, node):
# FindDepKernels has found something like this, loaded it, and now we just
# want to get the function name and call it.

# First let's check for registered C++ kernels
cppDevModNames = []
value = node.func.value
if isinstance(value, ast.Name) and value.id != 'cudaq':
cppDevModNames = [node.func.attr, value.id]
else:
while isinstance(value, ast.Attribute):
cppDevModNames.append(value.attr)
value = value.value
if isinstance(value, ast.Name):
cppDevModNames.append(value.id)
break

devKey = '.'.join(cppDevModNames[::-1])

def get_full_module_path(partial_path):
parts = partial_path.split('.')
for module_name, module in sys.modules.items():
if module_name.endswith(parts[0]):
try:
obj = module
for part in parts[1:]:
obj = getattr(obj, part)
return f"{module_name}.{'.'.join(parts[1:])}"
except AttributeError:
continue
return partial_path

devKey = get_full_module_path(devKey)
if cudaq_runtime.isRegisteredDeviceModule(devKey):
maybeKernelName = cudaq_runtime.checkRegisteredCppDeviceKernel(
self.module, devKey + '.' + node.func.attr)
if maybeKernelName == None:
maybeKernelName = cudaq_runtime.checkRegisteredCppDeviceKernel(
self.module, devKey)
if maybeKernelName != None:
otherKernel = SymbolTable(
self.module.operation)[maybeKernelName]
fType = otherKernel.type
if len(fType.inputs) != len(node.args):
funcName = node.func.id if hasattr(
node.func, 'id') else node.func.attr
self.emitFatalError(
f"invalid number of arguments passed to callable {funcName} ({len(node.args)} vs required {len(fType.inputs)})",
node)

[self.visit(arg) for arg in node.args]
values = [self.popValue() for _ in node.args]
values.reverse()
values = [self.ifPointerThenLoad(v) for v in values]
func.CallOp(otherKernel, values)
return

# Start by seeing if we have mod1.mod2.mod3...
moduleNames = []
value = node.func.value
Expand Down Expand Up @@ -1816,6 +1869,7 @@ def bodyBuilder(iterVal):

values = [self.popValue() for _ in node.args]
values.reverse()
values = [self.ifPointerThenLoad(v) for v in values]
func.CallOp(otherKernel, values)
return

Expand Down
45 changes: 44 additions & 1 deletion python/cudaq/kernel/kernel_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Callable
from ..mlir.ir import *
from ..mlir.passmanager import *
from ..mlir.dialects import quake, cc
from ..mlir.dialects import quake, cc, func
from .ast_bridge import compile_to_mlir, PyASTBridge
from .utils import mlirTypeFromPyType, nvqppPrefix, mlirTypeToPyType, globalAstRegistry, emitFatalError, emitErrorIfInvalidPauli, globalRegisteredTypes
from .analysis import MidCircuitMeasurementAnalyzer, HasReturnNodeVisitor
Expand Down Expand Up @@ -220,6 +220,49 @@ def compile(self):
self.dependentCaptures = extraMetadata[
'dependent_captures'] if 'dependent_captures' in extraMetadata else None

def merge_kernel(self, otherMod):
"""
Merge the kernel in this PyKernelDecorator (the ModuleOp) with
the provided ModuleOp.
"""
self.compile()
if not isinstance(otherMod, str):
otherMod = str(otherMod)
newMod = cudaq_runtime.mergeExternalMLIR(self.module, otherMod)
# Get the name of the kernel entry point
name = self.name
for op in newMod.body:
if isinstance(op, func.FuncOp):
for attr in op.attributes:
if 'cudaq-entrypoint' == attr.name:
name = op.name.value.replace(nvqppPrefix, '')
break

return PyKernelDecorator(None, kernelName=name, module=newMod)

def synthesize_callable_arguments(self, funcNames):
"""
Given this Kernel has callable block arguments, synthesize away these
callable arguments with the in-module FuncOps with given names. The
name at index 0 in the list corresponds to the first callable block
argument, index 1 to the second callable block argument, etc.
"""
self.compile()
cudaq_runtime.synthPyCallable(self.module, funcNames)
# Reset the argument types by removing the Callable
self.argTypes = [
a for a in self.argTypes if not cc.CallableType.isinstance(a)
]

def extract_c_function_pointer(self, name=None):
"""
Return the C function pointer for the function with given name, or
with the name of this kernel if not provided.
"""
self.compile()
return cudaq_runtime.jitAndGetFunctionPointer(
self.module, nvqppPrefix + self.name if name is None else name)

def __str__(self):
"""
Return the MLIR Module string representation for this kernel.
Expand Down
6 changes: 6 additions & 0 deletions python/cudaq/kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,12 @@ def mlirTypeToPyType(argType):
if F32Type.isinstance(argType):
return np.float32

if quake.VeqType.isinstance(argType):
return qvector

if cc.CallableType.isinstance(argType):
return Callable

if ComplexType.isinstance(argType):
if F64Type.isinstance(ComplexType(argType).element_type):
return complex
Expand Down
Loading

0 comments on commit 8bc1349

Please sign in to comment.