Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
如何将LSTMCell的变量设置为输入,而不是让它在Tensorflow中创建它?_Tensorflow_Lstm - Fatal编程技术网

如何将LSTMCell的变量设置为输入,而不是让它在Tensorflow中创建它?

如何将LSTMCell的变量设置为输入,而不是让它在Tensorflow中创建它?,tensorflow,lstm,Tensorflow,Lstm,当我创建tf.contrib.rnn.LSTMCell时,它会在初始化期间创建其内核和偏差可训练变量 代码现在的样子: cell_fw = tf.contrib.rnn.LSTMCell(hidden_size_char, state_is_tuple=True) 我希望它看起来像什么: kernel = tf.get_variable(...) bias = tf.get_variable(...) cell_fw = tf.contrib.r

当我创建tf.contrib.rnn.LSTMCell时,它会在初始化期间创建其内核偏差可训练变量

代码现在的样子:

cell_fw = tf.contrib.rnn.LSTMCell(hidden_size_char,
                        state_is_tuple=True)
我希望它看起来像什么:

kernel = tf.get_variable(...)
bias = tf.get_variable(...)
cell_fw = tf.contrib.rnn.LSTMCell(kernel, bias, hidden_size,
                        state_is_tuple=True)
我想做的是自己创建这些变量,并在将其实例化为init的输入时将其提供给LSTMCell类


有没有一个简单的方法可以做到这一点?我查看了,但它似乎位于复杂的类层次结构中。

我对LSTMCell类进行了子类化,并更改了它的init和build 方法,以便它们接受给定的变量。如果变量在init中给定 在构建中,我们不再使用get_变量,而是使用给定的内核和bias变量

不过,可能有更干净的方法

_BIAS_VARIABLE_NAME = "bias"
_WEIGHTS_VARIABLE_NAME = "kernel"

class MyLSTMCell(tf.contrib.rnn.LSTMCell):
    def __init__(self, num_units,
                 use_peepholes=False, cell_clip=None,
                 initializer=None, num_proj=None, proj_clip=None,
                 num_unit_shards=None, num_proj_shards=None,
                 forget_bias=1.0, state_is_tuple=True,
                 activation=None, reuse=None, name=None, var_given=False, kernel=None, bias=None):

        super(MyLSTMCell, self).__init__(num_units,
                 use_peepholes=use_peepholes, cell_clip=cell_clip,
                 initializer=initializer, num_proj=num_proj, proj_clip=proj_clip,
                 num_unit_shards=num_unit_shards, num_proj_shards=num_proj_shards,
                 forget_bias=forget_bias, state_is_tuple=state_is_tuple,
                 activation=activation, reuse=reuse, name=name)

        self.var_given = var_given
        if self.var_given:
            self._kernel = kernel
            self._bias = bias


    def build(self, inputs_shape):
        if inputs_shape[1].value is None:
            raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
                             % inputs_shape)

        input_depth = inputs_shape[1].value
        h_depth = self._num_units if self._num_proj is None else self._num_proj
        maybe_partitioner = (
            partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
            if self._num_unit_shards is not None
            else None)
        if self.var_given:
            # self._kernel and self._bais are already added in init
            pass
        else:
            self._kernel = self.add_variable(
                _WEIGHTS_VARIABLE_NAME,
                shape=[input_depth + h_depth, 4 * self._num_units],
                initializer=self._initializer,
                partitioner=maybe_partitioner)
            self._bias = self.add_variable(
                _BIAS_VARIABLE_NAME,
                shape=[4 * self._num_units],
                initializer=init_ops.zeros_initializer(dtype=self.dtype))
        if self._use_peepholes:
            self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units],
                                               initializer=self._initializer)
            self._w_i_diag = self.add_variable("w_i_diag", shape=[self._num_units],
                                               initializer=self._initializer)
            self._w_o_diag = self.add_variable("w_o_diag", shape=[self._num_units],
                                               initializer=self._initializer)

        if self._num_proj is not None:
            maybe_proj_partitioner = (
                partitioned_variables.fixed_size_partitioner(self._num_proj_shards)
                if self._num_proj_shards is not None
                else None)
            self._proj_kernel = self.add_variable(
                "projection/%s" % _WEIGHTS_VARIABLE_NAME,
                shape=[self._num_units, self._num_proj],
                initializer=self._initializer,
                partitioner=maybe_proj_partitioner)

        self.built = True
因此,代码如下所示:

kernel = get_variable(...)
bias = get_variable(...)
lstm_fw = MyLSTMCell(....., var_given=True, kernel=kernel, bias=bias)

非常感谢。我得到的
名称“init_ops”未定义
错误。因此,在我的
MyLSTMCell
中,我添加了来自tensorflow.python.ops import init_ops的