diff --git a/mmdetection/tools/train.py b/mmdetection/tools/train.py index 7d5df0f..3c7f04f 100644 --- a/mmdetection/tools/train.py +++ b/mmdetection/tools/train.py @@ -205,7 +205,8 @@ def main(): logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) # set num_classes - cfg.model.roi_head.bbox_head.num_classes = len(cfg.classes) + for bbox_head in cfg.model.roi_head.bbox_head: + bbox_head.num_classes = len(cfg.classes) # init the meta dict to record some important information such as # environment info and seed, which will be logged