Python 无法从检查点还原:双向/向后\lstm/bias

Python 无法从检查点还原:双向/向后\lstm/bias,python,tensorflow,tensor2tensor,Python,Tensorflow,Tensor2tensor,我试图在tensor2tensor中创建一个简单的基于LSTM的RNN 到目前为止,培训似乎有效,但我无法恢复模型。尝试这样做将抛出一个NotFoundError,指出LSTM中的偏差节点: NotFoundError:。。 检查点中未找到键双向/向后\lstm/偏移 我不知道为什么会这样 这实际上是另一个问题的解决方法,我可以使用tensor2tensor()中的LSTM来解决类似问题 环境 $pip freeze | grep张量 网格张量流==0.0.5 张量2传感器==1.12.0 张

我试图在tensor2tensor中创建一个简单的基于LSTM的RNN

到目前为止,培训似乎有效,但我无法恢复模型。尝试这样做将抛出一个
NotFoundError
,指出LSTM中的偏差节点:

NotFoundError:。。
检查点中未找到键双向/向后\lstm/偏移
我不知道为什么会这样

这实际上是另一个问题的解决方法,我可以使用tensor2tensor()中的LSTM来解决类似问题

环境
$pip freeze | grep张量
网格张量流==0.0.5
张量2传感器==1.12.0
张力板==1.12.0
tensorflow数据集==1.0.2
张量流估计器==1.13.0
tensorflow gpu==1.12.0
tensorflow元数据==0.9.0
张量流概率==0.5.0
模型体 完全错误
NotFoundError(回溯见上文):从检查点还原失败。这很可能是由于检查点中缺少变量名或其他图形键造成的。请确保您没有根据检查点更改预期的图形。原始错误:
检查点中未找到while/lstm_keras/parallel_0_4/lstm_keras/lstm_keras/body/bidirectional/backward_lstm/bias键
[[node save/RestoreV2(在/home/sfalk/tmp/pycharm_project_265/asr/model/persistence.py:282中定义)=RestoreV2[dtypes=[DT_FLOAT,DT_FLOAT,DT_FLOAT,DT_FLOAT,DT_FLOAT,DT_FLOAT,DT_FLOAT,DT_FLOAT,DT_FLOAT],_设备=“/job:localhost/replica:0/任务:0/设备:CPU:0”](_arg_save/Const_0_0,save/RestoreV2/tensor_name,save/RestoreV2/shape_和_切片)]]

任何问题可能是什么以及如何解决?

这似乎与使用tensor2tensor从检查点恢复期间的“while”似乎在键名前面。似乎是一个未解决的错误,请在github上输入

如果可以的话,我会对此发表评论,但我的声誉太低了。干杯

def body(self, features):

    inputs = features['inputs'][:,:,0,:]

    hparams = self._hparams
    problem = hparams.problem
    encoders = problem.feature_info

    max_input_length = 350
    max_output_length = 350 

    encoder = Bidirectional(LSTM(128, return_sequences=True, unroll=False), merge_mode='concat')(inputs)
    encoder_last = encoder[:, -1, :]

    decoder = LSTM(256, return_sequences=True, unroll=False)(inputs, initial_state=[encoder_last, encoder_last])

    attention = dot([decoder, encoder], axes=[2, 2])
    attention = Activation('softmax', name='attention')(attention)

    context = dot([attention, encoder], axes=[2, 1])
    concat = concatenate([context, decoder])

    return tf.expand_dims(concat, 2)