Skip to content

Commit

Permalink
[tfrt:jitrt] Add rt.status support to JitRt API intrinsics
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 445052118
  • Loading branch information
ezhulenev authored and copybara-github committed Apr 28, 2022
1 parent eaf4cbb commit aae8198
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 26 deletions.
14 changes: 13 additions & 1 deletion backends/jitrt/include/tfrt/jitrt/opdefs/rt_base.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,23 @@ class RT_Type<string name, string typeMnemonic> : TypeDef<RuntimeDialect,
let mnemonic = typeMnemonic;
}

// -------------------------------------------------------------------------- //
// Types for integrating JitRt kernels with the runtime.
// -------------------------------------------------------------------------- //

// This is an opaque handle to tfrt::jitrt::KernelContextType.
def KernelContextType : RT_Type<"KernelContext", "kernel_context"> {
let summary = "Kernel Context type";
let description = [{
Opaque handle used for interacting with the TFRT run-time.
Opaque handle used for interacting with the JitRt runtime.
}];
}

// This is an opaque handle to tfrt::jitrt::StatusType.
def StatusType : RT_Type<"Status", "status"> {
let summary = "Status type";
let description = [{
A status type returned from the JitRt runtime API intrinsics.
}];
}

Expand Down
1 change: 1 addition & 0 deletions backends/jitrt/include/tfrt/jitrt/opdefs/rt_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
#include "tfrt/jitrt/opdefs/rt_dialect.h.inc"

Expand Down
30 changes: 26 additions & 4 deletions backends/jitrt/include/tfrt/jitrt/opdefs/rt_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,20 @@ def SetErrorOp : RT_Op<"set_error"> {
let assemblyFormat = "$ctx `,` $error attr-dict";
}

//===----------------------------------------------------------------------===//
// IsOkOp
//===----------------------------------------------------------------------===//

def IsOkOp : RT_Op<"is_ok"> {
let summary = "returns true if status is ok";
let description = "Checks if the runtime status is ok.";

let arguments = (ins StatusType:$status);
let results = (outs I1:$ok);

let assemblyFormat = "$status attr-dict";
}

//===----------------------------------------------------------------------===//
// CustomCallOp
//===----------------------------------------------------------------------===//
Expand All @@ -130,13 +144,18 @@ def CustomCallOp : RT_Op<"custom_call"> {
on top of the JitRt, for example this can be used as an extension mechanism
to register vendor specific kernels (e.g. call oneDNN convolution).

Returns `!rt.status` value which can be checked to see if the custom call
was successful.

Example:

```mlir
func @compute(%ctx: !rt.kernel_context, %arg0: memref<?xf32>,
%arg1: memref<?xf32>) {
%0 = rt.custom_call "one_dnn.some_operation"(%arg0, %arg1)
: (memref<?xf32>, memref<?xf32>) -> !one_dnn.status
%status = rt.custom_call "one_dnn.some_operation"(%arg0, %arg1)
: (memref<?xf32>, memref<?xf32>) -> ()
%0 = rt.is_ok %status
cf.assert %0, "failed to call one_dnn custom call"
return
}
```
Expand All @@ -152,10 +171,13 @@ def CustomCallOp : RT_Op<"custom_call"> {
Variadic<AnyType>:$operands
);

let results = (outs Variadic<AnyType>);
let results = (outs
StatusType:$status,
Variadic<AnyType>:$results
);

let assemblyFormat = [{
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
$callee `(` $operands `)` attr-dict `:` functional-type($operands, $results)
}];
}

Expand Down
5 changes: 3 additions & 2 deletions backends/jitrt/include/tfrt/jitrt/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ extern "C" void *runtimeGetResultStorage(KernelContext *, int64_t);
// Sets kernel context to an error state.
extern "C" void runtimeSetError(KernelContext *, const char *);

// Calls the custom call function registered with the runtime.
extern "C" void runtimeCustomCall(const char *, void **args);
// Calls the custom call function registered with the runtime. Returns true
// if the custom call was successful.
extern "C" bool runtimeCustomCall(const char *, void **args);

} // namespace runtime
} // namespace jitrt
Expand Down
33 changes: 29 additions & 4 deletions backends/jitrt/lib/conversion/rt_to_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ struct RuntimeAPI {
static FunctionType CustomCallFunctionType(MLIRContext *ctx) {
auto callee = OpaquePointerType(ctx);
auto args = CustomCallArgumentsType(ctx);
return FunctionType::get(ctx, {callee, args}, {});
auto i1 = IntegerType::get(ctx, 1);
return FunctionType::get(ctx, {callee, args}, {i1});
}
};

