Skip to content

Commit

Permalink
fix open-mmlab#253: speedup; fix assert
Browse files Browse the repository at this point in the history
  • Loading branch information
cuhk-hbsun committed Jun 6, 2021
1 parent 70715ce commit 92c57aa
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 27 deletions.
49 changes: 30 additions & 19 deletions mmocr/datasets/pipelines/dbnet_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,36 +84,47 @@ def __call__(self, results):
def may_augment_annotation(self, aug, shape, target_shape, results):
if aug is None:
return results

# augment polygon mask
for key in results['mask_fields']:
# augment polygon mask
masks = []
for mask in results[key]:
masks.append(
[self.may_augment_poly(aug, shape, target_shape, mask[0])])
masks = self.may_augment_poly(aug, shape, results[key])
if len(masks) > 0:
results[key] = PolygonMasks(masks, *target_shape[:2])

# augment bbox
for key in results['bbox_fields']:
# augment bbox
bboxes = []
for bbox in results[key]:
bbox = self.may_augment_poly(aug, shape, target_shape, bbox)
bboxes.append(bbox)
bboxes = self.may_augment_poly(
aug, shape, results[key], mask_flag=False)
results[key] = np.zeros(0)
if len(bboxes) > 0:
results[key] = np.stack(bboxes)

return results

def may_augment_poly(self, aug, img_shape, target_shape, poly):
# poly n x 2
poly = poly.reshape(-1, 2)
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
keypoints = aug.augment_keypoints(
[imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints
poly = [[p.x, p.y] for p in keypoints]
poly = np.array(poly).flatten()
return poly
def may_augment_poly(self, aug, img_shape, polys, mask_flag=True):
key_points, poly_point_nums = [], []
for poly in polys:
if mask_flag:
poly = poly[0]
poly = poly.reshape(-1, 2)
key_points.extend([imgaug.Keypoint(p[0], p[1]) for p in poly])
poly_point_nums.append(poly.shape[0])
key_points = aug.augment_keypoints(
[imgaug.KeypointsOnImage(keypoints=key_points,
shape=img_shape)])[0].keypoints

new_polys = []
start_idx = 0
for poly_point_num in poly_point_nums:
new_poly = []
for key_point in key_points[start_idx:(start_idx +
poly_point_num)]:
new_poly.append([key_point.x, key_point.y])
start_idx += poly_point_num
new_poly = np.array(new_poly).flatten()
new_polys.append([new_poly] if mask_flag else new_poly)

return new_polys

def __repr__(self):
repr_str = self.__class__.__name__
Expand Down
34 changes: 26 additions & 8 deletions mmocr/datasets/pipelines/textdet_targets/dbnet_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def ignore_texts(self, results, ignore_tags):
mask for i, mask in enumerate(results['gt_labels'])
if not ignore_tags[i]
])
new_ignore_tags = [ignore for ignore in ignore_tags if not ignore]

return results
return results, new_ignore_tags

def generate_thr_map(self, img_size, polygons):
"""Generate threshold map.
Expand Down Expand Up @@ -149,12 +150,15 @@ def draw_border_map(self, polygon, canvas, mask):
else:
print(f'padding {polygon} with {distance} gets {padded_polygon}')
padded_polygon = polygon.copy().astype(np.int32)
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)

x_min = padded_polygon[:, 0].min()
x_max = padded_polygon[:, 0].max()
y_min = padded_polygon[:, 1].min()
y_max = padded_polygon[:, 1].max()

if x_max <= 0 or y_max <= 0:
return

width = x_max - x_min + 1
height = y_max - y_min + 1

Expand All @@ -180,6 +184,16 @@ def draw_border_map(self, polygon, canvas, mask):
x_max_valid = min(max(0, x_max), canvas.shape[1] - 1)
y_min_valid = min(max(0, y_min), canvas.shape[0] - 1)
y_max_valid = min(max(0, y_max), canvas.shape[0] - 1)

if x_min_valid - x_min >= distance_map.shape[
1] or y_min_valid - y_min >= distance_map.shape[0]:
return
if x_max_valid - x_max + width <= 0:
return
if y_max_valid - y_max + height <= 0:
return

cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
canvas[y_min_valid:y_max_valid + 1,
x_min_valid:x_max_valid + 1] = np.fmax(
1 - distance_map[y_min_valid - y_min:y_max_valid - y_max +
Expand All @@ -198,25 +212,29 @@ def generate_targets(self, results):
results (dict): The output result dictionary.
"""
assert isinstance(results, dict)
polygons = results['gt_masks'].masks

if 'bbox_fields' in results:
results['bbox_fields'].clear()

ignore_tags = self.find_invalid(results)
results, ignore_tags = self.ignore_texts(results, ignore_tags)

h, w, _ = results['img_shape']
polygons = results['gt_masks'].masks

# generate gt_shrink_kernel
gt_shrink, ignore_tags = self.generate_kernels((h, w),
polygons,
self.shrink_ratio,
ignore_tags=ignore_tags)

results = self.ignore_texts(results, ignore_tags)

# polygons and polygons_ignore reassignment.
polygons = results['gt_masks'].masks
results, ignore_tags = self.ignore_texts(results, ignore_tags)
# genenrate gt_shrink_mask
polygons_ignore = results['gt_masks_ignore'].masks

gt_shrink_mask = self.generate_effective_mask((h, w), polygons_ignore)

# generate gt_threshold and gt_threshold_mask
polygons = results['gt_masks'].masks
gt_thr, gt_thr_mask = self.generate_thr_map((h, w), polygons)

results['mask_fields'].clear() # rm gt_masks encoded by polygons
Expand Down

0 comments on commit 92c57aa

Please sign in to comment.