Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code to split the COCO dataset #19

Open
LimePencil opened this issue Nov 21, 2022 · 0 comments
Open

Code to split the COCO dataset #19

LimePencil opened this issue Nov 21, 2022 · 0 comments
Labels
documentation Improvements or additions to documentation

Comments

@LimePencil
Copy link
Contributor

https://github.com/open-mmlab/mmdetection/blob/master/tools/misc/split_coco.py

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp

import mmcv
import numpy as np

prog_description = '''K-Fold coco split.

To split coco data for semi-supervised object detection:
    python tools/misc/split_coco.py
'''


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--data-root',
        type=str,
        help='The data root of coco dataset.',
        default='./data/coco/')
    parser.add_argument(
        '--out-dir',
        type=str,
        help='The output directory of coco semi-supervised annotations.',
        default='./data/coco_semi_annos/')
    parser.add_argument(
        '--labeled-percent',
        type=float,
        nargs='+',
        help='The percentage of labeled data in the training set.',
        default=[1, 2, 5, 10])
    parser.add_argument(
        '--fold',
        type=int,
        help='K-fold cross validation for semi-supervised object detection.',
        default=5)
    args = parser.parse_args()
    return args


def split_coco(data_root, out_dir, percent, fold):
    """Split COCO data for Semi-supervised object detection.

    Args:
        data_root (str): The data root of coco dataset.
        out_dir (str): The output directory of coco semi-supervised
            annotations.
        percent (float): The percentage of labeled data in the training set.
        fold (int): The fold of dataset and set as random seed for data split.
    """

    def save_anns(name, images, annotations):
        sub_anns = dict()
        sub_anns['images'] = images
        sub_anns['annotations'] = annotations
        sub_anns['licenses'] = anns['licenses']
        sub_anns['categories'] = anns['categories']
        sub_anns['info'] = anns['info']

        mmcv.mkdir_or_exist(out_dir)
        mmcv.dump(sub_anns, f'{out_dir}/{name}.json')

    # set random seed with the fold
    np.random.seed(fold)
    ann_file = osp.join(data_root, 'annotations/instances_train2017.json')
    anns = mmcv.load(ann_file)

    image_list = anns['images']
    labeled_total = int(percent / 100. * len(image_list))
    labeled_inds = set(
        np.random.choice(range(len(image_list)), size=labeled_total))
    labeled_ids, labeled_images, unlabeled_images = [], [], []

    for i in range(len(image_list)):
        if i in labeled_inds:
            labeled_images.append(image_list[i])
            labeled_ids.append(image_list[i]['id'])
        else:
            unlabeled_images.append(image_list[i])

    # get all annotations of labeled images
    labeled_ids = set(labeled_ids)
    labeled_annotations, unlabeled_annotations = [], []

    for ann in anns['annotations']:
        if ann['image_id'] in labeled_ids:
            labeled_annotations.append(ann)
        else:
            unlabeled_annotations.append(ann)

    # save labeled and unlabeled
    labeled_name = f'instances_train2017.{fold}@{percent}'
    unlabeled_name = f'instances_train2017.{fold}@{percent}-unlabeled'

    save_anns(labeled_name, labeled_images, labeled_annotations)
    save_anns(unlabeled_name, unlabeled_images, unlabeled_annotations)


def multi_wrapper(args):
    return split_coco(*args)


if __name__ == '__main__':
    args = parse_args()
    arguments_list = [(args.data_root, args.out_dir, p, f)
                      for f in range(1, args.fold + 1)
                      for p in args.labeled_percent]
    mmcv.track_parallel_progress(multi_wrapper, arguments_list, args.fold)
@LimePencil LimePencil added the documentation Improvements or additions to documentation label Nov 21, 2022
@LimePencil LimePencil changed the title Code to splitting the COCO dataset Code to split the COCO dataset Nov 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

1 participant