Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pretraining part #18

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions Mesh-candidate Bestfit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
## Mesh-candidate Bestfit

<p align="center">
<img src="../assets/Mesh-candidate Bestfit.png" width=100%> <br>
<i><small>Mesh-candidate Bestfit iteratively inserts elements from a small set of public datasets by searching for the best match between sampled candidates and the available grids in the current layout, ultimately achieving document synthesis.</i>
</p>

You can generate a large scale of diverse data for pretraining applying our proposed method Mesh-candidate Bestfit, just follow steps below:

### 1. Environment Setup

You need to install [PyMuPDF](https://pypi.org/project/PyMuPDF/1.23.7/) for subsequent rendering via pip:

```bash
cd "Mesh-candidate Bestfit"
pip install pymupdf==1.23.7
```

### 2. Preprocessing

- **Data Preparation**

Two primary things need to be well prepared before starting generation:

1\. **Original Annotation File of Your Dataset**

* The file must be in **JSON format** and follow the **COCO** specification.
* Each instance should have a **unique instance ID**.
* The file should be placed in the `./` directory.

2\. **Element Pool**

You can easily extract elements of different categories based on the original annotation file. However, it is required to be structured like this:

```bash
./element_pool
├── advertisement
│ ├── 727.jpg
│ ├── 919.jpg
│ ├── 1423.jpg
│ └── ...
├── algorithm
│ ├── 12653.jpg
│ ├── 17485.jpg
│ ├── 44364.jpg
│ └── ...
└── ...
```

The first-level subdirectories are named after the **specific categories**, and the elements inside are named with **corresponding instance IDs** in the raw json file of the dataset.

**Note:** For convenience, we provide original annotation file and element pool for M6Doc dataset, which can be downloaded from [annotation file](https://drive.google.com/file/d/1ua41Gs3UW8iuoJp21tZ4-lczVrcEm-gP/view?usp=sharing) and [element pool](https://drive.google.com/file/d/1MrIFObKr1bDGgZLBQM_c_Dvobkp6mjFE/view?usp=sharing), respectively.

- **Data Augmentation(Optional)**

If you want to apply our designed augmentation pipeline to your element pool, you can just run:

```bash
python augmentation.py --min_count 100 --aug_times 50
```

The script will perform augmentation pipeline `aug_times` times on each element of categories whose element number is less than `min_count`.

- **Map Dict**

To facilitate the random selection of candidates during the rendering phase, it is necessary to establish a mapping from elements to all of their candidate paths:

```bash
python map_dict.py --save_path ./map_dict.json --use_aug
```

### 3. Layout Generation

Now, you can generate diverse layouts using Mesh-candidate Bestfit algorithm. To prevent process blocking, it will save the result of each layout in a timely manner, but you can use the [combine_layouts.py](./combine_layouts.py) script to combine them all together like this:

```bash
python bestfit_generator.py --generate_num 10000 --json_path ./M6Doc.json --output_dir ./generated_layouts/seperate
python combine_layouts.py --seperate_layouts_dir ./generate_layouts/seperate --save_path ./generate_layouts/combined_layouts.json
```

Afterwards, feel free to delete the seperate layouts since they are no longer used.

### 4. Rendering

Finally, you can render generated layouts and save the results in yolo format via the script below:

```bash
python rendering.py --json_path ./generate_layouts/combined_layouts.json --map_dict_path ./map_dict.json --save_dir ./generated_dataset
```

### Visualization

We provide [visualize.ipynb](./visualize.ipynb) to visualize the layouts generated by our proposed methods. Here, we display some generation cases below:

<p align="center">
<img src="../assets/visualization.png" width=100%> <br>
</p>
90 changes: 90 additions & 0 deletions Mesh-candidate Bestfit/augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import cv2
import time
import argparse
import numpy as np
from tqdm import tqdm
import albumentations as A


class EdgeDetection(A.ImageOnlyTransform):
"""
A class for edge extraction of images with the sobel filter
"""
def apply(self, img, **params):
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
mag = np.hypot(sobelx, sobely)
mag = mag / np.max(mag) * 255
return np.uint8(mag)


def pipeline(h, w):
"""
Whole data augmentation pipeline with the input of image size

Args:
h (float): Height of the image.
w (float): Width of the image.
"""
return A.Compose([
A.RandomBrightnessContrast(p=0.5),
A.RandomResizedCrop(height=h, width=w, scale=(0.5, 0.9), ratio=(w / h, w / h), p=0.7), # keep h/w ratio the same
EdgeDetection(p=0.2),
A.ElasticTransform(alpha_affine=5, p=0.2),
A.GaussNoise(var_limit=(100, 1200), p=1),
])


def perform_augmentation(img, transform, save_dir, prefix_id, aug_times):
"""
Perform augmentation for a single element with many times.

Args:
img (image): An elment.
transform (sequence): Data augmentation pipeline.
save_dir (str): Root directory to save.
prefix_id (str): Raw id for the element.
aug_times (int): Augmentation times.
"""
for _ in range(aug_times):
transformed = transform(image=img)
transformed_image = transformed["image"]
transformed_image_bgr = cv2.cvtColor(transformed_image, cv2.COLOR_RGB2BGR)

suffix_id = str(time.time()).replace(".", "_")
prefix_id_dir = os.path.join(save_dir, prefix_id)
os.makedirs(prefix_id_dir, exist_ok=True)
aug_element_path = os.path.join(prefix_id_dir, f'{prefix_id}_{suffix_id}.jpg')
cv2.imwrite(aug_element_path, transformed_image_bgr)



if __name__ == "__main__":

parser = argparse.ArgumentParser(description="Perform Image Augmentation")
parser.add_argument("--min_count", type=int, default=100, help="Minimum number of elements for categories that do not require data augmentation")
parser.add_argument("--aug_times", type=int, default=50, help="Number of augmentations per element")
args = parser.parse_args()

root_dir = './element_pool'

for category in tqdm(os.listdir(root_dir),desc='Categories done'):
category_dir = os.path.join(root_dir,category)
if os.listdir(category_dir) <= args.min_count:
continue
else:
save_dir = os.path.join(category_dir, 'aug')
os.makedirs(save_dir, exist_ok=True)
for raw_element in os.listdir(category_dir):
raw_element_path = os.path.join(category_dir, raw_element)

img = cv2.imread(raw_element_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w, c = img.shape
element_id = raw_element.split('.')[0]

transform = pipeline(h, w)

perform_augmentation(img=img,transform=transform,save_dir=save_dir,prefix_id=element_id,aug_times=args.aug_times)
142 changes: 142 additions & 0 deletions Mesh-candidate Bestfit/bestfit_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import os
import json
import time
import torch
import random
import datetime
import argparse
import itertools
import torchvision
import multiprocessing
from utils.process import *

random.seed(datetime.datetime.now().timestamp())


def bestfit_generator(element_all):
"""
Apply the Mesh-candidate Bestfit algorithm to generate diverse layouts.

Args:
element_all (dict): Loaded elements from dataset json file.
output_dir (str): Directory to save the generated layouts.
"""
# Default candidate_num = 500
candidate_num = 500
large_elements_idx = random.sample(list(range(len(element_all['large']))), int(candidate_num*0.99))
small_elements_idx = random.sample(list(range(len(element_all['small']))), int(candidate_num*0.01))
cand_elements = [element_all['large'][large_idx] for large_idx in large_elements_idx] + [element_all['small'][small_idx] for small_idx in small_elements_idx]

# Initially, randomly put an element
put_elements = []
e0 = random.choice(cand_elements)
cx = random.uniform(min(e0.w/2, 1-e0.w/2), max(e0.w/2, 1-e0.w/2))
cy = random.uniform(min(e0.h/2, 1-e0.h/2), max(e0.h/2, 1-e0.h/2))
e0.cx, e0.cy = cx, cy
put_elements = [e0]
cand_elements.remove(e0)
small_cnt = 1 if e0.w < 0.05 or e0.h < 0.05 else 0

# Iterativelly insert elements
while True:
# Construct meshgrid based on current layout
put_element_boxes = []
xticks, yticks = [0,1], [0,1]
for e in put_elements:
x1, y1, x2, y2 = e.cx-e.w/2, e.cy-e.h/2, e.cx+e.w/2, e.cy+e.h/2
xticks.append(x1)
xticks.append(x2)
yticks.append(y1)
yticks.append(y2)
put_element_boxes.append([x1, y1, x2, y2])
xticks, yticks = list(set(xticks)), list(set(yticks))
pticks = list(itertools.product(xticks, yticks))
meshgrid = list(itertools.product(pticks, pticks))
put_element_boxes = torch.Tensor(put_element_boxes)

# Filter out invlid grids
meshgrid = [grid for grid in meshgrid if grid[0][0] < grid[1][0] and grid[0][1] < grid[1][1]]
meshgrid_tensor = torch.Tensor([p1 + p2 for p1, p2 in meshgrid])
iou_res = torchvision.ops.box_iou(meshgrid_tensor, put_element_boxes)
valid_grid_idx = (iou_res.sum(dim=1) == 0).nonzero().flatten().tolist()
meshgrid = meshgrid_tensor[valid_grid_idx].tolist()

# Search for the Mesh-candidate Bestfit pair
max_fill, max_grid_idx, max_element_idx = 0, -1, -1
for element_idx, e in enumerate(cand_elements):
for grid_idx, grid in enumerate(meshgrid):
if e.w > grid[2] - grid[0] or e.h > grid[3] - grid[1]:
continue
element_area = e.w * e.h
grid_area = (grid[2] - grid[0]) * (grid[3] - grid[1])
if element_area/grid_area > max_fill:
max_fill = element_area/grid_area
max_grid_idx = grid_idx
max_element_idx = element_idx

# Termination condition
if max_element_idx == -1 or max_grid_idx == -1:
break
else:
maxfit_element = cand_elements[max_element_idx]
if maxfit_element.w < 0.05 or maxfit_element.h < 0.05:
small_cnt += 1
if small_cnt > 5:
break
else:
pass

# Put the candidate to the center of the grid
cand_elements.remove(maxfit_element)
maxfit_element.cx = (meshgrid[max_grid_idx][0] + meshgrid[max_grid_idx][2])/2
maxfit_element.cy = (meshgrid[max_grid_idx][1] + meshgrid[max_grid_idx][3])/2
put_elements.append(maxfit_element)

# Apply a rescale transform to introduce more diversity
for _, e in enumerate(put_elements):
e.gen_real_bbox()
layout = Layout(cand_elements=put_elements)

# Convert the layout to json file format
boxes, categories, relpaths = [], [], []
for element in layout.cand_elements:
cx, cy, w, h = element.get_real_bbox()
x1, y1, x2, y2 = cx-w/2, cy-h/2, cx+w/2, cy+h/2
boxes.append([x1, y1, x2, y2])
categories.append(element.category-1) # Exclude the "__background__" category (category_id = 0)
relpaths.append(element.filepath)

output_layout = {
"boxes": boxes,
"categories": categories,
"relpaths": relpaths
}

# To prevent process blocking, save the result of each layout in a timely manner.
with open(os.path.join(OUTPUT_DIR,str(time.time()).replace(".", "_")+'.json'),'w') as f:
json.dump(output_layout, f)

return output_layout



if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument('--generate_num', default=None, required=True, type=int, help='number of layouts to generate')
parser.add_argument('--json_path', default=None, required=True, type=str, help='original json file of the dataset')
parser.add_argument('--output_dir', default='./generated_layouts/seperate', type=str, help='output directory of generated seperate layouts')
args = parser.parse_args()

element_all = read_data(args.json_path)
OUTPUT_DIR = args.output_dir
os.makedirs(OUTPUT_DIR,exist_ok=True)

# Using multiprocessing to accelerate generation
n_jobs = 100
with multiprocessing.Pool(n_jobs) as p:
generated_layout = p.starmap(
bestfit_generator, [(element_all,) for _ in range(args.generate_num)]
)
p.close()
p.join()
32 changes: 32 additions & 0 deletions Mesh-candidate Bestfit/combine_layouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import json
import argparse
from tqdm import tqdm


def combine_layouts(seperate_layouts_dir):
"""
Combining seperate layouts into one json.

Args:
seperate_layouts_dir (str): Directory to save seperate layouts json files generated by bestfit_generator.py
"""
combined_layouts = []
for item in tqdm(os.listdir(seperate_layouts_dir),desc='Combining seperate layouts'):
abs_path = os.path.join(seperate_layouts_dir,item)
json_file = json.load(open(abs_path))
combined_layouts.append(json_file)
return combined_layouts


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument('--seperate_layouts_dir', default="./generated_layouts/seperate", type=str, help="directory to save seperate layouts")
parser.add_argument('--save_path', default="./generated_layouts/combined_layouts.json", type=str, help='save path for combined generated layouts')
args = parser.parse_args()

combined_layouts = combine_layouts(seperate_layouts_dir=args.seperate_layouts_dir)

with open(args.save_path,'w') as f:
f.write(json.dumps(combined_layouts,indent=4))
Loading