Skip to content

Commit

Permalink
[bugfix] fix name2var bug (#458)
Browse files Browse the repository at this point in the history
fix name2var bug
  • Loading branch information
chengmengli06 authored Apr 8, 2024
1 parent 212570b commit ea2610e
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions easy_rec/python/model/easy_rec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile

from easy_rec.python.compat import regularizers
from easy_rec.python.layers import input_layer
Expand Down Expand Up @@ -379,28 +380,30 @@ def _get_restore_vars(self, ckpt_var_map_path):
name2var[var_name] = [one_var] if is_part else one_var

if ckpt_var_map_path != '':
if not tf.gfile.Exists(ckpt_var_map_path):
if not gfile.Exists(ckpt_var_map_path):
logging.warning('%s not exist' % ckpt_var_map_path)
return name2var

# load var map
name_map = {}
with open(ckpt_var_map_path, 'r') as fin:
with gfile.GFile(ckpt_var_map_path, 'r') as fin:
for one_line in fin:
one_line = one_line.strip()
line_tok = [x for x in one_line.split() if x != '']
if len(line_tok) != 2:
logging.warning('Failed to process: %s' % one_line)
continue
name_map[line_tok[0]] = line_tok[1]
var_map = {}
update_map = {}
old_keys = []
for var_name in name2var:
if var_name in name_map:
in_ckpt_name = name_map[var_name]
var_map[in_ckpt_name] = name2var[var_name]
else:
logging.warning('Failed to find in var_map_file(%s): %s' %
(ckpt_var_map_path, var_name))
update_map[in_ckpt_name] = name2var[var_name]
old_keys.append(var_name)
for tmp_key in old_keys:
del name2var[tmp_key]
name2var.update(update_map)
return name2var
else:
var_filter, scope_update = self.get_restore_filter()
Expand Down

0 comments on commit ea2610e

Please sign in to comment.