-
Notifications
You must be signed in to change notification settings - Fork 0
/
ksdd2_preprocess.py
70 lines (63 loc) · 2.64 KB
/
ksdd2_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import cv2
from glob import glob
import pandas as pd
import argparse
from tqdm import tqdm
def reshape_ksdd2(src_dir, dst_dir, RES=(224, 632)):
# make dest directory
splits = ['train', 'test']
for split in splits:
src_split_dir = os.path.join(src_dir, split)
dst_split_dir = os.path.join(dst_dir, split)
os.makedirs(dst_split_dir, exist_ok=True)
all_imgs = os.listdir(src_split_dir)
for img in tqdm(all_imgs, desc=f"Reshaping {split} images", unit="file", total=len(all_imgs)):
img_path = os.path.join(src_split_dir, img)
img_out_path = os.path.join(dst_split_dir, img)
img = cv2.imread(img_path)
img = cv2.resize(img, RES)
cv2.imwrite(img_out_path, img)
def copy_files(src_dir, dst_dir):
target_files = glob(os.path.join(src_dir, '*.pyb'))
for file in tqdm(target_files, desc="Copying .pyb files", unit="file", total=len(target_files)):
file_name = os.path.basename(file)
dst_file = os.path.join(dst_dir, file_name)
os.system(f'cp {file} {dst_file}')
def make_csv(dst_dir):
splits = ['train', 'test']
for split in splits:
img_dir = os.path.join(dst_dir, split)
all_imgs = os.listdir(img_dir)
all_masks = [img for img in all_imgs if "GT" in img]
imgs_dict = {"path": [], "label": []}
for img in tqdm(all_masks, desc=f"Creating {split}.csv", unit="file", total=len(all_masks)):
imgs_dict["path"].append(img.replace("_GT.png", ".png"))
img_path = os.path.join(img_dir, img)
loaded = cv2.imread(img_path)
# if there is a 1, it is positive, else negative
if max(loaded.flatten()) == 0:
imgs_dict["label"].append("negative")
else:
imgs_dict["label"].append("positive")
df = pd.DataFrame(imgs_dict)
df.to_csv(os.path.join(dst_dir, f"{split}.csv"), index=False)
def main(args):
src_dir = args.src_dir
dst_dir = args.dst_dir
RES = (224,632) # w x h
# make directory
print(f"Copying files from {src_dir} to {dst_dir}")
os.makedirs(dst_dir, exist_ok=True)
# copy .pyb files
copy_files(src_dir, dst_dir)
# reshape images (needed for batching)
reshape_ksdd2(src_dir, dst_dir, RES=RES)
# make csv files
make_csv(dst_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src_dir", type=str, required=True, help="Path to the KSDD2 dataset root")
parser.add_argument("--dst_dir", type=str, required=True, help="Path to the destination directory")
args = parser.parse_args()
main(args)