Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Horizontal Flipping for Test Time Augmentation #15

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/openpifpaf/decoder/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch

from .utils.hflip import hflip_average_fields_batch
from .. import annotation, visualizer

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -111,10 +112,31 @@ def apply(f, items):
LOG.debug('nn processing time: %.1fms', (time.time() - start) * 1000.0)
return heads

def batch(self, model, image_batch, *, device=None, gt_anns_batch=None):
def batch(self, model, image_batch, *, device=None, hflip=False, gt_anns_batch=None):
"""From image batch straight to annotations batch."""
start_nn = time.perf_counter()
fields_batch = self.fields_batch(model, image_batch, device=device)

if hflip:
# The horizontal-flip evaluation technique improves accuracy when evaluating the test set.
# We average the predictions generated for the original and flipped image and use that as the
# final prediction. This method reduces prediction noise.

# Take horizontal flipped image and generate fields.
hflip_image_batch = torch.flip(image_batch, [-1])
combined_image_batch = torch.cat((image_batch, hflip_image_batch), dim=0)
combined_fields_batch = self.fields_batch(model, combined_image_batch, device=device)
cfb_len = len(combined_fields_batch)
assert cfb_len % 2 == 0
fields_batch = combined_fields_batch[:cfb_len // 2]
hflip_fields_batch = combined_fields_batch[cfb_len // 2:]

# Average the fields with the original fields before decoding to the final prediction.
fields_batch = hflip_average_fields_batch(
fields_batch=fields_batch, hflip_fields_batch=hflip_fields_batch, head_metas=model.head_metas
)
else:
fields_batch = self.fields_batch(model, image_batch, device=device)

self.last_nn_time = time.perf_counter() - start_nn

if gt_anns_batch is None:
Expand Down
54 changes: 54 additions & 0 deletions src/openpifpaf/decoder/utils/hflip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Helper methods for horizontally flipping field representations of the image during evaluation.
"""

import torch


def hflip_average_fields_batch(fields_batch, hflip_fields_batch, head_metas):
""" Entrypoint function for horizontal flipping. """
hflip_funcs = []
for head_meta in head_metas:
if head_meta.name == 'cifdet':
hflip_func = hflip_average_cifdet_fields_batch
else:
raise ValueError(f'Unsupported head meta for hflip: {head_meta.name}.')
hflip_funcs.append(hflip_func)

for i, current_batch in enumerate(fields_batch):
assert len(current_batch) == len(head_metas)
for j, field_set in enumerate(current_batch):
# Additional processing for hflip field set specific to heads used.
hflip_field_set = hflip_funcs[j](hflip_fields_batch[i][j])

# Take an average of both fields for final fields batch prediction.
field_set = field_set.add(hflip_field_set)
field_set = torch.div(field_set, 2)
fields_batch[i][j] = field_set

return fields_batch


def hflip_handle_reg_x_offset(hflip_field_set, offset_field_index=2):
""" Handle the set of x regression fields that the cifdet head produces. """
# Horizontally flip field to perform offset (reverse the operation of flipping all fields)
hflip_field_set[:, offset_field_index, :, :] = torch.flip(hflip_field_set[:, offset_field_index, :, :], [-1])
# Deal with vector offsets for x regression field
fields_shape = hflip_field_set.shape
offset_tensor = torch.arange(fields_shape[3]).repeat(fields_shape[0], fields_shape[2], 1)
# Remove offset
hflip_field_set[:, offset_field_index, :, :] = hflip_field_set[:, offset_field_index, :, :].subtract(offset_tensor)
# Horizontally flip field again
hflip_field_set[:, offset_field_index, :, :] = torch.flip(hflip_field_set[:, offset_field_index, :, :], [-1])
# Negate x regression field
hflip_field_set[:, offset_field_index, :, :] = torch.neg(hflip_field_set[:, offset_field_index, :, :])
# Add back offset
hflip_field_set[:, offset_field_index, :, :] = hflip_field_set[:, offset_field_index, :, :].add(offset_tensor)
return hflip_field_set


def hflip_average_cifdet_fields_batch(hflip_field_set):
""" Function returns the horizontally flipped set of cifdet fields used in object detection tasks. """
hflip_field_set = torch.flip(hflip_field_set, [-1])
hflip_field_set = hflip_handle_reg_x_offset(hflip_field_set)
return hflip_field_set
9 changes: 8 additions & 1 deletion src/openpifpaf/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def cli(cls, parser: argparse.ArgumentParser, *,
group.add_argument('--precise-rescaling', dest='fast_rescaling',
default=True, action='store_false',
help='use more exact image rescaling (requires scipy)')
group.add_argument('--tta-hflip', dest='tta_hflip',
default=False, action='store_true',
help='apply horizontal flipping as test time augmentation')

@classmethod
def configure(cls, args: argparse.Namespace):
Expand All @@ -80,6 +83,7 @@ def configure(cls, args: argparse.Namespace):
cls.fast_rescaling = args.fast_rescaling
cls.loader_workers = args.loader_workers
cls.long_edge = args.long_edge
cls.tta_hflip = args.tta_hflip

def preprocess_factory(self):
rescale_t = None
Expand Down Expand Up @@ -125,7 +129,10 @@ def enumerated_dataloader(self, enumerated_dataloader):
if self.visualize_processed_image:
visualizer.Base.processed_image(processed_image_batch[0])

pred_batch = self.processor.batch(self.model, processed_image_batch, device=self.device)
pred_batch = self.processor.batch(
self.model, processed_image_batch, hflip=self.tta_hflip, device=self.device
)

self.last_decoder_time = self.processor.last_decoder_time
self.last_nn_time = self.processor.last_nn_time
self.total_decoder_time += self.processor.last_decoder_time
Expand Down