-
Notifications
You must be signed in to change notification settings - Fork 98
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
230 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Reference: \n", | ||
"- optimistic_restore: https://github.com/tensorflow/tensorflow/issues/1823, https://github.com/tensorflow/tensorflow/issues/312\n", | ||
"- get_uninit_vars: http://stackoverflow.com/questions/35164529/in-tensorflow-is-there-any-way-to-just-initialize-uninitialised-variables" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Partial restore + Partial initialization" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import tensorflow as tf" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Only A is saved and restored" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[0, 3]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"tf.reset_default_graph()\n", | ||
"a = tf.train.get_or_create_global_step()\n", | ||
"b = tf.Variable(3)\n", | ||
"sess = tf.Session()\n", | ||
"saver = tf.train.Saver([a])\n", | ||
"\n", | ||
"# Initialize\n", | ||
"sess.run(tf.global_variables_initializer())\n", | ||
"\n", | ||
"# Save\n", | ||
"saver.save(sess, save_path='./restore_reinitialize')\n", | ||
"\n", | ||
"print(sess.run([a, b]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def optimistic_restore(sess, save_path, restore_vars=None, name=None):\n", | ||
" # Reference: https://github.com/tensorflow/tensorflow/issues/312\n", | ||
"\n", | ||
" # Load checkpoint reader\n", | ||
" reader = tf.train.NewCheckpointReader(save_path)\n", | ||
"\n", | ||
" # All variable saved in checkpoint\n", | ||
" # Dict: name => shape\n", | ||
" # {'global_step': [],\n", | ||
" # 'resnet_v1_101/block1/unit_1/bottleneck_v1/conv1/BatchNorm/beta': [64], .. }\n", | ||
" saved_shapes = reader.get_variable_to_shape_map()\n", | ||
"\n", | ||
" # List of all names of global variables in current graph\n", | ||
" # Sort because variables in checkpoints are sorted by their names already.\n", | ||
" # [('global_step:0', 'global_step'),\n", | ||
" # ('resnet_v1_101/block1/unit_1/bottleneck_v1/conv1/BatchNorm/beta:0', .. ]\n", | ||
" var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()\n", | ||
" if var.name.split(':')[0] in saved_shapes])\n", | ||
"\n", | ||
" # Dict: name => variable\n", | ||
" # Key: 'Decoder/LSTM_initializer/Layer_0/fully_connected/biases'\n", | ||
" # Value: <tf.Variable 'Decoder/LSTM_initializer/Layer_0/fully_connected/biases:0' shape=(512,) dtype=float32_ref>\n", | ||
" name2var = dict(zip(map(lambda x:x.name.split(':')[0], tf.global_variables()), tf.global_variables()))\n", | ||
"\n", | ||
" # List all global variables to restore if they are in checkpoint\n", | ||
" restore_vars = []\n", | ||
" with tf.variable_scope('', reuse=True):\n", | ||
" for var_name, saved_var_name in var_names:\n", | ||
" curr_var = name2var[saved_var_name]\n", | ||
" var_shape = curr_var.get_shape().as_list()\n", | ||
" if var_shape == saved_shapes[saved_var_name]:\n", | ||
" restore_vars.append(curr_var)\n", | ||
"\n", | ||
" # Restore variables\n", | ||
" saver = tf.train.Saver(restore_vars, name=name)\n", | ||
" saver.restore(sess, save_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"INFO:tensorflow:Restoring parameters from ./restore_reinitialize\n", | ||
"Attempting to use uninitialized value Variable\n", | ||
"\t [[Node: _retval_Variable_0_0 = _Retval[T=DT_INT32, index=0, _device=\"/job:localhost/replica:0/task:0/cpu:0\"](Variable)]]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"tf.reset_default_graph()\n", | ||
"a = tf.train.get_or_create_global_step()\n", | ||
"b = tf.Variable(3)\n", | ||
"sess = tf.Session()\n", | ||
"\n", | ||
"optimistic_restore(sess, './restore_reinitialize')\n", | ||
"try:\n", | ||
" print(sess.run([a, b]))\n", | ||
"except Exception as e:\n", | ||
" print(e)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Initialize b" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_uninit_vars(sess, variables=None):\n", | ||
" if variables is None:\n", | ||
" variables = tf.global_variables()\n", | ||
" init_flag = sess.run(\n", | ||
" tf.stack([tf.is_variable_initialized(v) for v in variables]))\n", | ||
" return [v for v, f in zip(variables, init_flag) if not f]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[<tf.Variable 'Variable:0' shape=() dtype=int32_ref>]\n", | ||
"[3]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"uninit_vars = get_uninit_vars(sess)\n", | ||
"print(uninit_vars)\n", | ||
"init_op = tf.variables_initializer(uninit_vars) \n", | ||
"sess.run(init_op)\n", | ||
"print(sess.run(uninit_vars))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[0, 3]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Initialize\n", | ||
"sess.run(tf.variables_initializer([a, b]))\n", | ||
"\n", | ||
"print(sess.run([a, b]))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"anaconda-cloud": {}, | ||
"kernelspec": { | ||
"display_name": "Python [conda env:mldemo]", | ||
"language": "python", | ||
"name": "conda-env-mldemo-py" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |