Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/350.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 恢复训练过的Tensorflow模型KeyError:&x27;BlockLSTM&x27;_Python_Tensorflow_Lstm - Fatal编程技术网

Python 恢复训练过的Tensorflow模型KeyError:&x27;BlockLSTM&x27;

Python 恢复训练过的Tensorflow模型KeyError:&x27;BlockLSTM&x27;,python,tensorflow,lstm,Python,Tensorflow,Lstm,所以我试图加载我训练过的Tensorflow模型,但是得到了这个奇怪的错误,我无法找到关于这个特定错误的任何答案 这是我的储蓄电话: with tf.Session(graph=self.graph) as sess: saver = tf.train.Saver() for i in range(self.c.epochs): batch_data, batch_labels = self.get_batch(train_keys, self.c.doc_len

所以我试图加载我训练过的Tensorflow模型,但是得到了这个奇怪的错误,我无法找到关于这个特定错误的任何答案

这是我的储蓄电话:

with tf.Session(graph=self.graph) as sess:
    saver = tf.train.Saver()
    for i in range(self.c.epochs):
        batch_data, batch_labels = self.get_batch(train_keys, self.c.doc_len, self.c.num_classes, batch_size=self.c.batch_size)

        _, batch_loss = sess.run([self.optimizer, self.loss], feed_dict={self.input_data: batch_data, self.labels: batch_labels, self.dropout_rate: 0.5})

        if (i % 2 == 0 and i != 0 or i == self.c.epochs-1):
            saver.save(sess, save_model_file, global_step=2)
这是我的恢复功能:

tf.reset_default_graph()
    saver = tf.train.import_meta_graph(trained_model_name)

    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint('./'))

    graph = tf.get_default_graph()

    X_init = tf.placeholder(tf.float32, shape=(c.vocab_size, c.emb_size))

    input_data = graph.get_tensor_by_name("input_data")

    preds = graph.get_tensor_by_name("preds")

    init = tf.global_variables_initializer()

    sess.run(init, feed_dict={X_init: lexvec_model})

    pred = sess.run(preds, feed_dict={input_data: model_input})
我们的目标是使用恢复的模型进行推断,但我在“saver=tf.train.import\u meta\u graph(trained\u model\u name)”中得到一个错误。一些帮助会很好:)

错误代码:

Traceback (most recent call last):
  File "C:/Users/.../main/Predictor.py", line 94, in <module>
    prediction = predictor.predict(text_doc=doc)
  File "C:/Users/.../main/Predictor.py", line 57, in predict
    saver = tf.train.import_meta_graph(trained_model_name)
  File "C:\Users\...\Python36\lib\site-packages\tensorflow\python\training\saver.py", line 1927, in import_meta_graph **kwargs)
  File "C:\Users\...\Python\Python36\lib\site packages\tensorflow\python\framework\meta_graph.py", line 741, in import_scoped_meta_graph 
    producer_op_list=producer_op_list)
  File "C:\Users\...\Python\Python36\lib\site-packages\tensorflow\python\util\deprecation.py", line 432, in new_func return func(*args, **kwargs)
  File "C:\Users\...\Python\Python36\lib\site-packages\tensorflow\python\framework\importer.py", line 457, in import_graph_def _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
  File "C:\Users\...\Python\Python36\lib\site-packages\tensorflow\python\framework\importer.py", line 227, in _RemoveDefaultAttrs
    op_def = op_dict[node.op]
KeyError: 'BlockLSTM'
回溯(最近一次呼叫最后一次):
文件“C:/Users/../main/Predictor.py”,第94行,在
预测=预测值。预测(text\u doc=doc)
predict中第57行的文件“C:/Users/../main/Predictor.py”
saver=tf.train.import\u元图(已训练的模型名称)
文件“C:\Users\…\Python36\lib\site packages\tensorflow\python\training\saver.py”,第1927行,在import\u meta\u graph**kwargs中)
文件“C:\Users\…\Python\Python36\lib\site packages\tensorflow\Python\framework\meta\u graph.py”,第741行,位于导入范围的meta\u图中
制片人名单=制片人名单)
文件“C:\Users\…\Python\Python36\lib\site packages\tensorflow\Python\util\deprecation.py”,第432行,在new_func return func(*args,**kwargs)中
文件“C:\Users\…\Python\Python36\lib\site packages\tensorflow\Python\framework\importer.py”,第457行,在导入图定义中删除默认属性(操作目录、生产者列表、图形定义)
文件“C:\Users\…\Python\Python36\lib\site packages\tensorflow\Python\framework\importer.py”,第227行,位于RemoveDefaultAttrs中
op_def=op_dict[node.op]
KeyError:“BlockLSTM”

我在使用LSTMBlockFusedCell()时遇到了同样的问题。 解决之道在于

# for LSTMBlockFusedCell(), https://github.com/tensorflow/tensorflow/issues/23369
tf.contrib.rnn
# restore meta graph
meta_file = args.restore + '.meta'
loader = tf.train.import_meta_graph(meta_file, clear_devices=True)
...