Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 419892069
  • Loading branch information
G4G authored and copybara-github committed Jan 5, 2022
1 parent bab2352 commit a075635
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 35 deletions.
5 changes: 4 additions & 1 deletion tensorflow_graphics/nn/loss/chamfer_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(point_set_a, point_set_b, name="chamfer_distance_evaluate"):
def evaluate(point_set_a: type_alias.TensorLike,
point_set_b: type_alias.TensorLike,
name: str = "chamfer_distance_evaluate") -> tf.Tensor:
"""Computes the Chamfer distance for the given two point sets.
Note:
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_graphics/nn/loss/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(point_set_a, point_set_b, name="hausdorff_distance_evaluate"):
def evaluate(point_set_a: type_alias.TensorLike,
point_set_b: type_alias.TensorLike,
name: str = "hausdorff_distance_evaluate") -> tf.Tensor:
"""Computes the Hausdorff distance from point_set_a to point_set_b.
Note:
Expand Down
13 changes: 8 additions & 5 deletions tensorflow_graphics/nn/metric/fscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,23 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable

import tensorflow as tf

from tensorflow_graphics.nn.metric import precision as precision_module
from tensorflow_graphics.nn.metric import recall as recall_module
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(ground_truth,
prediction,
precision_function=precision_module.evaluate,
recall_function=recall_module.evaluate,
name="fscore_evaluate"):
def evaluate(ground_truth: type_alias.TensorLike,
prediction: type_alias.TensorLike,
precision_function: Callable[..., Any] = precision_module.evaluate,
recall_function: Callable[..., Any] = recall_module.evaluate,
name: str = "fscore_evaluate") -> tf.Tensor:
"""Computes the fscore metric for the given ground truth and predicted labels.
The fscore is calculated as 2 * (precision * recall) / (precision + recall)
Expand Down
9 changes: 5 additions & 4 deletions tensorflow_graphics/nn/metric/intersection_over_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
from tensorflow_graphics.util import asserts
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(ground_truth_labels,
predicted_labels,
grid_size=1,
name="intersection_over_union_evaluate"):
def evaluate(ground_truth_labels: type_alias.TensorLike,
predicted_labels: type_alias.TensorLike,
grid_size: int = 1,
name: str = "intersection_over_union_evaluate") -> tf.Tensor:
"""Computes the Intersection-Over-Union metric for the given ground truth and predicted labels.
Note:
Expand Down
15 changes: 9 additions & 6 deletions tensorflow_graphics/nn/metric/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,26 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable, List, Optional, Union, Tuple

import tensorflow as tf

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def _cast_to_int(prediction):
return tf.cast(x=prediction, dtype=tf.int32)


def evaluate(ground_truth,
prediction,
classes=None,
reduce_average=True,
prediction_to_category_function=_cast_to_int,
name="precision_evaluate"):
def evaluate(ground_truth: type_alias.TensorLike,
prediction: type_alias.TensorLike,
classes: Optional[Union[int, List[int], Tuple[int]]] = None,
reduce_average: bool = True,
prediction_to_category_function: Callable[..., Any] = _cast_to_int,
name: str = "precision_evaluate") -> tf.Tensor:
"""Computes the precision metric for the given ground truth and predictions.
Note:
Expand Down
15 changes: 9 additions & 6 deletions tensorflow_graphics/nn/metric/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,26 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable, List, Optional, Tuple, Union

import tensorflow as tf

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def _cast_to_int(prediction):
return tf.cast(x=prediction, dtype=tf.int32)


def evaluate(ground_truth,
prediction,
classes=None,
reduce_average=True,
prediction_to_category_function=_cast_to_int,
name="recall_evaluate"):
def evaluate(ground_truth: type_alias.TensorLike,
prediction: type_alias.TensorLike,
classes: Optional[Union[int, List[int], Tuple[int]]] = None,
reduce_average: bool = True,
prediction_to_category_function: Callable[..., Any] = _cast_to_int,
name: str = "recall_evaluate") -> tf.Tensor:
"""Computes the recall metric for the given ground truth and predictions.
Note:
Expand Down
34 changes: 24 additions & 10 deletions tensorflow_graphics/notebooks/mesh_segmentation_dataio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable, Dict, List, Tuple

