diff --git a/tensorflow_graphics/nn/loss/chamfer_distance.py b/tensorflow_graphics/nn/loss/chamfer_distance.py index 61c733f1f..e22e40425 100644 --- a/tensorflow_graphics/nn/loss/chamfer_distance.py +++ b/tensorflow_graphics/nn/loss/chamfer_distance.py @@ -21,9 +21,12 @@ from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def evaluate(point_set_a, point_set_b, name="chamfer_distance_evaluate"): +def evaluate(point_set_a: type_alias.TensorLike, + point_set_b: type_alias.TensorLike, + name: str = "chamfer_distance_evaluate") -> tf.Tensor: """Computes the Chamfer distance for the given two point sets. Note: diff --git a/tensorflow_graphics/nn/loss/hausdorff_distance.py b/tensorflow_graphics/nn/loss/hausdorff_distance.py index 46fb0946f..279f1bbbd 100644 --- a/tensorflow_graphics/nn/loss/hausdorff_distance.py +++ b/tensorflow_graphics/nn/loss/hausdorff_distance.py @@ -21,9 +21,12 @@ from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def evaluate(point_set_a, point_set_b, name="hausdorff_distance_evaluate"): +def evaluate(point_set_a: type_alias.TensorLike, + point_set_b: type_alias.TensorLike, + name: str = "hausdorff_distance_evaluate") -> tf.Tensor: """Computes the Hausdorff distance from point_set_a to point_set_b. Note: diff --git a/tensorflow_graphics/nn/metric/fscore.py b/tensorflow_graphics/nn/metric/fscore.py index d979eee3e..8bb696c39 100644 --- a/tensorflow_graphics/nn/metric/fscore.py +++ b/tensorflow_graphics/nn/metric/fscore.py @@ -17,6 +17,8 @@ from __future__ import division from __future__ import print_function +from typing import Any, Callable + import tensorflow as tf from tensorflow_graphics.nn.metric import precision as precision_module @@ -24,13 +26,14 @@ from tensorflow_graphics.util import export_api from tensorflow_graphics.util import safe_ops from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def evaluate(ground_truth, - prediction, - precision_function=precision_module.evaluate, - recall_function=recall_module.evaluate, - name="fscore_evaluate"): +def evaluate(ground_truth: type_alias.TensorLike, + prediction: type_alias.TensorLike, + precision_function: Callable[..., Any] = precision_module.evaluate, + recall_function: Callable[..., Any] = recall_module.evaluate, + name: str = "fscore_evaluate") -> tf.Tensor: """Computes the fscore metric for the given ground truth and predicted labels. The fscore is calculated as 2 * (precision * recall) / (precision + recall) diff --git a/tensorflow_graphics/nn/metric/intersection_over_union.py b/tensorflow_graphics/nn/metric/intersection_over_union.py index 5b1896649..1fe718057 100644 --- a/tensorflow_graphics/nn/metric/intersection_over_union.py +++ b/tensorflow_graphics/nn/metric/intersection_over_union.py @@ -23,12 +23,13 @@ from tensorflow_graphics.util import asserts from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def evaluate(ground_truth_labels, - predicted_labels, - grid_size=1, - name="intersection_over_union_evaluate"): +def evaluate(ground_truth_labels: type_alias.TensorLike, + predicted_labels: type_alias.TensorLike, + grid_size: int = 1, + name: str = "intersection_over_union_evaluate") -> tf.Tensor: """Computes the Intersection-Over-Union metric for the given ground truth and predicted labels. Note: diff --git a/tensorflow_graphics/nn/metric/precision.py b/tensorflow_graphics/nn/metric/precision.py index 913801456..814a82fb1 100644 --- a/tensorflow_graphics/nn/metric/precision.py +++ b/tensorflow_graphics/nn/metric/precision.py @@ -17,23 +17,26 @@ from __future__ import division from __future__ import print_function +from typing import Any, Callable, List, Optional, Union, Tuple + import tensorflow as tf from tensorflow_graphics.util import export_api from tensorflow_graphics.util import safe_ops from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias def _cast_to_int(prediction): return tf.cast(x=prediction, dtype=tf.int32) -def evaluate(ground_truth, - prediction, - classes=None, - reduce_average=True, - prediction_to_category_function=_cast_to_int, - name="precision_evaluate"): +def evaluate(ground_truth: type_alias.TensorLike, + prediction: type_alias.TensorLike, + classes: Optional[Union[int, List[int], Tuple[int]]] = None, + reduce_average: bool = True, + prediction_to_category_function: Callable[..., Any] = _cast_to_int, + name: str = "precision_evaluate") -> tf.Tensor: """Computes the precision metric for the given ground truth and predictions. Note: diff --git a/tensorflow_graphics/nn/metric/recall.py b/tensorflow_graphics/nn/metric/recall.py index 250f9e544..7e82be870 100644 --- a/tensorflow_graphics/nn/metric/recall.py +++ b/tensorflow_graphics/nn/metric/recall.py @@ -17,23 +17,26 @@ from __future__ import division from __future__ import print_function +from typing import Any, Callable, List, Optional, Tuple, Union + import tensorflow as tf from tensorflow_graphics.util import export_api from tensorflow_graphics.util import safe_ops from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias def _cast_to_int(prediction): return tf.cast(x=prediction, dtype=tf.int32) -def evaluate(ground_truth, - prediction, - classes=None, - reduce_average=True, - prediction_to_category_function=_cast_to_int, - name="recall_evaluate"): +def evaluate(ground_truth: type_alias.TensorLike, + prediction: type_alias.TensorLike, + classes: Optional[Union[int, List[int], Tuple[int]]] = None, + reduce_average: bool = True, + prediction_to_category_function: Callable[..., Any] = _cast_to_int, + name: str = "recall_evaluate") -> tf.Tensor: """Computes the recall metric for the given ground truth and predictions. Note: diff --git a/tensorflow_graphics/notebooks/mesh_segmentation_dataio.py b/tensorflow_graphics/notebooks/mesh_segmentation_dataio.py index 6d43d099e..ea07b3e26 100644 --- a/tensorflow_graphics/notebooks/mesh_segmentation_dataio.py +++ b/tensorflow_graphics/notebooks/mesh_segmentation_dataio.py @@ -24,12 +24,15 @@ from __future__ import division from __future__ import print_function +from typing import Any, Callable, Dict, List, Tuple + import numpy as np import tensorflow as tf from tensorflow_graphics.geometry.convolution import utils as conv_utils from tensorflow_graphics.geometry.representation.mesh import utils as mesh_utils from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias DEFAULT_IO_PARAMS = { 'batch_size': 8, @@ -42,7 +45,11 @@ } -def adjacency_from_edges(edges, weights, num_edges, num_vertices): +def adjacency_from_edges( + edges: type_alias.TensorLike, + weights: type_alias.TensorLike, + num_edges: type_alias.TensorLike, + num_vertices: type_alias.TensorLike) -> tf.SparseTensor: """Returns a batched sparse 1-ring adj tensor from edge list tensor. Args: @@ -103,7 +110,9 @@ def adjacency_from_edges(edges, weights, num_edges, num_vertices): return adjacency -def get_weighted_edges(faces, self_edges=True): +def get_weighted_edges( + faces: np.ndarray, + self_edges: bool = True) -> Tuple[np.ndarray, np.ndarray]: r"""Gets unique edges and degree weights from a triangular mesh. The shorthands used below are: @@ -136,12 +145,12 @@ def get_weighted_edges(faces, self_edges=True): return edges, weights -def _tfrecords_to_dataset(tfrecords, - parallel_threads, - shuffle, - repeat, - sloppy, - max_readers=16): +def _tfrecords_to_dataset(tfrecords: List[str], + parallel_threads: int, + shuffle: bool, + repeat: bool, + sloppy: bool, + max_readers: int = 16) -> tf.data.TFRecordDataset: """Creates a TFRecordsDataset that iterates over filenames in parallel. Args: @@ -244,7 +253,9 @@ def _parse_mesh_data(mesh_data, mean_center=True): return mesh_data -def create_dataset_from_tfrecords(tfrecords, params): +def create_dataset_from_tfrecords( + tfrecords: List[str], + params: Dict[str, Any]) -> tf.data.Dataset: """Creates a mesh dataset given a list of tf records filenames. Args: @@ -309,7 +320,10 @@ def _set_default_if_none(param, param_dict, default_val): drop_remainder=is_training) -def create_input_from_dataset(dataset_fn, files, io_params): +def create_input_from_dataset( + dataset_fn: Callable[..., Any], + files: List[str], + io_params: Dict[str, Any]) -> Tuple[Dict[str, Any], tf.Tensor]: """Creates input function given dataset generator and input files. Args: diff --git a/tensorflow_graphics/notebooks/mesh_viewer.py b/tensorflow_graphics/notebooks/mesh_viewer.py index bf6779073..60dd39141 100644 --- a/tensorflow_graphics/notebooks/mesh_viewer.py +++ b/tensorflow_graphics/notebooks/mesh_viewer.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function +from typing import Any, Dict + import numpy as np from tensorflow_graphics.notebooks import threejs_visualization @@ -32,7 +34,7 @@ class Viewer(object): """A ThreeJS based viewer class for viewing 3D meshes.""" - def _mesh_from_data(self, data): + def _mesh_from_data(self, data: Dict[str, Any]) -> Dict[str, Any]: """Creates a dictionary of ThreeJS mesh objects from numpy data.""" if 'vertices' not in data or 'faces' not in data: raise ValueError('Mesh Data must contain vertices and faces') @@ -54,7 +56,7 @@ def _mesh_from_data(self, data): mesh['material'] = material return mesh - def __init__(self, source_mesh_data): + def __init__(self, source_mesh_data: Dict[str, Any]): context = threejs_visualization.build_context() self.context = context light1 = context.THREE.PointLight.new_object(0x808080)