Skip to content

Commit

Permalink
[RN50/Paddle] Remove export script and add INT8 feature (QAT + infere…
Browse files Browse the repository at this point in the history
…nce)
  • Loading branch information
leo0519 authored and nv-kkudrynski committed Feb 20, 2024
1 parent 9dd9fcb commit 38934f9
Show file tree
Hide file tree
Showing 13 changed files with 373 additions and 263 deletions.
2 changes: 1 addition & 1 deletion PaddlePaddle/Classification/RN50v1.5/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:23.09-py3
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:23.12-py3
FROM ${FROM_IMAGE_NAME}

ADD requirements.txt /workspace/
Expand Down
350 changes: 235 additions & 115 deletions PaddlePaddle/Classification/RN50v1.5/README.md

Large diffs are not rendered by default.

75 changes: 0 additions & 75 deletions PaddlePaddle/Classification/RN50v1.5/export_model.py

This file was deleted.

27 changes: 16 additions & 11 deletions PaddlePaddle/Classification/RN50v1.5/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


def init_predictor(args):
infer_dir = args.trt_inference_dir
infer_dir = args.inference_dir
assert os.path.isdir(
infer_dir), f'inference_dir = "{infer_dir}" is not a directory'
pdiparams_path = glob.glob(os.path.join(infer_dir, '*.pdiparams'))
Expand All @@ -41,7 +41,7 @@ def init_predictor(args):
predictor_config = Config(pdmodel_path[0], pdiparams_path[0])
predictor_config.enable_memory_optim()
predictor_config.enable_use_gpu(0, args.device)
precision = args.trt_precision
precision = args.precision
max_batch_size = args.batch_size
assert precision in ['FP32', 'FP16', 'INT8'], \
'precision should be FP32/FP16/INT8'
Expand All @@ -54,12 +54,17 @@ def init_predictor(args):
else:
raise NotImplementedError
predictor_config.enable_tensorrt_engine(
workspace_size=args.trt_workspace_size,
workspace_size=args.workspace_size,
max_batch_size=max_batch_size,
min_subgraph_size=args.trt_min_subgraph_size,
min_subgraph_size=args.min_subgraph_size,
precision_mode=precision_mode,
use_static=args.trt_use_static,
use_calib_mode=args.trt_use_calib_mode)
use_static=args.use_static,
use_calib_mode=args.use_calib_mode)
predictor_config.set_trt_dynamic_shape_info(
{"data": (1,) + tuple(args.image_shape)},
{"data": (args.batch_size,) + tuple(args.image_shape)},
{"data": (args.batch_size,) + tuple(args.image_shape)},
)
predictor = create_predictor(predictor_config)
return predictor

Expand Down Expand Up @@ -140,7 +145,7 @@ def benchmark_dataset(args):
quantile = np.quantile(latency, [0.9, 0.95, 0.99])

