Skip to content

Commit

Permalink
Split inputs to save nested structures (#444)
Browse files Browse the repository at this point in the history
  • Loading branch information
NihalHarish authored Feb 13, 2021
1 parent c91c0dd commit f3b31a6
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 38 deletions.
11 changes: 11 additions & 0 deletions smdebug/analysis/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Standard Library
import re
from contextlib import contextmanager


Expand All @@ -21,6 +22,16 @@ def no_refresh(trials):
trial.dynamic_refresh = True


def _tensor_name_sorter(t_name):
# sorts t_names based on their numerical suffix
# currently used to sort internally named input
# and output tensors
if not bool(re.match(r".+_\d+", t_name)):
t_name = f"{t_name}_0"
t_name = t_name.split("_")[-1]
return int(t_name)


@contextmanager
def refresh(trials):
if isinstance(trials, list):
Expand Down
1 change: 1 addition & 0 deletions smdebug/core/tfevent/event_file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def as_dtype(t):
types_pb2.DT_HALF: np.float16,
types_pb2.DT_FLOAT: np.float32,
types_pb2.DT_DOUBLE: np.float64,
types_pb2.DT_INT8: np.uint8,
types_pb2.DT_INT32: np.int32,
types_pb2.DT_INT64: np.int64,
types_pb2.DT_STRING: np.str,
Expand Down
2 changes: 1 addition & 1 deletion smdebug/core/tfevent/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# hash value of ndarray.dtype is not the same as np.float class
# so we need to convert the type classes below to np.dtype object
_NP_DATATYPE_TO_PROTO_DATATYPE = {
np.dtype(np.float16): "DT_INT32",
np.dtype(np.float16): "DT_HALF",
np.dtype(np.float32): "DT_FLOAT",
np.dtype(np.float64): "DT_DOUBLE",
np.dtype(np.int32): "DT_INT32",
Expand Down
67 changes: 43 additions & 24 deletions smdebug/tensorflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,22 +539,34 @@ def save_gradients_from_logs(self, gradients):
g = g.values
self._save_tensor_to_file(export_name, g, collections_to_write)

def _save_model_inputs_and_outputs_helper(self, collection_key, tensors_to_save, prefix):
collections_to_write = (
{self.get_collection(collection_key)}
if self._is_collection_being_saved_for_step(collection_key)
else set()
)
if isinstance(tensors_to_save, (dict, list)):
tensors_to_save = nest.flatten(tensors_to_save)
for idx, t_value in enumerate(tensors_to_save):
t_name = f"{prefix}_{idx}"
self._save_tensor_to_file(t_name, t_value, collections_to_write)
else:
self._save_tensor_to_file(prefix, tensors_to_save, collections_to_write)

def save_smdebug_logs(self, logs):
if logs is None:
return

for key in logs:
tensors_to_save = []
collections_to_write = set()
if SMDEBUG_PREFIX in key:
# Save Model Outputs
if key in ModelOutputs:
export_name = get_model_output_export_name(key)
tensors_to_save.append((export_name, logs[key]))
collections_to_write = (
{self.get_collection(CollectionKeys.OUTPUTS)}
if self._is_collection_being_saved_for_step(CollectionKeys.OUTPUTS)
else set()
if key == ModelOutput.LABELS:
self._save_model_inputs_and_outputs_helper(
CollectionKeys.OUTPUTS, logs[key], prefix="labels"
)
elif key == ModelOutput.PREDICTIONS:
self._save_model_inputs_and_outputs_helper(
CollectionKeys.OUTPUTS, logs[key], prefix="predictions"
)
# Save Gradients
elif key == SMDEBUG_GRADIENTS_KEY:
Expand All @@ -564,19 +576,9 @@ def save_smdebug_logs(self, logs):
self._save_layer_values(logs[key])
# Save Model Inputs
elif key in ModelInputs:
export_name = get_model_input_export_name()
tensors_to_save.append((export_name, logs[key]))
collections_to_write = (
{self.get_collection(CollectionKeys.INPUTS)}
if self._is_collection_being_saved_for_step(CollectionKeys.INPUTS)
else set()
self._save_model_inputs_and_outputs_helper(
CollectionKeys.INPUTS, logs[key], prefix="inputs"
)
for t_name, t_value in tensors_to_save:
if isinstance(t_value, dict):
# flatten the inputs and labels
# since we cannot convert dicts into numpy
t_value = nest.flatten(t_value)
self._save_tensor_to_file(t_name, t_value, collections_to_write)

def _save_metrics(self, batch, logs, force_save=False):
# if force_save is True, doesn't check whether collection needs to be saved for steps
Expand Down Expand Up @@ -948,14 +950,31 @@ def _save_layer_values(self, logs):
layer_name = layer_name.name
elif isinstance(layer_name, bytes):
layer_name = str(layer_name, "utf-8")
layer_input_tensor_name = get_export_name_for_keras(str(layer_name), "input")
if len(layer_input) == 1:
# Layer Inputs are flattened and passed as a list into
# the next layer. Unpacking it speeds up the _make_numpy fn.
layer_input = layer_input[0]
layer_input_tensor_name = get_export_name_for_keras(str(layer_name), "input")
self._save_tensor_to_file(layer_input_tensor_name, layer_input, collections_to_write)
self._save_tensor_to_file(
layer_input_tensor_name, layer_input, collections_to_write
)
else:
for idx, l_name in enumerate(layer_input):
layer_input_tensor_name_with_idx = f"{layer_input_tensor_name}_{idx}"
self._save_tensor_to_file(
layer_input_tensor_name_with_idx, l_name, collections_to_write
)
layer_output_tensor_name = get_export_name_for_keras(str(layer_name), "output")
self._save_tensor_to_file(layer_output_tensor_name, layer_output, collections_to_write)
if isinstance(layer_output, list):
for idx, l_output in enumerate(layer_output):
layer_output_tensor_name_with_idx = f"{layer_output_tensor_name}_{idx}"
self._save_tensor_to_file(
layer_output_tensor_name_with_idx, l_output, collections_to_write
)
else:
self._save_tensor_to_file(
layer_output_tensor_name, layer_output, collections_to_write
)

def _write_optimizer_variables(self):
optimizer_collections = self.collection_manager.get(CollectionKeys.OPTIMIZER_VARIABLES)
Expand Down
43 changes: 40 additions & 3 deletions smdebug/trials/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from bisect import bisect_left

# First Party
from smdebug.analysis.utils import refresh
from smdebug.analysis.utils import _tensor_name_sorter, refresh
from smdebug.core.access_layer.utils import has_training_ended
from smdebug.core.collection import Collection
from smdebug.core.collection import Collection, CollectionKeys
from smdebug.core.config_constants import (
INCOMPLETE_STEP_WAIT_WINDOW_DEFAULT,
INCOMPLETE_STEP_WAIT_WINDOW_KEY,
Expand Down Expand Up @@ -364,7 +364,44 @@ def _tensors_in_collection(self, collection) -> set:
rval.update(self._tensors_matching_regex(regex))
return rval

def tensor_names(self, *, step=None, mode=ModeKeys.GLOBAL, regex=None, collection=None) -> list:
def inputs(self, step, mode=ModeKeys.GLOBAL):
input_tensors_names = sorted(
self.tensor_names(
show_prefixed_tensors=True, step=step, mode=mode, collection=CollectionKeys.INPUTS
),
key=_tensor_name_sorter,
)
input_tensors = [
self.tensor(tensor_name).value(step) for tensor_name in input_tensors_names
]
return input_tensors

def _get_output_tensors_helper(self, step, mode, regex):
output_tensors_names = sorted(
self.tensor_names(show_prefixed_tensors=True, step=step, mode=mode, regex=regex),
key=_tensor_name_sorter,
)
output_tensors = [
self.tensor(tensor_name).value(step) for tensor_name in output_tensors_names
]
return output_tensors

def labels(self, step, mode=ModeKeys.GLOBAL):
return self._get_output_tensors_helper(step, mode, regex="labels*")

def predictions(self, step, mode=ModeKeys.GLOBAL):
return self._get_output_tensors_helper(step, mode, regex="predictions*")

# * is used in python to force usage of named arguments
def tensor_names(
self,
show_prefixed_tensors=False,
*,
step=None,
mode=ModeKeys.GLOBAL,
regex=None,
collection=None,
) -> list:
self.maybe_refresh()
ts = set()
if step is None and mode == ModeKeys.GLOBAL:
Expand Down
11 changes: 4 additions & 7 deletions tests/tensorflow2/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,6 @@ def test_keras_fit_pure_eager(out_dir, tf_eager_mode):
assert len(trial.tensor_names(collection=CollectionKeys.OUTPUTS)) == (2 if is_tf_2_2() else 0)


@pytest.mark.skip # skip until aws tf update
def test_model_inputs_and_outputs(out_dir, tf_eager_mode):
# explicitly save INPUTS and OUTPUTS
include_collections = [CollectionKeys.INPUTS, CollectionKeys.OUTPUTS]
Expand All @@ -789,13 +788,11 @@ def test_model_inputs_and_outputs(out_dir, tf_eager_mode):
assert len(trial.tensor_names(collection=CollectionKeys.OUTPUTS)) == 2
assert len(trial.tensor_names(collection=CollectionKeys.INPUTS)) == 1

for tname in trial.tensor_names(collection=CollectionKeys.OUTPUTS):
output = trial.tensor(tname)
assert tname in ["y", "y_pred"]
assert output.value(0) is not None
# Check the shape of output tensors
assert trial.tensor("y").value(0).shape[1] == 1 # label
assert trial.tensor("y_pred").value(0).shape[1] == 10 # Output probability for each class
assert trial.labels(step=0)[0].shape == (6000, 1)
assert trial.predictions(step=0)[0].shape == (6000, 10)
# Check the shape of input tensors
assert trial.inputs(step=0)[0].shape == (6000, 28, 28)


@pytest.mark.skip # skip until aws tf update
Expand Down
2 changes: 1 addition & 1 deletion tests/tensorflow2/test_model_subclassing.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_subclassed_model(out_dir):
assert trial.tensor_names(collection=smd.CollectionKeys.LOSSES) == ["loss"]
if is_tf_2_2():
# Feature to save model inputs and outputs was first added for TF 2.2.0
assert trial.tensor_names(collection=smd.CollectionKeys.INPUTS) == ["model_input"]
assert trial.tensor_names(collection=smd.CollectionKeys.INPUTS) == ["inputs"]
assert trial.tensor_names(collection=smd.CollectionKeys.OUTPUTS) == [
"labels",
"predictions",
Expand Down
4 changes: 2 additions & 2 deletions tests/tensorflow2/test_support_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ def test_support_dicts(out_dir):
model.fit(inputs, labels, batch_size=16, epochs=10, callbacks=[smdebug_hook])
model.save(out_dir, save_format="tf")
trial = create_trial(out_dir)
assert trial.tensor_names(collection=CollectionKeys.INPUTS) == ["model_input"]
assert trial.tensor_names(collection=CollectionKeys.OUTPUTS) == ["labels", "predictions"]
assert trial.tensor_names(collection=CollectionKeys.INPUTS) == ["inputs_0"]
assert trial.tensor_names(collection=CollectionKeys.OUTPUTS) == ["labels_0", "predictions"]

0 comments on commit f3b31a6

Please sign in to comment.