Skip to content

Commit

Permalink
Move slim sequence example decoder to open source.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 308685525
Change-Id: Iab06095ad34aa1b36f9e0582b912b363d2729466
  • Loading branch information
TF-Slim Team authored and copybara-github committed Apr 27, 2020
1 parent a6a8f82 commit b258885
Show file tree
Hide file tree
Showing 2 changed files with 625 additions and 0 deletions.
290 changes: 290 additions & 0 deletions tf_slim/data/tfexample_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@
import six
from tf_slim.data import data_decoder
# pylint:disable=g-direct-tensorflow-import
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import map_fn
Expand Down Expand Up @@ -457,6 +460,200 @@ def decode_raw():
return image


class BoundingBoxSequence(ItemHandler):
"""An ItemHandler that concatenates SparseTensors to Bounding Boxes.
"""

def __init__(self, keys=None, prefix=None, return_dense=True,
default_value=-1.0):
"""Initialize the bounding box handler.
Args:
keys: A list of four key names representing the ymin, xmin, ymax, xmax
in the Example or SequenceExample.
prefix: An optional prefix for each of the bounding box keys in the
Example or SequenceExample. If provided, `prefix` is prepended to each
key in `keys`.
return_dense: if True, returns a dense tensor; if False, returns as
sparse tensor.
default_value: The value used when the `tensor_key` is not found in a
particular `TFExample`.
Raises:
ValueError: if keys is not `None` and also not a list of exactly 4 keys
"""
if keys is None:
keys = ['ymin', 'xmin', 'ymax', 'xmax']
elif len(keys) != 4:
raise ValueError('BoundingBoxSequence expects 4 keys but got {}'.format(
len(keys)))
self._prefix = prefix
self._keys = keys
self._full_keys = [prefix + k for k in keys]
self._return_dense = return_dense
self._default_value = default_value
super(BoundingBoxSequence, self).__init__(self._full_keys)

def tensors_to_item(self, keys_to_tensors):
"""Maps the given dictionary of tensors to a concatenated list of bboxes.
Args:
keys_to_tensors: a mapping of TF-Example keys to parsed tensors.
Returns:
[time, num_boxes, 4] tensor of bounding box coordinates, in order
[y_min, x_min, y_max, x_max]. Whether the tensor is a SparseTensor
or a dense Tensor is determined by the return_dense parameter. Empty
positions in the sparse tensor are filled with -1.0 values.
"""
sides = []
for key in self._full_keys:
value = keys_to_tensors[key]
expanded_dims = array_ops.concat(
[math_ops.to_int64(array_ops.shape(value)),
constant_op.constant([1], dtype=dtypes.int64)], 0)
side = sparse_ops.sparse_reshape(value, expanded_dims)
sides.append(side)
bounding_boxes = sparse_ops.sparse_concat(2, sides)
if self._return_dense:
bounding_boxes = sparse_ops.sparse_tensor_to_dense(
bounding_boxes, default_value=self._default_value)
return bounding_boxes


class NumBoxesSequence(ItemHandler):
"""An ItemHandler that returns num_boxes per frame for a box sequence.
`num_boxes` is inferred from a 2D SparseTensor decoded from a field in the
SequenceExample. The SparseTensor is partially dense and only ragged along its
second dimensions.
The output is an int64 tf.Tensor of shape [time], which is solely determined
by the tensor of the first key. However, if `check_consistency` is True, this
function checks that `num_boxes` is consistent across all keys.
"""

def __init__(self, keys=None, check_consistency=True):
"""Initialization.
Args:
keys: A list of keys of sparse tensors which have exactly 2 dimensions,
with the 1st being the `time` and the 2nd the `boxes` per frame.
key in `keys`.
check_consistency: if True, check for consistency.
Raises:
ValueError: If keys is empty.
"""
if not keys:
raise ValueError('keys must not be empty.')
self._check_consistency = check_consistency
super(NumBoxesSequence, self).__init__(keys)

def tensors_to_item(self, keys_to_tensors):
"""Maps the given dictionary of tensors to a num_boxes tensor.
If check_consistency is True: raises runtime error in Tensorflow when the
consistency is violated across tensors.
Args:
keys_to_tensors: A mapping of TF-Example keys to parsed tensors.
Returns:
[time] tf.Tensor containing the number of boxes per frame.
Raises:
ValueError: If any of the keyed tensors is not sparse or exactly 2
dimensional.
"""
def _compute_num_boxes(tensor):
"""Compute num_boxes from a single 2D tensor."""
if not isinstance(tensor, sparse_tensor.SparseTensor):
raise ValueError('tensor must be of type tf.SparseTensor.')
indices = tensor.indices
dense_shape = tensor.dense_shape
box_ids = indices[:, 1]
box_ids = sparse_tensor.SparseTensor(
indices=indices, values=box_ids, dense_shape=dense_shape)
box_ids = sparse_ops.sparse_tensor_to_dense(box_ids, default_value=-1)
# In the event that the parsed tensor is empty (perhaps due to a negative
# example), we pad box_ids so that the resulting number of boxes is 0.
num_boxes = math_ops.reduce_max(
array_ops.pad(box_ids + 1, [[0, 0], [0, 1]]), axis=1)
return num_boxes

num_boxes = _compute_num_boxes(keys_to_tensors[self._keys[0]])
asserts = []
if self._check_consistency:
for i in range(1, len(self._keys)):
cur_num_boxes = _compute_num_boxes(keys_to_tensors[self._keys[i]])
asserts.append(check_ops.assert_equal(num_boxes, cur_num_boxes))

with ops.control_dependencies(asserts):
return array_ops.identity(num_boxes)


class KeypointsSequence(ItemHandler):
"""An ItemHandler that concatenates SparseTensors to Keypoints.
"""

