Skip to content

Commit

Permalink
[Compiler] Add force-entrypoints-return-allocs option to TRT task (#494)
Browse files Browse the repository at this point in the history
Add force-entrypoints-return-allocs option to TRT task, also update the
integration test to use non-DPS style calling convention.
  • Loading branch information
yizhuoz004 authored Feb 12, 2025
1 parent cdfe3a9 commit e841bee
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@ struct TensorRTToExecutableOptions

Option<std::string> entrypoint{this, "entrypoint", llvm::cl::init("main"),
llvm::cl::desc("entrypoint function name")};

/// Forces entrypoint functions to return allocations corresponding to the
/// original tensor results. Otherwise, entrypoints will be lowered to use
/// destination passing style whenever possible, but some results may still
/// lower to returned allocations (because the output shape may not be
/// computable from the inputs). In either case, the user should verify the
/// final calling convention of the compiled function(s) by inspecting the
/// compiled function signature metadata.
Option<bool> forceEntrypointsReturnAllocs{
this, "force-entrypoints-return-allocs", llvm::cl::init(false),
llvm::cl::desc(
"Require entrypoint functions to return allocations corresponding to"
" the original tensor results, otherwise they are transformed"
" into destination arguments whenever possible.")};
};

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ void TensorRTToExecutableTask::buildPostClusteringPipeline(
nullptr, options.get<TensorRTOptions>().options));

pm.addPass(createMemRefCastEliminationPass());
plan::PlanAllocTensorsPassOptions allocTensorOpts{};
allocTensorOpts.forceEntrypointsReturnAllocs =
options.forceEntrypointsReturnAllocs;
pm.addPass(plan::createPlanAllocTensorsPass());
pm.addPass(plan::createPlanBufferizePass());
pm.addPass(createMemRefCastEliminationPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def compile(client, op):
"--tensorrt-builder-opt-level=0",
"--tensorrt-strongly-typed=true",
"--tensorrt-workspace-memory-pool-limit=1024kB",
"--force-entrypoints-return-allocs",
],
)
task.run(op)
Expand Down Expand Up @@ -53,14 +54,11 @@ def tensorrt_add():
device=devices[0],
stream=stream,
)
arg1 = client.create_memref(
np.zeros(shape=(2, 3, 4), dtype=np.float32).data,
device=devices[0],
stream=stream,
results = session.execute_function(
"main", in_args=[arg0], stream=stream, client=client
)
session.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream)

data = np.asarray(client.copy_to_host(arg1, stream=stream))
data = np.asarray(client.copy_to_host(results[0], stream=stream))
stream.sync()

print(data)
Expand Down

0 comments on commit e841bee

Please sign in to comment.