You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Copyright (c) OpenMMLab. All rights reserved.importargparseimportos.pathasospimportmmcvimportnumpyasnpprog_description='''K-Fold coco split.To split coco data for semi-supervised object detection: python tools/misc/split_coco.py'''defparse_args():
parser=argparse.ArgumentParser()
parser.add_argument(
'--data-root',
type=str,
help='The data root of coco dataset.',
default='./data/coco/')
parser.add_argument(
'--out-dir',
type=str,
help='The output directory of coco semi-supervised annotations.',
default='./data/coco_semi_annos/')
parser.add_argument(
'--labeled-percent',
type=float,
nargs='+',
help='The percentage of labeled data in the training set.',
default=[1, 2, 5, 10])
parser.add_argument(
'--fold',
type=int,
help='K-fold cross validation for semi-supervised object detection.',
default=5)
args=parser.parse_args()
returnargsdefsplit_coco(data_root, out_dir, percent, fold):
"""Split COCO data for Semi-supervised object detection. Args: data_root (str): The data root of coco dataset. out_dir (str): The output directory of coco semi-supervised annotations. percent (float): The percentage of labeled data in the training set. fold (int): The fold of dataset and set as random seed for data split. """defsave_anns(name, images, annotations):
sub_anns=dict()
sub_anns['images'] =imagessub_anns['annotations'] =annotationssub_anns['licenses'] =anns['licenses']
sub_anns['categories'] =anns['categories']
sub_anns['info'] =anns['info']
mmcv.mkdir_or_exist(out_dir)
mmcv.dump(sub_anns, f'{out_dir}/{name}.json')
# set random seed with the foldnp.random.seed(fold)
ann_file=osp.join(data_root, 'annotations/instances_train2017.json')
anns=mmcv.load(ann_file)
image_list=anns['images']
labeled_total=int(percent/100.*len(image_list))
labeled_inds=set(
np.random.choice(range(len(image_list)), size=labeled_total))
labeled_ids, labeled_images, unlabeled_images= [], [], []
foriinrange(len(image_list)):
ifiinlabeled_inds:
labeled_images.append(image_list[i])
labeled_ids.append(image_list[i]['id'])
else:
unlabeled_images.append(image_list[i])
# get all annotations of labeled imageslabeled_ids=set(labeled_ids)
labeled_annotations, unlabeled_annotations= [], []
foranninanns['annotations']:
ifann['image_id'] inlabeled_ids:
labeled_annotations.append(ann)
else:
unlabeled_annotations.append(ann)
# save labeled and unlabeledlabeled_name=f'instances_train2017.{fold}@{percent}'unlabeled_name=f'instances_train2017.{fold}@{percent}-unlabeled'save_anns(labeled_name, labeled_images, labeled_annotations)
save_anns(unlabeled_name, unlabeled_images, unlabeled_annotations)
defmulti_wrapper(args):
returnsplit_coco(*args)
if__name__=='__main__':
args=parse_args()
arguments_list= [(args.data_root, args.out_dir, p, f)
forfinrange(1, args.fold+1)
forpinargs.labeled_percent]
mmcv.track_parallel_progress(multi_wrapper, arguments_list, args.fold)
The text was updated successfully, but these errors were encountered:
https://github.com/open-mmlab/mmdetection/blob/master/tools/misc/split_coco.py
The text was updated successfully, but these errors were encountered: