-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
test time #26
Comments
Sorry for the late reply. 对于coco14 val 集,在 test 阶段我们花费了约2.5h 在 crf 阶段,我们重新实现了多进程部分并得到 你可以调整参数
import torch
import numpy as np
import time
import os
import torch.nn.functional as F
import multiprocessing as mp
from multiprocessing import Process
from omegaconf import OmegaConf
import json
import argparse
from tqdm import tqdm
from libs.utils import DenseCRF, PolynomialLR, scores
from main_v2 import get_dataset, makedirs
def process_crf(i, dataset, logit_dir, postprocessor):
image_id, image, gt_label = dataset.__getitem__(i)
filename = os.path.join(logit_dir, image_id + ".npy")
logit = np.load(filename)
_, H, W = image.shape
logit = torch.FloatTensor(logit)[None, ...]
logit = F.interpolate(logit, size=(H, W), mode="bilinear", align_corners=False)
prob = F.softmax(logit, dim=1)[0].numpy()
image = image.astype(np.uint8).transpose(1, 2, 0)
prob = postprocessor(image, prob)
label = np.argmax(prob, axis=0)
return label, gt_label
def crf(dataset, logit_dir, postprocessor, num_workers=4):
print("CRF post-processing")
pbar = tqdm(total=len(dataset), desc="CRF post-processing", ascii=True)
def update(*a):
pbar.update()
pool = mp.Pool(num_workers)
results = []
for i in range(len(dataset)):
results.append(pool.apply_async(process_crf,
args=(i, dataset, logit_dir, postprocessor),
callback=update))
pool.close()
pool.join()
results = [r.get() for r in results]
print("CRF post-processing finished")
# print("Results:", results)
return results
def main(config_path, n_jobs):
# Configuration
CONFIG = OmegaConf.load(config_path)
torch.set_grad_enabled(False)
print("# jobs:", n_jobs)
# Dataset
dataset = get_dataset(CONFIG.DATASET.NAME)(
root=CONFIG.DATASET.ROOT,
split=CONFIG.DATASET.SPLIT.VAL,
ignore_label=CONFIG.DATASET.IGNORE_LABEL,
mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
augment=False,
)
print(dataset)
# CRF post-processor
postprocessor = DenseCRF(
iter_max=CONFIG.CRF.ITER_MAX,
pos_xy_std=CONFIG.CRF.POS_XY_STD,
pos_w=CONFIG.CRF.POS_W,
bi_xy_std=CONFIG.CRF.BI_XY_STD,
bi_rgb_std=CONFIG.CRF.BI_RGB_STD,
bi_w=CONFIG.CRF.BI_W,
)
# Path to logit files
logit_dir = os.path.join(
CONFIG.EXP.OUTPUT_DIR,
"features",
CONFIG.EXP.ID,
CONFIG.MODEL.NAME.lower(),
CONFIG.DATASET.SPLIT.VAL,
"logit",
)
print("Logit src:", logit_dir)
if not os.path.isdir(logit_dir):
print("Logit not found, run first: python main.py test [OPTIONS]")
quit()
# Path to save scores
save_dir = os.path.join(
CONFIG.EXP.OUTPUT_DIR,
"scores",
CONFIG.EXP.ID,
CONFIG.MODEL.NAME.lower(),
CONFIG.DATASET.SPLIT.VAL,
)
makedirs(save_dir)
save_path = os.path.join(save_dir, "scores_crf_coco.json")
print("Score dst:", save_path)
# CRF
results = crf(dataset, logit_dir, postprocessor, num_workers=n_jobs)
# Evaluation
preds, gts = zip(*results)
# Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU
score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES)
print(f'mIoU: {score["Mean IoU"]}')
with open(save_path, "w") as f:
json.dump(score, f, indent=4, sort_keys=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("config_path", type=str)
parser.add_argument("--n_jobs", type=int, default=4)
args = parser.parse_args()
main(args.config_path, args.n_jobs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
您好,非常感谢您的代码,为我的工作提供了很多帮助。
请问使用您提供的deeplabv2以及后处理的代码,在coco2014 val数据集上测试大概花费多长时间?
The text was updated successfully, but these errors were encountered: