Warning: file_get_contents(/data/phpspider/zhask/data//catemap/8/logging/2.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
批量内的Tensorflow乘法广播_Tensorflow - Fatal编程技术网

批量内的Tensorflow乘法广播

批量内的Tensorflow乘法广播,tensorflow,Tensorflow,我们知道tf.multiply可以这样广播: import tensorflow as tf import numpy as np a = tf.Variable(np.arange(12).reshape(3, 4)) b = tf.Variable(np.arange(4)) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) sess.run(tf.multiply(a, b)) 这将给我们

我们知道tf.multiply可以这样广播:

import tensorflow as tf
import numpy as np
a = tf.Variable(np.arange(12).reshape(3, 4))
b = tf.Variable(np.arange(4))
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(tf.multiply(a, b))
这将给我们

[[0, 1, 4, 9],
 [0, 5, 12, 21],
 [0, 9, 20, 33]]
但我的问题是,如果
a
b
都是成批的,我该怎么办?就是

a = tf.Variable(np.arange(24).reshape(2, 3, 4))
b = tf.Variable(np.arange(8).reshape(2, 4))
那么,我怎样才能得到每批向量乘以(广播)矩阵的结果呢?我喜欢下面的答案:

[[[0, 1, 4, 9],
  [0, 5, 12, 21],
  [0, 9, 20, 33]],

 [[48, 65, 84, 105],
  [64, 85, 108, 133],
  [80, 105, 132, 161]]]

谢谢

广播首先在左侧添加单例维度,直到秩匹配为止。在第一种情况下,添加批处理维度。但在第二种情况下,您已经有批处理维度,因此需要在第二个位置手动插入单例维度:

a = tf.reshape(tf.range(24), (2, 3, 4))
b = tf.reshape(tf.range(8), (2, 4))
sess.run(tf.mul(a, tf.expand_dims(b, 1)))