Python 节点';合并/合并摘要';有来自不同框架的输入:这意味着什么?

Python 节点';合并/合并摘要';有来自不同框架的输入:这意味着什么?,python,tensorflow,deep-learning,summary,Python,Tensorflow,Deep Learning,Summary,尝试合并所有摘要时,我有一个错误,即merge/MergeSummary的输入来自不同的帧。首先,什么是框架?你能给我指一下TF文档中关于这些东西的地方吗?--当然,我在谷歌上搜索了一下,但几乎什么也没找到。如何解决此问题?下面的代码再现错误。提前谢谢 import numpy as np import tensorflow as tf tf.reset_default_graph() tf.set_random_seed(23) BATCH = 2 LENGTH = 4 SIZE = 5

尝试合并所有摘要时,我有一个错误,即
merge/MergeSummary
的输入来自不同的帧。首先,什么是框架?你能给我指一下TF文档中关于这些东西的地方吗?--当然,我在谷歌上搜索了一下,但几乎什么也没找到。如何解决此问题?下面的代码再现错误。提前谢谢

import numpy as np
import tensorflow as tf

tf.reset_default_graph()
tf.set_random_seed(23)

BATCH = 2
LENGTH = 4
SIZE = 5
ATT_SIZE = 3
NUM_QUERIES = 2

def linear(inputs, output_size, use_bias=True, activation_fn=None):
    """Linear projection."""

    input_shape = inputs.get_shape().as_list()
    input_size = input_shape[-1]
    output_shape = input_shape[:-1] + [output_size]
    if len(output_shape) > 2:
        output_shape_tensor = tf.unstack(tf.shape(inputs))
        output_shape_tensor[-1] = output_size
        output_shape_tensor = tf.stack(output_shape_tensor)
        inputs = tf.reshape(inputs, [-1, input_size])

    kernel = tf.get_variable("kernel", [input_size, output_size])
    output = tf.matmul(inputs, kernel)
    if use_bias:
        output = output + tf.get_variable('bias', [output_size])

    if len(output_shape) > 2:
        output = tf.reshape(output, output_shape_tensor)
        output.set_shape(output_shape)  # pylint: disable=I0011,E1101

    if activation_fn is not None:
        return activation_fn(output)
    return output


class Attention(object):
    """Attention mechanism implementation."""

    def __init__(self, attention_states, attention_size):
        """Initializes a new instance of the Attention class."""
        self._states = attention_states
        self._attention_size = attention_size
        self._batch = tf.shape(self._states)[0]
        self._length = tf.shape(self._states)[1]
        self._size = self._states.get_shape()[2].value
        self._features = None

    def _init_features(self):
        states = tf.reshape(
            self._states, [self._batch, self._length, 1, self._size])
        weights = tf.get_variable(
            "kernel", [1, 1, self._size, self._attention_size])
        self._features = tf.nn.conv2d(states, weights, [1, 1, 1, 1], "SAME")

    def get_weights(self, query, scope=None):
        """Reurns the attention weights for the given query."""
        with tf.variable_scope(scope or "Attention"):
            if self._features is None:
                self._init_features()
            else:
                tf.get_variable_scope().reuse_variables()
            vect = tf.get_variable("Vector", [self._attention_size])
            with tf.variable_scope("Query"):
                query_features = linear(query, self._attention_size, False)
                query_features = tf.reshape(
                    query_features, [-1, 1, 1, self._attention_size])

        activations = vect * tf.tanh(self._features + query_features)
        activations = tf.reduce_sum(activations, [2, 3])
        with tf.name_scope('summaries'):
            tf.summary.histogram('histogram', activations)
        return tf.nn.softmax(activations)

states = tf.placeholder(tf.float32, shape=[BATCH, None, SIZE])  # unknown length
queries = tf.placeholder(tf.float32, shape=[NUM_QUERIES, BATCH, ATT_SIZE])
attention = Attention(states, ATT_SIZE)
func = lambda x: attention.get_weights(x, "Softmax")
weights = tf.map_fn(func, queries)
for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
    name = var.name.replace(':', '_')
    tf.summary.histogram(name, var)
summary_op = tf.summary.merge_all()

states_np = np.random.rand(BATCH, LENGTH, SIZE)
queries_np = np.random.rand(NUM_QUERIES, BATCH, ATT_SIZE)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    weights_np, summary_str = sess.run([weights, summary_op], {states: states_np, queries: queries_np})
    print weights_np

错误消息确实不是用户友好的。它已更新为

ValueError: Cannot use 'map/while/summaries/histogram' as input to 'Merge/MergeSummary' because 'map/while/summaries/histogram' is in a while loop. See info log for more details.
正如新消息所说,问题在于您无法从while循环内部生成摘要。原始消息引用的
是while循环的“执行帧”——while循环每次迭代的所有状态都保存在

在这种情况下,
while\u循环
tf.map\u fn
创建,其中的摘要是
tf.summary.histogram(“直方图”,激活)

有几种方法可以解决这个问题。您可以从
get\u weights
中提取摘要,让
get\u weights
返回激活,使用
tf.map\u fn
调用中新返回的激活创建摘要

另一种方法是,如果NUM_查询是常量且很小,则可以静态展开循环,而不是使用
tf.map_fn
。以下是执行此操作的代码:

# TOP PART OF THE CODE IS THE SAME

states = tf.placeholder(tf.float32, shape=[BATCH, None, SIZE])  # unknown length
queries = tf.placeholder(tf.float32, shape=[NUM_QUERIES, BATCH, ATT_SIZE])
attention = Attention(states, ATT_SIZE)
func = lambda x: attention.get_weights(x, "Softmax")

# NEW CODE BEGIN
split_queries = tf.split(queries, NUM_QUERIES)
weights = []
for query in split_queries:
    weights.append(func(query))
# NEW CODE END

for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
    name = var.name.replace(':', '_')
    tf.summary.histogram(name, var)
summary_op = tf.summary.merge_all()

states_np = np.random.rand(BATCH, LENGTH, SIZE)
queries_np = np.random.rand(NUM_QUERIES, BATCH, ATT_SIZE)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # NEW CODE BEGIN
    results = sess.run(weights + [summary_op], {states: states_np, queries: queries_np})
    weights_np, summary_str = results[:-1], results[-1]
    # NEW CODE END
    print weights_np