import numpy as np
import tensorflow as tf

from tensorflow_graphics.geometry.convolution import utils as conv_utils
from tensorflow_graphics.geometry.representation.mesh import utils as mesh_utils
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias

DEFAULT_IO_PARAMS = {
'batch_size': 8,
Expand All @@ -42,7 +45,11 @@
}


def adjacency_from_edges(edges, weights, num_edges, num_vertices):
def adjacency_from_edges(
edges: type_alias.TensorLike,
weights: type_alias.TensorLike,
num_edges: type_alias.TensorLike,
num_vertices: type_alias.TensorLike) -> tf.SparseTensor:
"""Returns a batched sparse 1-ring adj tensor from edge list tensor.
Args:
Expand Down Expand Up @@ -103,7 +110,9 @@ def adjacency_from_edges(edges, weights, num_edges, num_vertices):
return adjacency


def get_weighted_edges(faces, self_edges=True):
def get_weighted_edges(
faces: np.ndarray,
self_edges: bool = True) -> Tuple[np.ndarray, np.ndarray]:
r"""Gets unique edges and degree weights from a triangular mesh.
The shorthands used below are:
Expand Down Expand Up @@ -136,12 +145,12 @@ def get_weighted_edges(faces, self_edges=True):
return edges, weights


def _tfrecords_to_dataset(tfrecords,
parallel_threads,
shuffle,
repeat,
sloppy,
max_readers=16):
def _tfrecords_to_dataset(tfrecords: List[str],
parallel_threads: int,
shuffle: bool,
repeat: bool,
sloppy: bool,
max_readers: int = 16) -> tf.data.TFRecordDataset:
"""Creates a TFRecordsDataset that iterates over filenames in parallel.
Args:
Expand Down Expand Up @@ -244,7 +253,9 @@ def _parse_mesh_data(mesh_data, mean_center=True):
return mesh_data


def create_dataset_from_tfrecords(tfrecords, params):
def create_dataset_from_tfrecords(
tfrecords: List[str],
params: Dict[str, Any]) -> tf.data.Dataset:
"""Creates a mesh dataset given a list of tf records filenames.
Args:
Expand Down Expand Up @@ -309,7 +320,10 @@ def _set_default_if_none(param, param_dict, default_val):
drop_remainder=is_training)


def create_input_from_dataset(dataset_fn, files, io_params):
def create_input_from_dataset(
dataset_fn: Callable[..., Any],
files: List[str],
io_params: Dict[str, Any]) -> Tuple[Dict[str, Any], tf.Tensor]:
"""Creates input function given dataset generator and input files.
Args:
Expand Down
6 changes: 4 additions & 2 deletions tensorflow_graphics/notebooks/mesh_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Dict

import numpy as np
from tensorflow_graphics.notebooks import threejs_visualization

Expand All @@ -32,7 +34,7 @@
class Viewer(object):
"""A ThreeJS based viewer class for viewing 3D meshes."""

def _mesh_from_data(self, data):
def _mesh_from_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Creates a dictionary of ThreeJS mesh objects from numpy data."""
if 'vertices' not in data or 'faces' not in data:
raise ValueError('Mesh Data must contain vertices and faces')
Expand All @@ -54,7 +56,7 @@ def _mesh_from_data(self, data):
mesh['material'] = material
return mesh

def __init__(self, source_mesh_data):
def __init__(self, source_mesh_data: Dict[str, Any]):
context = threejs_visualization.build_context()
self.context = context
light1 = context.THREE.PointLight.new_object(0x808080)
Expand Down

0 comments on commit a075635

Please sign in to comment.