Skip to content

Latest commit

 

History

History
1117 lines (611 loc) · 29.3 KB

contrib.framework.md

File metadata and controls

1117 lines (611 loc) · 29.3 KB

Framework (contrib)

[TOC]

Framework utilities.


tf.contrib.framework.assert_same_float_dtype(tensors=None, dtype=None) {#assert_same_float_dtype}

Validate and return float type based on tensors and dtype.

For ops such as matrix multiplication, inputs and weights must be of the same float type. This function validates that all tensors are the same type, validates that type is dtype (if supplied), and returns the type. Type must be dtypes.float32 or dtypes.float64. If neither tensors nor dtype is supplied, default to dtypes.float32.

Args:
  • tensors: Tensors of input values. Can include None elements, which will be ignored.
  • dtype: Expected type.
Returns:

Validated type.

Raises:
  • ValueError: if neither tensors nor dtype is supplied, or result is not float.

tf.contrib.framework.assert_scalar_int(tensor, name=None) {#assert_scalar_int}

Assert tensor is 0-D, of type tf.int32 or tf.int64.

Args:
  • tensor: Tensor to test.
  • name: Name of the op and of the new Tensor if one is created.
Returns:

tensor, for chaining.

Raises:
  • ValueError: if tensor is not 0-D, of type tf.int32 or tf.int64.

tf.convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None) {#convert_to_tensor_or_sparse_tensor}

Converts value to a SparseTensor or Tensor.

Args:
  • value: A SparseTensor, SparseTensorValue, or an object whose type has a registered Tensor conversion function.
  • dtype: Optional element type for the returned tensor. If missing, the type is inferred from the type of value.
  • name: Optional name to use if a new Tensor is created.
Returns:

A SparseTensor or Tensor based on value.

Raises:
  • RuntimeError: If result type is incompatible with dtype.

tf.contrib.framework.get_graph_from_inputs(op_input_list, graph=None) {#get_graph_from_inputs}

Returns the appropriate graph to use for the given inputs.

  1. If graph is provided, we validate that all inputs in op_input_list are from the same graph.
  2. Otherwise, we attempt to select a graph from the first Operation- or Tensor-valued input in op_input_list, and validate that all other such inputs are in the same graph.
  3. If the graph was not specified and it could not be inferred from op_input_list, we attempt to use the default graph.
Args:
  • op_input_list: A list of inputs to an operation, which may include Tensor, Operation, and other objects that may be converted to a graph element.
  • graph: (Optional) The explicit graph to use.
Raises:
  • TypeError: If op_input_list is not a list or tuple, or if graph is not a Graph.
  • ValueError: If a graph is explicitly passed and not all inputs are from it, or if the inputs are from multiple graphs, or we could not find a graph and there was no default graph.
Returns:

The appropriate graph to use for the given inputs.


tf.is_numeric_tensor(tensor) {#is_numeric_tensor}


tf.is_non_decreasing(x, name=None) {#is_non_decreasing}

Returns True if x is non-decreasing.

Elements of x are compared in row-major order. The tensor [x[0],...] is non-decreasing if for every adjacent pair we have x[i] <= x[i+1]. If x has less than two elements, it is trivially non-decreasing.

See also: is_strictly_increasing

Args:
  • x: Numeric Tensor.
  • name: A name for this operation (optional). Defaults to "is_non_decreasing"
Returns:

Boolean Tensor, equal to True iff x is non-decreasing.

Raises:
  • TypeError: if x is not a numeric tensor.

tf.is_strictly_increasing(x, name=None) {#is_strictly_increasing}

Returns True if x is strictly increasing.

Elements of x are compared in row-major order. The tensor [x[0],...] is strictly increasing if for every adjacent pair we have x[i] < x[i+1]. If x has less than two elements, it is trivially strictly increasing.

See also: is_non_decreasing

Args:
  • x: Numeric Tensor.
  • name: A name for this operation (optional). Defaults to "is_strictly_increasing"
Returns:

Boolean Tensor, equal to True iff x is strictly increasing.

Raises:
  • TypeError: if x is not a numeric tensor.

tf.contrib.framework.is_tensor(x) {#is_tensor}

Check for tensor types.

Check whether an object is a tensor. Equivalent to isinstance(x, [tf.Tensor, tf.SparseTensor, tf.Variable]).

Args:
  • x: An python object to check.
Returns:

True if x is a tensor, False if not.


tf.contrib.framework.reduce_sum_n(tensors, name=None) {#reduce_sum_n}

Reduce tensors to a scalar sum.

This reduces each tensor in tensors to a scalar via tf.reduce_sum, then adds them via tf.add_n.

Args:
  • tensors: List of tensors, all of the same numeric type.
  • name: Tensor name, and scope for all other ops.
Returns:

Total loss tensor, or None if no losses have been configured.

Raises:
  • ValueError: if losses is missing or empty.

tf.contrib.framework.with_shape(expected_shape, tensor) {#with_shape}

Asserts tensor has expected shape.

If tensor shape and expected_shape, are fully defined, assert they match. Otherwise, add assert op that will validate the shape when tensor is evaluated, and set shape on tensor.

Args:
  • expected_shape: Expected shape to assert, as a 1D array of ints, or tensor of same.
  • tensor: Tensor whose shape we're validating.
Returns:

tensor, perhaps with a dependent assert operation.

Raises:
  • ValueError: if tensor has an invalid shape.

tf.contrib.framework.with_same_shape(expected_tensor, tensor) {#with_same_shape}

Assert tensors are the same shape, from the same graph.

Args:
  • expected_tensor: Tensor with expected shape.
  • tensor: Tensor of actual values.
Returns:

Tuple of (actual_tensor, label_tensor), possibly with assert ops added.

Deprecation


tf.contrib.framework.deprecated(date, instructions) {#deprecated}

Decorator for marking functions or methods deprecated.

This decorator logs a deprecation warning whenever the decorated function is called. It has the following format:

(from ) is deprecated and will be removed after . Instructions for updating:

will include the class name if it is a method.

It also edits the docstring of the function: ' (deprecated)' is appended to the first line of the docstring and a deprecation notice is prepended to the rest of the docstring.

Args:
  • date: String. The date the function is scheduled to be removed. Must be ISO 8601 (YYYY-MM-DD).
  • instructions: String. Instructions on how to update code using the deprecated function.
Returns:

Decorated function or method.

Raises:
  • ValueError: If date is not in ISO 8601 format, or instructions are empty.

tf.contrib.framework.deprecated_args(date, instructions, *deprecated_arg_names_or_tuples) {#deprecated_args}

Decorator for marking specific function arguments as deprecated.

This decorator logs a deprecation warning whenever the decorated function is called with the deprecated argument. It has the following format:

Calling (from ) with is deprecated and will be removed after . Instructions for updating:

will include the class name if it is a method.

It also edits the docstring of the function: ' (deprecated arguments)' is appended to the first line of the docstring and a deprecation notice is prepended to the rest of the docstring.

Args:
  • date: String. The date the function is scheduled to be removed. Must be ISO 8601 (YYYY-MM-DD).
  • instructions: String. Instructions on how to update code using the deprecated function.
  • *deprecated_arg_names_or_tuples: String. or 2-Tuple(String, [ok_vals]). The string is the deprecated argument name. Optionally, an ok-value may be provided. If the user provided argument equals this value, the warning is suppressed.
Returns:

Decorated function or method.

Raises:
  • ValueError: If date is not in ISO 8601 format, instructions are empty, the deprecated arguments are not present in the function signature, or the second element of a deprecated_tuple is not a list.

tf.contrib.framework.deprecated_arg_values(date, instructions, **deprecated_kwargs) {#deprecated_arg_values}

Decorator for marking specific function argument values as deprecated.

This decorator logs a deprecation warning whenever the decorated function is called with the deprecated argument values. It has the following format:

Calling (from ) with = is deprecated and will be removed after . Instructions for updating:

will include the class name if it is a method.

It also edits the docstring of the function: ' (deprecated arguments)' is appended to the first line of the docstring and a deprecation notice is prepended to the rest of the docstring.

Args:
  • date: String. The date the function is scheduled to be removed. Must be ISO 8601 (YYYY-MM-DD).
  • instructions: String. Instructions on how to update code using the deprecated function.
  • **deprecated_kwargs: The deprecated argument values.
Returns:

Decorated function or method.

Raises:
  • ValueError: If date is not in ISO 8601 format, or instructions are empty.

Arg_Scope


tf.contrib.framework.arg_scope(list_ops_or_scope, **kwargs) {#arg_scope}

Stores the default arguments for the given set of list_ops.

For usage, please see examples at top of the file.

Args:
  • list_ops_or_scope: List or tuple of operations to set argument scope for or a dictionary containing the current scope. When list_ops_or_scope is a dict, kwargs must be empty. When list_ops_or_scope is a list or tuple, then every op in it need to be decorated with @add_arg_scope to work.
  • **kwargs: keyword=value that will define the defaults for each op in list_ops. All the ops need to accept the given set of arguments.
Yields:

the current_scope, which is a dictionary of {op: {arg: value}}

Raises:
  • TypeError: if list_ops is not a list or a tuple.
  • ValueError: if any op in list_ops has not be decorated with @add_arg_scope.

tf.contrib.framework.add_arg_scope(func) {#add_arg_scope}

Decorates a function with args so it can be used within an arg_scope.

Args:
  • func: function to decorate.
Returns:

A tuple with the decorated function func_with_args().


tf.contrib.framework.has_arg_scope(func) {#has_arg_scope}

Checks whether a func has been decorated with @add_arg_scope or not.

Args:
  • func: function to check.
Returns:

a boolean.


tf.contrib.framework.arg_scoped_arguments(func) {#arg_scoped_arguments}

Returns the list kwargs that arg_scope can set for a func.

Args:
  • func: function which has been decorated with @add_arg_scope.
Returns:

a list of kwargs names.

Variables


tf.contrib.framework.add_model_variable(var) {#add_model_variable}

Adds a variable to the GraphKeys.MODEL_VARIABLES collection.

Args:
  • var: a variable.

tf.train.assert_global_step(global_step_tensor) {#assert_global_step}

Asserts global_step_tensor is a scalar int Variable or Tensor.

Args:
  • global_step_tensor: Tensor to test.

tf.contrib.framework.assert_or_get_global_step(graph=None, global_step_tensor=None) {#assert_or_get_global_step}

Verifies that a global step tensor is valid or gets one if None is given.

If global_step_tensor is not None, check that it is a valid global step tensor (using assert_global_step). Otherwise find a global step tensor using get_global_step and return it.

Args:
  • graph: The graph to find the global step tensor for.
  • global_step_tensor: The tensor to check for suitability as a global step. If None is given (the default), find a global step tensor.
Returns:

A tensor suitable as a global step, or None if none was provided and none was found.


tf.contrib.framework.assign_from_checkpoint(model_path, var_list) {#assign_from_checkpoint}

Creates an operation to assign specific variables from a checkpoint.

Args:
  • model_path: The full path to the model checkpoint. To get latest checkpoint use model_path = tf.train.latest_checkpoint(checkpoint_dir)
  • var_list: A list of Variable objects or a dictionary mapping names in the checkpoint to the corresponding variables to initialize. If empty or None, it would return no_op(), None.
Returns:

the restore_op and the feed_dict that need to be run to restore var_list.

Raises:
  • ValueError: If the checkpoint specified at model_path is missing one of the variables in var_list.

tf.contrib.framework.assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False, reshape_variables=False) {#assign_from_checkpoint_fn}

Returns a function that assigns specific variables from a checkpoint.

Args:
  • model_path: The full path to the model checkpoint. To get latest checkpoint use model_path = tf.train.latest_checkpoint(checkpoint_dir)
  • var_list: A list of Variable objects or a dictionary mapping names in the checkpoint to the correspoing variables to initialize. If empty or None, it would return no_op(), None.
  • ignore_missing_vars: Boolean, if True it would ignore variables missing in the checkpoint with a warning instead of failing.
  • reshape_variables: Boolean, if True it would automatically reshape variables which are of different shape then the ones stored in the checkpoint but which have the same number of elements.
Returns:

A function that takes a single argument, a tf.Session, that applies the assignment operation.

Raises:
  • ValueError: If the checkpoint specified at model_path is missing one of the variables in var_list.

tf.contrib.framework.assign_from_values(var_names_to_values) {#assign_from_values}

Creates an assignment operation from a given mapping.

This function provides a mechanism for performing assignment of variables to values in a way that does not fill the graph with large assignment values.

Args:
  • var_names_to_values: A map from variable names to values.
Returns:
  • assign_op: An Operation that assigns each of the given variables to the requested values.
  • feed_dict: The feed dictionary to use when evaluating assign_op.
Raises:
  • ValueError: if any of the given variable names were not found.

tf.contrib.framework.assign_from_values_fn(var_names_to_values) {#assign_from_values_fn}

Returns a function that assigns specific variables from the given values.

This function provides a mechanism for performing assignment of variables to values in a way that does not fill the graph with large assignment values.

Args:
  • var_names_to_values: A map from variable names to values.
Returns:

A function that takes a single argument, a tf.Session, that applies the assignment operation.

Raises:
  • ValueError: if any of the given variable names were not found.

tf.contrib.framework.create_global_step(graph=None) {#create_global_step}

Create global step tensor in graph.

Args:
  • graph: The graph in which to create the global step. If missing, use default graph.
Returns:

Global step tensor.

Raises:
  • ValueError: if global step key is already defined.

tf.train.get_global_step(graph=None) {#get_global_step}

Get the global step tensor.

The global step tensor must be an integer variable. We first try to find it in the collection GLOBAL_STEP, or by name global_step:0.

Args:
  • graph: The graph to find the global step in. If missing, use default graph.
Returns:

The global step variable, or None if none was found.

Raises:
  • TypeError: If the global step tensor has a non-integer type, or if it is not a Variable.

tf.contrib.framework.get_or_create_global_step(graph=None) {#get_or_create_global_step}

Returns and create (if necessary) the global step variable.

Args:
  • graph: The graph in which to create the global step. If missing, use default graph.
Returns:

the tensor representing the global step variable.


tf.contrib.framework.get_local_variables(scope=None, suffix=None) {#get_local_variables}

Gets the list of local variables, filtered by scope and/or suffix.

Args:
  • scope: an optional scope for filtering the variables to return.
  • suffix: an optional suffix for filtering the variables to return.
Returns:

a list of variables in collection with scope and suffix.


tf.contrib.framework.get_model_variables(scope=None, suffix=None) {#get_model_variables}

Gets the list of model variables, filtered by scope and/or suffix.

Args:
  • scope: an optional scope for filtering the variables to return.
  • suffix: an optional suffix for filtering the variables to return.
Returns:

a list of variables in collection with scope and suffix.


tf.contrib.framework.get_unique_variable(var_op_name) {#get_unique_variable}

Gets the variable uniquely identified by that var_op_name.

Args:
  • var_op_name: the full name of the variable op, including the scope.
Returns:

a tensorflow variable.

Raises:
  • ValueError: if no variable uniquely identified by the name exists.

tf.contrib.framework.get_variables_by_name(given_name, scope=None) {#get_variables_by_name}

Gets the list of variables that were given that name.

Args:
  • given_name: name given to the variable without any scope.
  • scope: an optional scope for filtering the variables to return.
Returns:

a copied list of variables with the given name and scope.


tf.contrib.framework.get_variables_by_suffix(suffix, scope=None) {#get_variables_by_suffix}

Gets the list of variables that end with the given suffix.

Args:
  • suffix: suffix for filtering the variables to return.
  • scope: an optional scope for filtering the variables to return.
Returns:

a copied list of variables with the given name and prefix.


tf.contrib.framework.get_variables_to_restore(include=None, exclude=None) {#get_variables_to_restore}

Gets the list of the variables to restore.

Args:
  • include: an optional list/tuple of scope strings for filtering which variables from the VARIABLES collection to include. None would include all the variables.
  • exclude: an optional list/tuple of scope strings for filtering which variables from the VARIABLES collection to exclude. None it would not exclude any.
Returns:

a list of variables to restore.

Raises:
  • TypeError: include or exclude is provided but is not a list or a tuple.

tf.contrib.framework.get_variables(scope=None, suffix=None, collection='variables') {#get_variables}

Gets the list of variables, filtered by scope and/or suffix.

Args:
  • scope: an optional scope for filtering the variables to return. Can be a variable scope or a string.
  • suffix: an optional suffix for filtering the variables to return.
  • collection: in which collection search for. Defaults to GraphKeys.GLOBAL_VARIABLES.
Returns:

a list of variables in collection with scope and suffix.


tf.contrib.framework.local_variable(initial_value, validate_shape=True, name=None) {#local_variable}

Create variable and add it to GraphKeys.LOCAL_VARIABLES collection.

Args:
  • initial_value: See variables.Variable.init.
  • validate_shape: See variables.Variable.init.
  • name: See variables.Variable.init.
Returns:

New variable.


tf.contrib.framework.model_variable(*args, **kwargs) {#model_variable}

Gets an existing model variable with these parameters or creates a new one.

Args:
  • name: the name of the new or existing variable.
  • shape: shape of the new or existing variable.
  • dtype: type of the new or existing variable (defaults to DT_FLOAT).
  • initializer: initializer for the variable if one is created.
  • regularizer: a (Tensor -> Tensor or None) function; the result of applying it on a newly created variable will be added to the collection GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
  • trainable: If True also add the variable to the graph collection GraphKeys.TRAINABLE_VARIABLES (see tf.Variable).
  • collections: A list of collection names to which the Variable will be added. Note that the variable is always also added to the GraphKeys.GLOBAL_VARIABLES and GraphKeys.MODEL_VARIABLES collections.
  • caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device.
  • device: Optional device to place the variable. It can be an string or a function that is called to get the device for the variable.
  • partitioner: Optional callable that accepts a fully defined TensorShape and dtype of the Variable to be created, and returns a list of partitions for each axis (currently only one axis can be partitioned).
  • custom_getter: Callable that allows overwriting the internal get_variable method and has to have the same signature.
Returns:

The created or existing variable.


tf.contrib.framework.variable(*args, **kwargs) {#variable}

Gets an existing variable with these parameters or creates a new one.

Args:
  • name: the name of the new or existing variable.
  • shape: shape of the new or existing variable.
  • dtype: type of the new or existing variable (defaults to DT_FLOAT).
  • initializer: initializer for the variable if one is created.
  • regularizer: a (Tensor -> Tensor or None) function; the result of applying it on a newly created variable will be added to the collection GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
  • trainable: If True also add the variable to the graph collection GraphKeys.TRAINABLE_VARIABLES (see tf.Variable).
  • collections: A list of collection names to which the Variable will be added. If None it would default to tf.GraphKeys.GLOBAL_VARIABLES.
  • caching_device: Optional device string or function describing where the Variable should be cached for reading. Defaults to the Variable's device.
  • device: Optional device to place the variable. It can be an string or a function that is called to get the device for the variable.
  • partitioner: Optional callable that accepts a fully defined TensorShape and dtype of the Variable to be created, and returns a list of partitions for each axis (currently only one axis can be partitioned).
  • custom_getter: Callable that allows overwriting the internal get_variable method and has to have the same signature.
Returns:

The created or existing variable.


class tf.contrib.framework.VariableDeviceChooser {#VariableDeviceChooser}

Device chooser for variables.

When using a parameter server it will assign them in a round-robin fashion. When not using a parameter server it allows GPU or CPU placement.


tf.contrib.framework.VariableDeviceChooser.__call__(op) {#VariableDeviceChooser.call}


tf.contrib.framework.VariableDeviceChooser.__init__(num_tasks=0, job_name='ps', device_type='CPU', device_index=0) {#VariableDeviceChooser.init}

Initialize VariableDeviceChooser.

Usage:

To use with 2 parameter servers: VariableDeviceChooser(2)

To use without parameter servers: VariableDeviceChooser() VariableDeviceChooser(device_type='GPU') # For GPU placement

Args:
  • num_tasks: number of tasks.
  • job_name: String, a name for the parameter server job.
  • device_type: Optional device type string (e.g. "CPU" or "GPU")
  • device_index: int. Optional device index. If left unspecified, device represents 'any' device_index.

tf.contrib.framework.zero_initializer(ref, use_locking=True, name='zero_initializer') {#zero_initializer}

Initialize 'ref' with all zeros, ref tensor should be uninitialized. If already initialized, you will get ValueError. This op is intended to save memory during initialization.

Args:
  • ref: ref of the tensor need to be zero initialized.
  • name: optional name for this operation.
Returns:

ref that initialized.

Raises:
  • ValueError: If ref tensor is initialized.

Checkpoint utilities


tf.contrib.framework.load_checkpoint(filepattern) {#load_checkpoint}

Returns CheckpointReader for latest checkpoint.

Args:
  • filepattern: Directory with checkpoints file or path to checkpoint.
Returns:

CheckpointReader object.

Raises:
  • ValueError: if checkpoint_dir doesn't have 'checkpoint' file or checkpoints.

tf.contrib.framework.list_variables(checkpoint_dir) {#list_variables}

Returns list of all variables in the latest checkpoint.

Args:
  • checkpoint_dir: Directory with checkpoints file or path to checkpoint.
Returns:

List of tuples (name, shape).


tf.contrib.framework.load_variable(checkpoint_dir, name) {#load_variable}

Returns a Tensor with the contents of the given variable in the checkpoint.

Args:
  • checkpoint_dir: Directory with checkpoints file or path to checkpoint.
  • name: Name of the tensor to return.
Returns:

Tensor object.


tf.contrib.framework.init_from_checkpoint(checkpoint_dir, assignment_map) {#init_from_checkpoint}

Using assingment map initializes current variables with loaded tensors.

Note: This overrides default initialization ops of specified variables and redefines dtype.

Assignment map supports following syntax:

'checkpoint_scope_name/': 'scope_name/' - will load all variables in current scope_name from checkpoint_scope_name with matching variable names. 'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name' - will initalize scope_name/variable_name variable from checkpoint_scope_name/some_other_variable. 'scope_variable_name': variable - will initialize given tf.Variable object with variable from the checkpoint. 'scope_variable_name': list(variable) - will initialize list of partitioned variables with variable from the checkpoint. '/': 'scope_name/' - will load all variables in current scope_name from checkpoint's root (e.g. no scope).

Supports loading into partitioned variables, which are represented as '/part_<part #>'.

  • Example:
  # Create variables.
  with tf.variable_scope('test'):
    m = tf.get_variable('my_var')
  with tf.variable_scope('test2'):
    var2 = tf.get_variable('my_var')
  var3 = tf.get_variable(name="my1", shape=[100, 100],
                         partitioner=lambda shape, dtype: [5, 1])
  ...
  # Specify which variables to intialize from checkpoint.
  init_from_checkpoint(checkpoint_dir, {
    'some_var': 'test/my_var',
    'some_scope/': 'test2/'})
  ...
  # Or use `Variable` objects to identify what to initialize.
  init_from_checkpoint(checkpoint_dir, {
    'some_scope/var2': var2,
  })
  # Initialize partitioned variables
  init_from_checkpoint(checkpoint_dir, {
    'some_var_from_ckpt': 'part_var',
  })
  # Or specifying the list of `Variable` objects.
  init_from_checkpoint(checkpoint_dir, {
    'some_var_from_ckpt': var3._get_variable_list(),
  })
  ...
  # Initialize variables as usual.
  session.run(tf.get_all_variables())
Args:
  • checkpoint_dir: Directory with checkpoints file or path to checkpoint.
  • assignment_map: Dict, where keys are names of the variables in the checkpoint and values are current variables or names of current variables (in default graph).
Raises:

tf.errors.OpError: If missing checkpoints or tensors in checkpoints.

  • ValueError: If missing variables in current graph.