forked from pytorch/TensorRT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompiler.h
40 lines (31 loc) · 1.3 KB
/
compiler.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#pragma once
#include <cuda_runtime.h>
#include <vector>
#include "core/conversion/conversion.h"
#include "core/ir/ir.h"
#include "core/lowering/lowering.h"
#include "core/partitioning/partitioning.h"
#include "core/runtime/runtime.h"
#include "torch/csrc/jit/api/module.h"
#include "torch/csrc/jit/ir/ir.h"
namespace torch_tensorrt {
namespace core {
struct CompileSpec {
CompileSpec(std::vector<ir::Input> inputs) : graph_inputs(inputs) {}
CompileSpec(torch::jit::IValue& input_signature) : graph_inputs(input_signature) {}
ir::GraphInputs graph_inputs;
conversion::ConversionInfo convert_info;
lowering::LowerInfo lower_info;
partitioning::PartitioningInfo partitioning_info;
};
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg);
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);
torch::jit::script::Module EmbedEngineInNewModule(
const std::string& engine,
runtime::RTDevice cuda_device,
const std::vector<std::string>& input_binding_names,
const std::vector<std::string>& output_binding_names);
void set_device(const int gpu_id);
} // namespace core
} // namespace torch_tensorrt