Expand All @@ -139,11 +140,16 @@ class RuntimeTypeConverter : public TypeConverter {
RuntimeTypeConverter() {
addConversion([](Type type) { return type; });
addConversion(ConvertKernelContextType);
addConversion(ConvertStatusType);
}

static llvm::Optional<Type> ConvertKernelContextType(KernelContextType type) {
return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
}

static llvm::Optional<Type> ConvertStatusType(StatusType type) {
return IntegerType::get(type.getContext(), 1);
}
};

// -------------------------------------------------------------------------- //
Expand Down Expand Up @@ -253,6 +259,23 @@ class SetErrorOpLowering : public OpConversionPattern<SetErrorOp> {
}
};

//===----------------------------------------------------------------------===//
// Convert rt.is_ok to the corresponding runtime API call.
//===----------------------------------------------------------------------===//

class IsOkOpLowering : public OpConversionPattern<IsOkOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
IsOkOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Just pass through the converted operand.
rewriter.replaceOp(op, adaptor.status());
return success();
}
};

//===----------------------------------------------------------------------===//
// Convert rt.custom_call to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -444,7 +467,8 @@ class CustomCallOpLowering : public OpConversionPattern<CustomCallOp> {
if (failed(args)) return op.emitOpError() << "failed to encode arguments";

