Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change. #676

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tensorflow_graphics/nn/loss/chamfer_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(point_set_a, point_set_b, name="chamfer_distance_evaluate"):
def evaluate(point_set_a: type_alias.TensorLike,
point_set_b: type_alias.TensorLike,
name: str = "chamfer_distance_evaluate") -> tf.Tensor:
"""Computes the Chamfer distance for the given two point sets.

Note:
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_graphics/nn/loss/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(point_set_a, point_set_b, name="hausdorff_distance_evaluate"):
def evaluate(point_set_a: type_alias.TensorLike,
point_set_b: type_alias.TensorLike,
name: str = "hausdorff_distance_evaluate") -> tf.Tensor:
"""Computes the Hausdorff distance from point_set_a to point_set_b.

Note:
Expand Down
13 changes: 8 additions & 5 deletions tensorflow_graphics/nn/metric/fscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,23 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable

import tensorflow as tf

from tensorflow_graphics.nn.metric import precision as precision_module
from tensorflow_graphics.nn.metric import recall as recall_module
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(ground_truth,
prediction,
precision_function=precision_module.evaluate,
recall_function=recall_module.evaluate,
name="fscore_evaluate"):
def evaluate(ground_truth: type_alias.TensorLike,
prediction: type_alias.TensorLike,
precision_function: Callable[..., Any] = precision_module.evaluate,
recall_function: Callable[..., Any] = recall_module.evaluate,
name: str = "fscore_evaluate") -> tf.Tensor:
"""Computes the fscore metric for the given ground truth and predicted labels.

The fscore is calculated as 2 * (precision * recall) / (precision + recall)
Expand Down
9 changes: 5 additions & 4 deletions tensorflow_graphics/nn/metric/intersection_over_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
from tensorflow_graphics.util import asserts
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(ground_truth_labels,
predicted_labels,
grid_size=1,
name="intersection_over_union_evaluate"):
def evaluate(ground_truth_labels: type_alias.TensorLike,
predicted_labels: type_alias.TensorLike,
grid_size: int = 1,
name: str = "intersection_over_union_evaluate") -> tf.Tensor:
"""Computes the Intersection-Over-Union metric for the given ground truth and predicted labels.

Note:
Expand Down
15 changes: 9 additions & 6 deletions tensorflow_graphics/nn/metric/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,26 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable, List, Optional, Union, Tuple

import tensorflow as tf

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def _cast_to_int(prediction):
return tf.cast(x=prediction, dtype=tf.int32)


def evaluate(ground_truth,
prediction,
classes=None,
reduce_average=True,
prediction_to_category_function=_cast_to_int,
name="precision_evaluate"):
def evaluate(ground_truth: type_alias.TensorLike,
prediction: type_alias.TensorLike,
classes: Optional[Union[int, List[int], Tuple[int]]] = None,
reduce_average: bool = True,
prediction_to_category_function: Callable[..., Any] = _cast_to_int,
name: str = "precision_evaluate") -> tf.Tensor:
"""Computes the precision metric for the given ground truth and predictions.

Note:
Expand Down
15 changes: 9 additions & 6 deletions tensorflow_graphics/nn/metric/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,26 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable, List, Optional, Tuple, Union

import tensorflow as tf

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def _cast_to_int(prediction):
return tf.cast(x=prediction, dtype=tf.int32)


def evaluate(ground_truth,
prediction,
classes=None,
reduce_average=True,
prediction_to_category_function=_cast_to_int,
name="recall_evaluate"):
def evaluate(ground_truth: type_alias.TensorLike,
prediction: type_alias.TensorLike,
classes: Optional[Union[int, List[int], Tuple[int]]] = None,
reduce_average: bool = True,
prediction_to_category_function: Callable[..., Any] = _cast_to_int,
name: str = "recall_evaluate") -> tf.Tensor:
"""Computes the recall metric for the given ground truth and predictions.

Note:
Expand Down
34 changes: 24 additions & 10 deletions tensorflow_graphics/notebooks/mesh_segmentation_dataio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable, Dict, List, Tuple

import numpy as np
import tensorflow as tf

from tensorflow_graphics.geometry.convolution import utils as conv_utils
from tensorflow_graphics.geometry.representation.mesh import utils as mesh_utils
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias

DEFAULT_IO_PARAMS = {
'batch_size': 8,
Expand All @@ -42,7 +45,11 @@
}


def adjacency_from_edges(edges, weights, num_edges, num_vertices):
def adjacency_from_edges(
edges: type_alias.TensorLike,
weights: type_alias.TensorLike,
num_edges: type_alias.TensorLike,
num_vertices: type_alias.TensorLike) -> tf.SparseTensor:
"""Returns a batched sparse 1-ring adj tensor from edge list tensor.

Args:
Expand Down Expand Up @@ -103,7 +110,9 @@ def adjacency_from_edges(edges, weights, num_edges, num_vertices):
return adjacency


def get_weighted_edges(faces, self_edges=True):
def get_weighted_edges(
faces: np.ndarray,
self_edges: bool = True) -> Tuple[np.ndarray, np.ndarray]:
r"""Gets unique edges and degree weights from a triangular mesh.

The shorthands used below are:
Expand Down Expand Up @@ -136,12 +145,12 @@ def get_weighted_edges(faces, self_edges=True):
return edges, weights


def _tfrecords_to_dataset(tfrecords,
parallel_threads,
shuffle,
repeat,
sloppy,
max_readers=16):
def _tfrecords_to_dataset(tfrecords: List[str],
parallel_threads: int,
shuffle: bool,
repeat: bool,
sloppy: bool,
max_readers: int = 16) -> tf.data.TFRecordDataset:
"""Creates a TFRecordsDataset that iterates over filenames in parallel.

Args:
Expand Down Expand Up @@ -244,7 +253,9 @@ def _parse_mesh_data(mesh_data, mean_center=True):
return mesh_data


def create_dataset_from_tfrecords(tfrecords, params):
def create_dataset_from_tfrecords(
tfrecords: List[str],
params: Dict[str, Any]) -> tf.data.Dataset:
"""Creates a mesh dataset given a list of tf records filenames.

Args:
Expand Down Expand Up @@ -309,7 +320,10 @@ def _set_default_if_none(param, param_dict, default_val):
drop_remainder=is_training)


def create_input_from_dataset(dataset_fn, files, io_params):
def create_input_from_dataset(
dataset_fn: Callable[..., Any],
files: List[str],
io_params: Dict[str, Any]) -> Tuple[Dict[str, Any], tf.Tensor]:
"""Creates input function given dataset generator and input files.

Args:
Expand Down
6 changes: 4 additions & 2 deletions tensorflow_graphics/notebooks/mesh_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Dict

import numpy as np
from tensorflow_graphics.notebooks import threejs_visualization

Expand All @@ -32,7 +34,7 @@
class Viewer(object):
"""A ThreeJS based viewer class for viewing 3D meshes."""

def _mesh_from_data(self, data):
def _mesh_from_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Creates a dictionary of ThreeJS mesh objects from numpy data."""
if 'vertices' not in data or 'faces' not in data:
raise ValueError('Mesh Data must contain vertices and faces')
Expand All @@ -54,7 +56,7 @@ def _mesh_from_data(self, data):
mesh['material'] = material
return mesh

def __init__(self, source_mesh_data):
def __init__(self, source_mesh_data: Dict[str, Any]):
context = threejs_visualization.build_context()
self.context = context
light1 = context.THREE.PointLight.new_object(0x808080)
Expand Down