statistics = {
'precision': args.trt_precision,
'precision': args.precision,
'batch_size': batch_size,
'throughput': total_images / (end - start),
'accuracy': correct_predict / total_images,
Expand Down Expand Up @@ -189,7 +194,7 @@ def benchmark_synthetic(args):
quantile = np.quantile(latency, [0.9, 0.95, 0.99])

statistics = {
'precision': args.trt_precision,
'precision': args.precision,
'batch_size': batch_size,
'throughput': args.benchmark_steps * batch_size / (end - start),
'eval_latency_avg': np.mean(latency),
Expand All @@ -200,11 +205,11 @@ def benchmark_synthetic(args):
return statistics

def main(args):
setup_dllogger(args.trt_log_path)
setup_dllogger(args.report_file)
if args.show_config:
print_args(args)

if args.trt_use_synthetic:
if args.use_synthetic:
statistics = benchmark_synthetic(args)
else:
statistics = benchmark_dataset(args)
Expand All @@ -213,4 +218,4 @@ def main(args):


if __name__ == '__main__':
main(parse_args(including_trt=True))
main(parse_args(script='inference'))
1 change: 1 addition & 0 deletions PaddlePaddle/Classification/RN50v1.5/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def dist_optimizer(args, optimizer):
}

dist_strategy.asp = args.asp
dist_strategy.qat = args.qat

optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

python inference.py \
--data-layout NHWC \
--trt-inference-dir ./inference_amp \
--trt-precision FP16 \
--inference-dir ./inference_amp \
--precision FP16 \
--batch-size 256 \
--benchmark-steps 1024 \
--benchmark-warmup-steps 16 \
--trt-use-synthetic True
--use-synthetic True
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

CKPT=${1:-"./output/ResNet50/89"}
MODEL_PREFIX=${2:-"resnet_50_paddle"}

python -m paddle.distributed.launch --gpus=0 export_model.py \
--trt-inference-dir ./inference_tf32 \
--from-checkpoint $CKPT \
--model-prefix ${MODEL_PREFIX}
python inference.py \
--data-layout NHWC \
--inference-dir ./inference_qat \
--precision INT8 \
--batch-size 256 \
--benchmark-steps 1024 \
--benchmark-warmup-steps 16 \
--use-synthetic True
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.

python inference.py \
--trt-inference-dir ./inference_tf32 \
--trt-precision FP32 \
--inference-dir ./inference_tf32 \
--precision FP32 \
--dali-num-threads 8 \
--batch-size 256 \
--benchmark-steps 1024 \
--benchmark-warmup-steps 16 \
--trt-use-synthetic True
--use-synthetic True
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 train.py \
--scale-loss 128.0 \
--use-dynamic-loss-scaling \
--data-layout NHWC \
--fuse-resunit
--fuse-resunit \
--inference-dir ./inference_amp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
CKPT=${1:-"./output/ResNet50/89"}
MODEL_PREFIX=${2:-"resnet_50_paddle"}

python -m paddle.distributed.launch --gpus=0 export_model.py \
--amp \
--data-layout NHWC \
--trt-inference-dir ./inference_amp \
--from-checkpoint ${CKPT} \
--model-prefix ${MODEL_PREFIX}
python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 train.py \
--from-pretrained-params ${CKPT} \
--model-prefix ${MODEL_PREFIX} \
--epochs 10 \
--amp \
--scale-loss 128.0 \
--use-dynamic-loss-scaling \
--data-layout NHWC \
--qat \
--lr 0.00005 \
--inference-dir ./inference_qat
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 train.py --epochs 90
python -m paddle.distributed.launch --gpus=0,1,2,3,4,5,6,7 train.py --epochs 90 --inference-dir ./inference_tf32
17 changes: 15 additions & 2 deletions PaddlePaddle/Classification/RN50v1.5/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from paddle.static.amp.fp16_lists import AutoMixedPrecisionLists
from paddle.static.amp.fp16_utils import cast_model_to_fp16
from paddle.incubate import asp as sparsity
from paddle.static.quantization.quanter import quant_aware


class MetricSummary:
Expand Down Expand Up @@ -107,7 +108,7 @@ def main(args):
eval_step_each_epoch = len(eval_dataloader)
eval_prog = paddle.static.Program()

eval_fetchs, _, _, _ = program.build(
eval_fetchs, _, eval_feeds, _ = program.build(
args,
eval_prog,
startup_prog,
Expand Down Expand Up @@ -147,6 +148,14 @@ def main(args):
sparsity.prune_model(train_prog, mask_algo=args.mask_algo)
logging.info("Pruning model done.")

if args.qat:
if args.run_scope == RunScope.EVAL_ONLY:
eval_prog = quant_aware(eval_prog, device, for_test=True, return_program=True)
else:
optimizer.qat_init(
device,
test_program=eval_prog)

if eval_prog is not None:
eval_prog = program.compile_prog(args, eval_prog, is_train=False)

Expand All @@ -169,7 +178,7 @@ def main(args):

# Save a checkpoint
if epoch_id % args.save_interval == 0:
model_path = os.path.join(args.output_dir, args.model_arch_name)
model_path = os.path.join(args.checkpoint_dir, args.model_arch_name)
save_model(train_prog, model_path, epoch_id, args.model_prefix)

# Evaluation
Expand All @@ -190,6 +199,10 @@ def main(args):
if eval_summary.is_updated:
program.log_info((), eval_summary.metric_dict, Mode.EVAL)

if eval_prog is not None:
model_path = os.path.join(args.inference_dir, args.model_arch_name)
paddle.static.save_inference_model(model_path, [eval_feeds['data']], [eval_fetchs['label'][0]], exe, program=eval_prog)


if __name__ == '__main__':
paddle.enable_static()
Expand Down
Loading

0 comments on commit 38934f9

Please sign in to comment.