// Call runtime API to call the custom call target.
rewriter.replaceOpWithNewOp<CallOp>(op, kCustomCall, TypeRange(),
auto i1 = rewriter.getI1Type();
rewriter.replaceOpWithNewOp<CallOp>(op, kCustomCall, TypeRange(i1),
ValueRange({callee, *args}));

return success();
Expand All @@ -471,10 +495,11 @@ void ConvertRuntimeToLLVMPass::runOnOperation() {
// We use conversion to LLVM type to lower all runtime operands to LLVM types.
LLVMTypeConverter llvm_converter(ctx);
llvm_converter.addConversion(RuntimeTypeConverter::ConvertKernelContextType);
llvm_converter.addConversion(RuntimeTypeConverter::ConvertStatusType);

// Lower from the runtime operations to the runtime API function calls.
patterns.add<SetOutputOpLowering, SetErrorOpLowering, CustomCallOpLowering>(
llvm_converter, ctx);
patterns.add<SetOutputOpLowering, SetErrorOpLowering, IsOkOpLowering,
CustomCallOpLowering>(llvm_converter, ctx);

// Convert function signatures and call sites.
mlir::populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
Expand Down
11 changes: 5 additions & 6 deletions backends/jitrt/lib/jitrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,7 @@ extern "C" void runtimeSetError(KernelContext* ctx, const char* error) {
ctx->call_frame->error = {error};
}

extern "C" void runtimeCustomCall(const char* callee, void** args) {
extern "C" bool runtimeCustomCall(const char* callee, void** args) {
assert(callee && "callee must be not null");

// Default custom calls registry for the JitRt kernels.
Expand All @@ -1267,14 +1267,13 @@ extern "C" void runtimeCustomCall(const char* callee, void** args) {
return registry;
}();

// TODO(ezhulenev): Return failure if custom call is not registered.
auto* custom_call = registry->Find(callee);
assert(custom_call && "unknown custom call");
if (custom_call == nullptr) return false;

// TODO(ezhulenev): Handle failures in custom calls.
auto result = custom_call->call(args);
assert(mlir::succeeded(result) && "failed custom call");
(void)result;
if (mlir::failed(result)) return false;

return true;
}

llvm::orc::SymbolMap RuntimeApiSymbolMap(llvm::orc::MangleAndInterner mangle) {
Expand Down
7 changes: 4 additions & 3 deletions backends/jitrt/mlir_tests/jitrt/compile.assert.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ module @kernels attributes { tfrt.compiled } {
}

// CHECK: --- Running 'runtime_error'
func.func @runtime_error() -> !tfrt.chain {
func.func @runtime_error() -> !t.tensor {
%ch0 = tfrt.new.chain

// Allocate and initialize input tensor.
Expand All @@ -40,9 +40,10 @@ func.func @runtime_error() -> !tfrt.chain {

%executable = jitrt.compile { kernel = @kernels::@main }

// expected-error @+1 {{Dimension 0 must have size 0}}
%output = jitrt.execute %executable[%input_ready](%input)
: (!t.tensor) -> (!t.tensor)

tfrt.return %ch0 : !tfrt.chain
// CHECK: returned <<error: compiled kernel run time error:
// CHECK-SAME: Dimension 0 must have size 0>>
tfrt.return %output : !t.tensor
}
29 changes: 26 additions & 3 deletions backends/jitrt/mlir_tests/jitrt/compile.custom_call.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ module @kernels attributes { tfrt.compiled } {
%c1 = arith.constant 1 : index
%0 = memref.dim %input, %c0 : memref<?x?xf32>
%1 = memref.dim %input, %c1 : memref<?x?xf32>
%output = memref.alloc(%0, %1) : memref<?x?xf32>

rt.custom_call "testlib.times_two"(%input, %output)
// Reverse dimension order to test invalid custom call arguments below.
%output = memref.alloc(%1, %0) : memref<?x?xf32>

%status = rt.custom_call "testlib.times_two"(%input, %output)
: (memref<?x?xf32>, memref<?x?xf32>) -> ()
%ok = rt.is_ok %status
cf.assert %ok, "failed to call custom call 'testlib.times_two'"

func.return %output : memref<?x?xf32>
}
Expand Down Expand Up @@ -55,4 +59,23 @@ func.func @compiled_custom_call() -> !tfrt.chain {
%printed = tfrt.print.i1 %cmp, %cmp_ch

tfrt.return %printed : !tfrt.chain
}
}

// CHECK: --- Running 'compiled_custom_call_error'
func.func @compiled_custom_call_error() -> !t.tensor {
%ch0 = tfrt.new.chain

// Allocate and initialize input tensor.
%input = tfrt_dht.create_uninitialized_tensor.f32.2 [16 : i64, 4 : i64]
%ch1 = tfrt_dht.fill_tensor_with_constant.f32 %input, %ch0 1.0 : f32

// Compile a kernel with a custom call.
%executable = jitrt.compile { kernel = @kernels::@main }

// Execute compiled kernel with tensor operands.
%output = jitrt.execute %executable[%ch1](%input) : (!t.tensor) -> !t.tensor

// CHECK: returned <<error: compiled kernel run time error:
// CHECK-SAME: failed to call custom call 'testlib.times_two'>>
tfrt.return %output : !t.tensor
}
4 changes: 3 additions & 1 deletion backends/jitrt/mlir_tests/rt/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ func.func @set_error(%arg0: !rt.kernel_context) {
// CHECK: %[[MEMREF:.*]]: memref<?xf32>
func.func @custom_call(%arg0: !rt.kernel_context, %arg1: memref<?xf32>) -> f32 {
// CHECK: rt.custom_call "f32_reduce"(%[[MEMREF]]) : (memref<?xf32>) -> f32
%0 = rt.custom_call "f32_reduce"(%arg1) : (memref<?xf32>) -> f32
%status, %0 = rt.custom_call "f32_reduce"(%arg1) : (memref<?xf32>) -> f32
%ok = rt.is_ok %status
cf.assert %ok, "failed to call custom call"
func.return %0 : f32
}
7 changes: 5 additions & 2 deletions backends/jitrt/mlir_tests/rt/rt_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ func.func @custom_call(%arg0: !rt.kernel_context, %arg1: memref<?xf32>) {
// CHECK: %[[C3:.*]] = arith.constant 3 : i32
// CHECK: %[[ARGS:.*]] = llvm.alloca %[[C3]] x !llvm.ptr<i8>

// CHECK: call @runtimeCustomCall(%[[CALLEE]], %[[ARGS]])
rt.custom_call "f32_reduce"(%arg1) : (memref<?xf32>) -> ()
// CHECK: %[[STATUS:.*]] = call @runtimeCustomCall(%[[CALLEE]], %[[ARGS]])
// CHECK: cf.assert %[[STATUS]], "oops"
%status = rt.custom_call "f32_reduce"(%arg1) : (memref<?xf32>) -> ()
%ok = rt.is_ok %status
cf.assert %ok, "oops"

func.return
}

0 comments on commit aae8198

Please sign in to comment.