def __init__(self, keys=None, prefix=None, return_dense=True,
default_value=-1.0):
"""Initialize the keypoints handler.
Args:
keys: A list of two key names representing the y and x coordinates in the
Example or SequenceExample.
prefix: An optional prefix for each of the keypoint keys in the Example
or SequenceExample. If provided, `prefix` is prepended to each key in
`keys`.
return_dense: if True, returns a dense tensor; if False, returns as
sparse tensor.
default_value: The value used when the `tensor_key` is not found in a
particular `TFExample`.
Raises:
ValueError: if keys is not `None` and also not a list of exactly 2 keys
"""
if keys is None:
keys = ['y', 'x']
elif len(keys) != 2:
raise ValueError('KeypointsSequence expects 2 keys but got {}'.format(
len(keys)))
self._prefix = prefix
self._keys = keys
self._full_keys = [prefix + k for k in keys]
self._return_dense = return_dense
self._default_value = default_value
super(KeypointsSequence, self).__init__(self._full_keys)

def tensors_to_item(self, keys_to_tensors):
"""Maps the given dictionary of tensors to a concatenated list of keypoints.
Args:
keys_to_tensors: a mapping of TF-Example keys to parsed tensors.
Returns:
[time, num_keypoints, 2] tensor of keypoint coordinates, in order [y, x].
Whether the tensor is a SparseTensor or a dense Tensor is determined
by the return_dense parameter. Empty positions in the sparse tensor
are filled with -1.0 values.
"""
coordinates = []
for key in self._full_keys:
value = keys_to_tensors[key]
expanded_dims = array_ops.concat(
[math_ops.to_int64(array_ops.shape(value)),
constant_op.constant([1], dtype=dtypes.int64)], 0)
coordinate = sparse_ops.sparse_reshape(value, expanded_dims)
coordinates.append(coordinate)
keypoints = sparse_ops.sparse_concat(2, coordinates)
if self._return_dense:
keypoints = sparse_ops.sparse_tensor_to_dense(
keypoints, default_value=self._default_value)
return keypoints


class TFExampleDecoder(data_decoder.DataDecoder):
"""A decoder for TensorFlow Examples.
Expand Down Expand Up @@ -524,3 +721,96 @@ def decode(self, serialized_example, items=None):
keys_to_tensors = {key: example[key] for key in handler.keys}
outputs.append(handler.tensors_to_item(keys_to_tensors))
return outputs


class TFSequenceExampleDecoder(data_decoder.DataDecoder):
"""A decoder for TensorFlow SequenceExamples.
Decoding SequenceExample proto buffers is comprised of two stages:
(1) Example parsing and (2) tensor manipulation.
In the first stage, the tf.parse_single_sequence_example function is called
with a list of FixedLenFeatures and SparseLenFeatures. These instances tell TF
how to parse the example. The output of this stage is a set of tensors.
In the second stage, the resulting tensors are manipulated to provide the
requested 'item' tensors.
To perform this decoding operation, a SequenceExampleDecoder is given a list
of ItemHandlers. Each ItemHandler indicates the set of features for stage 1
and contains the instructions for post_processing its tensors for stage 2.
"""

def __init__(self, keys_to_context_features, keys_to_sequence_features,
items_to_handlers):
"""Constructs the decoder.
Args:
keys_to_context_features: a dictionary from TF-SequenceExample context
keys to either tf.VarLenFeature or tf.FixedLenFeature instances.
See tensorflow's parsing_ops.py.
keys_to_sequence_features: a dictionary from TF-SequenceExample sequence
keys to either tf.VarLenFeature or tf.FixedLenSequenceFeature instances.
See tensorflow's parsing_ops.py.
items_to_handlers: a dictionary from items (strings) to ItemHandler
instances. Note that the ItemHandler's are provided the keys that they
use to return the final item Tensors.
Raises:
ValueError: if the same key is present for context features and sequence
features.
"""
unique_keys = set()
unique_keys.update(keys_to_context_features)
unique_keys.update(keys_to_sequence_features)
if len(unique_keys) != (
len(keys_to_context_features) + len(keys_to_sequence_features)):
# This situation is ambiguous in the decoder's keys_to_tensors variable.
raise ValueError('Context and sequence keys are not unique. \n'
' Context keys: %s \n Sequence keys: %s' %
(list(keys_to_context_features.keys()),
list(keys_to_sequence_features.keys())))

self._keys_to_context_features = keys_to_context_features
self._keys_to_sequence_features = keys_to_sequence_features
self._items_to_handlers = items_to_handlers

def list_items(self):
"""See base class."""
return self._items_to_handlers.keys()

def decode(self, serialized_example, items=None):
"""Decodes the given serialized TF-SequenceExample.
Args:
serialized_example: a serialized TF-SequenceExample tensor.
items: the list of items to decode. These must be a subset of the item
keys in self._items_to_handlers. If `items` is left as None, then all
of the items in self._items_to_handlers are decoded.
Returns:
the decoded items, a list of tensor.
"""

context, feature_list = parsing_ops.parse_single_sequence_example(
serialized_example, self._keys_to_context_features,
self._keys_to_sequence_features)

# Reshape non-sparse elements just once:
for k in self._keys_to_context_features:
v = self._keys_to_context_features[k]
if isinstance(v, parsing_ops.FixedLenFeature):
context[k] = array_ops.reshape(context[k], v.shape)

if not items:
items = self._items_to_handlers.keys()

outputs = []
for item in items:
handler = self._items_to_handlers[item]
keys_to_tensors = {
key: context[key] if key in context else feature_list[key]
for key in handler.keys
}
outputs.append(handler.tensors_to_item(keys_to_tensors))
return outputs
Loading

0 comments on commit b258885

Please sign in to comment.