-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbak_tf2trt_v1.py
65 lines (54 loc) · 1.91 KB
/
bak_tf2trt_v1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#from helper import ModelOptimizer
import tensorrt as trt
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert
from time import perf_counter
#
# dotw: 2021-01-12
# - FPS increased greatly
# - accuracy untested
# - model file size larger
# - very slow startup time
# - lots of TF/TRT warnings
#
print(f"tensorflow version={tf.__version__}")
print(f"tensorrt version={trt.__version__}")
PRECISION = "FP16"
GPU_RAM_4G = 4000000000
GPU_RAM_6G = 6000000000
GPU_RAM_8G = 8000000000
MPL = "/home/aisg/src/ongtw/PeekingDuck/peekingduck_weights/movenet/multipose_lightning"
SPL = "/home/aisg/src/ongtw/PeekingDuck/peekingduck_weights/movenet/singlepose_lightning"
SPT = "/home/aisg/src/ongtw/PeekingDuck/peekingduck_weights/movenet/singlepose_thunder"
model_dir = SPL
model_out_dir = model_dir + "_fp16"
# dotw: uses helper but error, helper not found...
#opt_model = ModelOptimizer(model_dir)
#model_fp16 = opt_model.convert(model_dir + "_fp16", precision=PRECISION)
# dotw: error, create_inference_graph() missing 2 required positional arguments:
# 'input_graph_def' and 'outputs'
#trt_convert.create_inference_graph(
# input_saved_model_dir = model_dir,
# output_saved_model_dir = model_out_dir
#)
conv_parms = trt_convert.TrtConversionParams(
precision_mode = trt_convert.TrtPrecisionMode.FP16,
max_workspace_size_bytes = GPU_RAM_4G,
)
converter = trt_convert.TrtGraphConverterV2(
input_saved_model_dir = model_dir,
conversion_params = conv_parms
)
print(f"generating {model_out_dir}")
print("converting original model...")
st0 = perf_counter()
converter.convert()
#converter.build(input_fn = self.my_input_fn)
et0 = perf_counter()
print(f"conversion time = {et0 - st0:.2f} sec")
print("saving generated model...")
st1 = perf_counter()
converter.save(model_out_dir)
et1 = perf_counter()
print(f"save time = {et1 - st1:.2f} sec")
print(f"Total conversion time = {et1 - st0:.2f} sec")