Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

提供“多线程加速下载ocr_vqa数据集python脚本” #320

Open
tdye24 opened this issue Jul 2, 2024 · 0 comments
Open

提供“多线程加速下载ocr_vqa数据集python脚本” #320

tdye24 opened this issue Jul 2, 2024 · 0 comments

Comments

@tdye24
Copy link

tdye24 commented Jul 2, 2024

这是我写的一个简单脚本,加速ocr_vqa数据集的下载。

import os
import time
import json
import urllib.request as ureq
from multiprocessing.pool import ThreadPool

def download_image(url, filename):
    try:
        ureq.urlretrieve(url=url, filename=filename)
        # print("download image successfully:{}".format(url))
        return filename
    except Exception as e:
        print(e)
        return None

def download_image_thread(url_lst, filename_lst, num_processes):
    pool = ThreadPool(processes=num_processes)
    thread_list = []
    for (image_url, filename) in zip(url_lst, filename_lst):
        out = pool.apply_async(func=download_image, args=(image_url, filename))
        thread_list.append(out)
    pool.close()
    pool.join()

    return_lst = []
    for p in thread_list:
        image = p.get()
        return_lst.append(image)

    failed_lst = [i for i in return_lst if i is not None]
    return failed_lst


if __name__ == "__main__":
    os.mkdir('./images')

    with open('OCR-VQA-200K/dataset.json', 'r') as fp:
        data = json.load(fp)
        url_lst = []
        filename_lst = []
        for k in data.keys():
            url_lst.append(data[k]['imageURL'])
            ext = os.path.splitext(data[k]['imageURL'])[1]
            filename = 'images/%s%s' % (k, ext)
            filename_lst.append(filename)
    start_time = time.time()
    failed_lst = download_image_thread(url_lst=url_lst, filename_lst=filename_lst, num_processes=64)
    end_time = time.time()
    consume_time = end_time - start_time
    print("Consume time: " + str(consume_time) + " seconds.")
    print("Failed image list", failed_lst)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant