Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading dataset: save to image and json. #7

Open
DidiD1 opened this issue Jun 25, 2024 · 1 comment
Open

Loading dataset: save to image and json. #7

DidiD1 opened this issue Jun 25, 2024 · 1 comment

Comments

@DidiD1
Copy link

DidiD1 commented Jun 25, 2024

A simple demo to read the dataset.

import base64
import json
import pickle 
import tensorflow as tf
from datasets import load_dataset
import os.path
from tqdm import tqdm

def parse_tfrecord(record):
    example = tf.train.Example()
    example.ParseFromString(record.numpy())
    return example

def build_lookup_table(pickapic):
    lookup_table = {}
    for item in tqdm(pickapic['train']):
        lookup_table[item['image_0_uid']] = item
        lookup_table[item['image_1_uid']] = item
    return lookup_table

def get_caption_and_images(lookup_table, filename):
    results = []
    if filename:
        uid = filename.split('/')[-1].split('.')[0]
        if uid in lookup_table:
            matching_item = lookup_table[uid]
            if matching_item['image_0_uid'] == uid:
                caption = matching_item.get('caption', '')
                image = matching_item.get('jpg_0', '')
            elif matching_item['image_1_uid'] == uid:
                caption = matching_item.get('caption', '')
                image = matching_item.get('jpg_1', '')
            results.append({
                'filename': uid, 
                'caption': caption, 
                'jpg': image, 
            })
    return results

def read_tfrecord_file(file_path):
    raw_dataset = tf.data.TFRecordDataset(file_path)
    i = 0
    now_dir = os.path.dirname(os.path.abspath(__file__))
    target_dir_path = os.path.join(now_dir, "train")
    pickapic = load_dataset("yuvalkirstain/pickapic_v1", num_proc=64)
    lookup_table = build_lookup_table(pickapic)

    for raw_record in tqdm(raw_dataset):
        save_path = os.path.join(target_dir_path, str(i))
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        example = parse_tfrecord(raw_record)
        record = {}
        for key, value in example.features.feature.items():
            if value.bytes_list.value:
                try:
                    record[key] = value.bytes_list.value[0].decode('utf-8')
                except UnicodeDecodeError:
                    with open(save_path + '/' + str(key)+'.jpg', 'wb') as f:
                        f.write(value.bytes_list.value[0])
                    record[key] = str(key)+'.jpg'
            elif value.float_list.value:
                record[key] = value.float_list.value[0]
            elif value.int64_list.value:
                record[key] = value.int64_list.value[0]


    results_richhf = get_caption_and_images(lookup_table, record['filename'])
    with open(save_path + '/image.jpg', 'wb') as f:
      f.write(results_richhf[0]['jpg'])
        
    record['caption'] = results_richhf[0]['caption']
    json_records = json.dumps(record, indent=4)
    with open(save_path + '/output.json', 'w') as json_file:
      json_file.write(json_records)
    i += 1

file_path = " " 
records = read_tfrecord_file(file_path)
@ucasyjz
Copy link

ucasyjz commented Aug 1, 2024

there will be a problem, RuntimeError: One of the subprocesses has abruptly died during map operation.To debug the error, disable multiprocessing. How can i solve this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants