Skip to content

Commit

Permalink
Add partial init
Browse files Browse the repository at this point in the history
  • Loading branch information
j-min committed May 22, 2017
1 parent 58be3f3 commit b375f8b
Showing 1 changed file with 230 additions and 0 deletions.
230 changes: 230 additions & 0 deletions Save_Restore/Partial_restore_initialization.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reference: \n",
"- optimistic_restore: https://github.com/tensorflow/tensorflow/issues/1823, https://github.com/tensorflow/tensorflow/issues/312\n",
"- get_uninit_vars: http://stackoverflow.com/questions/35164529/in-tensorflow-is-there-any-way-to-just-initialize-uninitialised-variables"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Partial restore + Partial initialization"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Only A is saved and restored"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 3]\n"
]
}
],
"source": [
"tf.reset_default_graph()\n",
"a = tf.train.get_or_create_global_step()\n",
"b = tf.Variable(3)\n",
"sess = tf.Session()\n",
"saver = tf.train.Saver([a])\n",
"\n",
"# Initialize\n",
"sess.run(tf.global_variables_initializer())\n",
"\n",
"# Save\n",
"saver.save(sess, save_path='./restore_reinitialize')\n",
"\n",
"print(sess.run([a, b]))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def optimistic_restore(sess, save_path, restore_vars=None, name=None):\n",
" # Reference: https://github.com/tensorflow/tensorflow/issues/312\n",
"\n",
" # Load checkpoint reader\n",
" reader = tf.train.NewCheckpointReader(save_path)\n",
"\n",
" # All variable saved in checkpoint\n",
" # Dict: name => shape\n",
" # {'global_step': [],\n",
" # 'resnet_v1_101/block1/unit_1/bottleneck_v1/conv1/BatchNorm/beta': [64], .. }\n",
" saved_shapes = reader.get_variable_to_shape_map()\n",
"\n",
" # List of all names of global variables in current graph\n",
" # Sort because variables in checkpoints are sorted by their names already.\n",
" # [('global_step:0', 'global_step'),\n",
" # ('resnet_v1_101/block1/unit_1/bottleneck_v1/conv1/BatchNorm/beta:0', .. ]\n",
" var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()\n",
" if var.name.split(':')[0] in saved_shapes])\n",
"\n",
" # Dict: name => variable\n",
" # Key: 'Decoder/LSTM_initializer/Layer_0/fully_connected/biases'\n",
" # Value: <tf.Variable 'Decoder/LSTM_initializer/Layer_0/fully_connected/biases:0' shape=(512,) dtype=float32_ref>\n",
" name2var = dict(zip(map(lambda x:x.name.split(':')[0], tf.global_variables()), tf.global_variables()))\n",
"\n",
" # List all global variables to restore if they are in checkpoint\n",
" restore_vars = []\n",
" with tf.variable_scope('', reuse=True):\n",
" for var_name, saved_var_name in var_names:\n",
" curr_var = name2var[saved_var_name]\n",
" var_shape = curr_var.get_shape().as_list()\n",
" if var_shape == saved_shapes[saved_var_name]:\n",
" restore_vars.append(curr_var)\n",
"\n",
" # Restore variables\n",
" saver = tf.train.Saver(restore_vars, name=name)\n",
" saver.restore(sess, save_path)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from ./restore_reinitialize\n",
"Attempting to use uninitialized value Variable\n",
"\t [[Node: _retval_Variable_0_0 = _Retval[T=DT_INT32, index=0, _device=\"/job:localhost/replica:0/task:0/cpu:0\"](Variable)]]\n"
]
}
],
"source": [
"tf.reset_default_graph()\n",
"a = tf.train.get_or_create_global_step()\n",
"b = tf.Variable(3)\n",
"sess = tf.Session()\n",
"\n",
"optimistic_restore(sess, './restore_reinitialize')\n",
"try:\n",
" print(sess.run([a, b]))\n",
"except Exception as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize b"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def get_uninit_vars(sess, variables=None):\n",
" if variables is None:\n",
" variables = tf.global_variables()\n",
" init_flag = sess.run(\n",
" tf.stack([tf.is_variable_initialized(v) for v in variables]))\n",
" return [v for v, f in zip(variables, init_flag) if not f]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[<tf.Variable 'Variable:0' shape=() dtype=int32_ref>]\n",
"[3]\n"
]
}
],
"source": [
"uninit_vars = get_uninit_vars(sess)\n",
"print(uninit_vars)\n",
"init_op = tf.variables_initializer(uninit_vars) \n",
"sess.run(init_op)\n",
"print(sess.run(uninit_vars))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 3]\n"
]
}
],
"source": [
"# Initialize\n",
"sess.run(tf.variables_initializer([a, b]))\n",
"\n",
"print(sess.run([a, b]))"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda env:mldemo]",
"language": "python",
"name": "conda-env-mldemo-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit b375f8b

Please sign in to comment.