Warning: file_get_contents(/data/phpspider/zhask/data//catemap/3/flash/4.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Matrix 如何在tensorflow中实现某种乘法_Matrix_Tensorflow - Fatal编程技术网

Matrix 如何在tensorflow中实现某种乘法

Matrix 如何在tensorflow中实现某种乘法,matrix,tensorflow,Matrix,Tensorflow,我有一个维度为[A,b,c,d]的张量A,另一个维度为[b,b,d,e]的张量b,和c,一个从0到b的整数列表。我需要生成维度为[a,b,c,e]的张量D,由 D[i,j,k,l]=A[i,C[i],k,m]*B[C[i],j,m,l]的m=0..D之和 b足够小(通常是3或5个?),所以我不介意在b单独的操作中这样做——但我不能浪费时间去做需要b^2内存或时间的事情,而这个操作在b中显然应该是线性的。这似乎是点态乘法(带广播?)和张量收缩(矩阵在公共m维度上相乘)的某种组合,但我无法确定 如果

我有一个维度为
[A,b,c,d]
的张量
A
,另一个维度为
[b,b,d,e]
的张量
b
,和
c
,一个从0到
b
的整数列表。我需要生成维度为[a,b,c,e]的张量
D
,由

D[i,j,k,l]=A[i,C[i],k,m]*B[C[i],j,m,l]的m=0..D之和

b
足够小(通常是3或5个?),所以我不介意在
b
单独的操作中这样做——但我不能浪费时间去做需要
b^2
内存或时间的事情,而这个操作在
b
中显然应该是线性的。这似乎是点态乘法(带广播?)和张量收缩(矩阵在公共
m
维度上相乘)的某种组合,但我无法确定

如果有人真的能让我相信这在使用tensorflow提供的操作的
O(b)
触发器中是不可能的,那么好吧,但我肯定想要
O(b^2)


更新:看起来经过适当修改的
A
张量可以使用
tf.gather\nd
单独构建;如果这可以与
B
以某种方式配对,也许?不幸的是,到目前为止,我在这方面的实验发现了
tf.gather和
本身的一个bug,它使事情变得缓慢。

我找到了合理有效地完成这一任务的方法。首先使用
tf.gather
构建一个修改版的
B
,在第一个索引中包含适当的部分:

B2 = tf.gather(B, C)
然后使用
tf.gather\n和
仅拉出
A
张量的相关部分。我们将拉出一组索引对,其形式为
[0,C[0]]、[1,C[1]]、[2,C[2]]…
等等,因此首先我们需要构建索引张量

a = tf.shape(A)[0]
A2_indices = tf.stack([tf.range(a), C], axis=0)
A2 = tf.gather_nd(A, A2_indices)
用形状
[a,c,d]
生产
A2
。现在我们需要适当地乘以
A2
B2
。它在
m
索引(分别为2和3)中是张量收缩,但在
i
索引中是逐点乘法(两者均为0)。这意味着,遗憾的是,结果项不是张量收缩或逐点乘法!一种选择是计算张量积并仅对
m
进行收缩,然后对两个
i
索引进行
tf.diag
——但这将浪费大量计算来构建我们不需要的矩阵的其余部分。相反,我们可以将其视为批处理矩阵乘法:它过去被称为
tf.batched_matmul
,但现在它只是
matmul
。不过,这有一个警告,即除了每个输入张量中的2个矩阵维数外,其余的都必须是逐点乘法
B
B2
不符合此标准,因为它们具有附加的
j
索引。但是,我们可以使用
l
output维度“包装它”,然后稍后删除它。这意味着首先调用
tf.transpose
j
l
放在一起,然后
tf.restrape
将其转换为一个
j*l
输出维度,然后执行
tf.matmul
,然后执行另一个
tf.restrape
tf.transpose
以返回原始形式。所以

a, b, d, e = B2.get_shape().as_list()
B2_trans = tf.transpose(B2, perm=[0,2,1,3])
B2_jl = tf.reshape(B2, [a,d,b*e])
product_jl = tf.matmul(A2, B2_jl)
product_trans = tf.reshape(product_jl, [a,d,b,e])
result = tf.transpose(product_trans, perm=[0,2,1,3])
就这样结束了!当然,在实践中,
B
可能仅在这一种情况下才需要,在这种情况下,
B
可能已经在“压缩”状态下启动,从而节省转置(和廉价的重塑);或者,如果
A2
将被展平或转置,那么它也可以保存转置。但总体而言,一切都非常简单。:)