在tensorflow中加载旧的检查点
我使用Tensorflow r0.12训练了一些模型并保存了它。后来我更新到r1.0.1。一些模型加载时没有任何问题,但如果模型中有RNN单元,加载失败,检查点中未发现在tensorflow中加载旧的检查点,tensorflow,Tensorflow,我使用Tensorflow r0.12训练了一些模型并保存了它。后来我更新到r1.0.1。一些模型加载时没有任何问题,但如果模型中有RNN单元,加载失败,检查点中未发现关键层5/双向RNN/bw/多RNN单元/单元1/基本RNN单元/偏差。 另外,如果我检查model.index文件,我会看到类似的条目,例如:5/BiRNN/BW/MultiRNNCell/Cell0/basicrncell/Linear/Bias 带有RNN单元格的包现在位于tf.contrib.RNN(在0.12中是tf.n
关键层5/双向RNN/bw/多RNN单元/单元1/基本RNN单元/偏差。
另外,如果我检查model.index
文件,我会看到类似的条目,例如:5/BiRNN/BW/MultiRNNCell/Cell0/basicrncell/Linear/Bias
带有RNN单元格的包现在位于tf.contrib.RNN
(在0.12中是tf.nn.RNN\u cell
),因此我认为某些命名已更改
问题是:
是否有办法加载我的模型,重新映射其张量并保存,以便张量名称与r1.0兼容
另外,如果有帮助的话,我还有model.meta
文件
谢谢 如果有人遇到同样的问题,下面是我使用的解决方案。它是tensorflow.python.tools
中的inspect\u checkpoint.py
中tensor打印函数的修改版本
def resave_tensors(file_name, rename_map, dry_run=False):
"""
Updates checkpoint by renaming tensors in it.
:param file_name: Filename with checkpoint.
:param rename_map: Map from old names to new ones
:param dry_run: If True, just print new tensors.
"""
renames_count = 0
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
print("tensor_name: ", key)
tensor_val = reader.get_tensor(key)
print('shape: {}'.format(tensor_val.shape))
if key in rename_map:
renames_count += 1
key = rename_map[key]
tf.Variable(tensor_val, dtype=tensor_val.dtype, name=key)
saver = tf.train.Saver()
if not dry_run:
with tf.Session() as session:
session.run(tf.global_variables_initializer())
saver.save(session, file_name)
print('Renamed vars: {}'.format(renames_count))