Tensorflow 如何重用rnn单元进行推理

Tensorflow 如何重用rnn单元进行推理,tensorflow,Tensorflow,我的一些图形定义是用于培训的。看起来像这样 with tf.variable_scope('RNN', initializer=tf.contrib.layers.xavier_initializer()): self.rnn_cell = tf.nn.rnn_cell.GRUCell(HID_SZ, name='rnncell') self.init_state = tf.get_variable('init', [1, HID_SZ], tf.float32)

我的一些图形定义是用于培训的。看起来像这样

with tf.variable_scope('RNN', initializer=tf.contrib.layers.xavier_initializer()):
     self.rnn_cell = tf.nn.rnn_cell.GRUCell(HID_SZ, name='rnncell')
     self.init_state = tf.get_variable('init', [1, HID_SZ], tf.float32)

     self.init_state_train = tf.tile(self.init_state, [SZ_BATCH, 1])

     outputs, state = tf.nn.dynamic_rnn(self.rnn_cell, emb, initial_state=self.init_state_train, dtype=tf.float32, time_major=True)
然后我定义了用于推理的部分。现在看起来像

with tf.variable_scope("", reuse=True):
    [...]
    self.rnn_infer = tf.get_variable('RNN/rnncell')
    inputs_single = tf.expand_dims(emb_single, 0)
    input_state_ = tf.expand_dims(self.input_state, 0)
    output, hidden = self.rnn_infer(inputs_single, input_state_, name='rnncall')
但是
tf.get\u变量('RNN/rnncell')
会导致错误:

ValueError: You can only pass an initializer function that expects no arguments to its callable when the shape is not fully defined. The given initializer function expects the following args ['self', 'shape', 'dtype', 'partition_info']

我试图重新使用分配给
self.rnn\u cell
的变量进行推理,我该怎么做?

关键是,当您创建一个单元格并将其放入rnn中时,权重和ops会像往常一样在图上创建。因此,您可以像往常一样恢复权重

import tensorflow as tf
import numpy as np
import os


def build_and_train():
    HID_SZ = 1
    graph = tf.Graph()
    sess = tf.Session(graph=graph)
    ones = np.ones([2, 3])

    with graph.as_default():
        in_ = tf.placeholder(tf.float32, [2, 3])
        cell = tf.nn.rnn_cell.GRUCell(HID_SZ, name='rnncell')
        state = tf.zeros([2, HID_SZ])
        out, state = cell(in_, state)
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

    saver.save(sess, os.getcwd() + '\\model.ckpt')
    print('Cell output after training')
    print(sess.run(out, feed_dict={in_:ones}))

def infer():
    HID_SZ = 1
    graph = tf.Graph()
    sess = tf.Session(graph=graph)
    ones = np.ones([2, 3])

    with graph.as_default():
        in_ = tf.placeholder(tf.float32, [2, 3])
        cell = tf.nn.rnn_cell.GRUCell(HID_SZ, name='rnncell')
        state = tf.zeros([2, HID_SZ])
        out, state = cell(in_, state)
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

    print('random cell output')
    print(sess.run(out, feed_dict={in_:ones}))

    saver.restore(sess, 'model.ckpt')

    print('Trained cell output')
    print(sess.run(out, feed_dict={in_:ones}))


build_and_train()
infer()
这将产生:

Cell output after training
[[0.02710133]
 [0.02710133]]
random cell output
[[0.2458247]
 [0.2458247]]
Trained cell output
[[0.02710133]
 [0.02710133]]