Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test time #26

Open
Sunny599 opened this issue Aug 18, 2023 · 2 comments
Open

test time #26

Sunny599 opened this issue Aug 18, 2023 · 2 comments

Comments

@Sunny599
Copy link

您好,非常感谢您的代码,为我的工作提供了很多帮助。
请问使用您提供的deeplabv2以及后处理的代码,在coco2014 val数据集上测试大概花费多长时间?

@Sierkinhane
Copy link
Member

@Tiiiktak

@Tiiiktak
Copy link

Sorry for the late reply.

对于coco14 val 集,在 test 阶段我们花费了约2.5h

在 crf 阶段,我们重新实现了多进程部分并得到 crf_coco.py
n_jobs=16 ,用时约30h

你可以调整参数 --n jobs 寻求进一步提速

crf_coco.py:

import torch 
import numpy as np 
import time 
import os 
import torch.nn.functional as F
import multiprocessing as mp
from multiprocessing import Process
from omegaconf import OmegaConf
import json 
import argparse
from tqdm import tqdm

from libs.utils import DenseCRF, PolynomialLR, scores

from main_v2 import get_dataset, makedirs


def process_crf(i, dataset, logit_dir, postprocessor):
    image_id, image, gt_label = dataset.__getitem__(i)
    filename = os.path.join(logit_dir, image_id + ".npy")
    logit = np.load(filename)
    _, H, W = image.shape
    logit = torch.FloatTensor(logit)[None, ...]
    logit = F.interpolate(logit, size=(H, W), mode="bilinear", align_corners=False)
    prob = F.softmax(logit, dim=1)[0].numpy()

    image = image.astype(np.uint8).transpose(1, 2, 0)
    prob = postprocessor(image, prob)
    label = np.argmax(prob, axis=0)
    return label, gt_label
    
def crf(dataset, logit_dir, postprocessor, num_workers=4):
    print("CRF post-processing")
    pbar = tqdm(total=len(dataset), desc="CRF post-processing", ascii=True)
    def update(*a):
        pbar.update()
    pool = mp.Pool(num_workers)
    results = []
    for i in range(len(dataset)):
        results.append(pool.apply_async(process_crf, 
        args=(i, dataset, logit_dir, postprocessor),
        callback=update))
    pool.close()
    pool.join()
    results = [r.get() for r in results]

    print("CRF post-processing finished")
    # print("Results:", results)
    return results

def main(config_path, n_jobs):
    # Configuration
    CONFIG = OmegaConf.load(config_path)
    torch.set_grad_enabled(False)
    print("# jobs:", n_jobs)

    # Dataset
    dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.VAL,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        augment=False,
    )
    print(dataset)

    # CRF post-processor
    postprocessor = DenseCRF(
        iter_max=CONFIG.CRF.ITER_MAX,
        pos_xy_std=CONFIG.CRF.POS_XY_STD,
        pos_w=CONFIG.CRF.POS_W,
        bi_xy_std=CONFIG.CRF.BI_XY_STD,
        bi_rgb_std=CONFIG.CRF.BI_RGB_STD,
        bi_w=CONFIG.CRF.BI_W,
    )

    # Path to logit files
    logit_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "features",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
        "logit",
    )
    print("Logit src:", logit_dir)
    if not os.path.isdir(logit_dir):
        print("Logit not found, run first: python main.py test [OPTIONS]")
        quit()

    # Path to save scores
    save_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "scores",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.VAL,
    )
    makedirs(save_dir)
    save_path = os.path.join(save_dir, "scores_crf_coco.json")
    print("Score dst:", save_path)

    # CRF
    results = crf(dataset, logit_dir, postprocessor, num_workers=n_jobs)
    # Evaluation
    preds, gts = zip(*results)

    # Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU
    score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES)
    print(f'mIoU: {score["Mean IoU"]}')
    with open(save_path, "w") as f:
        json.dump(score, f, indent=4, sort_keys=True)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("config_path", type=str)
    parser.add_argument("--n_jobs", type=int, default=4)
    args = parser.parse_args()

    main(args.config_path, args.n_jobs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants