From 9ef4c59eba9314381e9735708319d905ac2b6694 Mon Sep 17 00:00:00 2001 From: prickly-u Date: Mon, 17 Feb 2020 13:05:27 +0300 Subject: [PATCH] Fix image preprocessing in evaluate.py --- keras_retinanet/bin/evaluate.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/keras_retinanet/bin/evaluate.py b/keras_retinanet/bin/evaluate.py index 1a10f23..3c03aa1 100755 --- a/keras_retinanet/bin/evaluate.py +++ b/keras_retinanet/bin/evaluate.py @@ -28,6 +28,7 @@ from .. import models from ..preprocessing.csv_generator import CSVGenerator from ..preprocessing.pascal_voc import PascalVocGenerator +from ..utils.anchors import make_shapes_callback from ..utils.config import read_config_file, parse_anchor_parameters from ..utils.eval import evaluate from ..utils.gpu import setup_gpu @@ -35,9 +36,13 @@ from ..utils.tf_version import check_tf_version -def create_generator(args): +def create_generator(args, preprocess_image): """ Create generators for evaluation. """ + common_args = { + 'preprocess_image' : preprocess_image, + } + if args.dataset_type == 'coco': # import here to prevent unnecessary dependency on cocoapi from ..preprocessing.coco import CocoGenerator @@ -49,6 +54,7 @@ def create_generator(args): image_max_side=args.image_max_side, config=args.config, shuffle_groups=False, + **common_args ) elif args.dataset_type == 'pascal': validation_generator = PascalVocGenerator( @@ -58,6 +64,7 @@ def create_generator(args): image_max_side=args.image_max_side, config=args.config, shuffle_groups=False, + **common_args ) elif args.dataset_type == 'csv': validation_generator = CSVGenerator( @@ -67,6 +74,7 @@ def create_generator(args): image_max_side=args.image_max_side, config=args.config, shuffle_groups=False, + **common_args ) else: raise ValueError('Invalid data type received: {}'.format(args.dataset_type)) @@ -129,7 +137,8 @@ def main(args=None): args.config = read_config_file(args.config) # create the generator - generator = create_generator(args) + backbone = models.backbone(args.backbone) + generator = create_generator(args, backbone.preprocess_image) # optionally load anchor parameters anchor_params = None @@ -140,6 +149,8 @@ def main(args=None): print('Loading model, this may take a second...') model = models.load_model(args.model, backbone_name=args.backbone) + generator.compute_shapes = make_shapes_callback(model) + # optionally convert the model if args.convert_model: model = models.convert_model(model, anchor_params=anchor_params)