Numpy 意外的tenorflow(1.12)广播效率

Numpy 意外的tenorflow(1.12)广播效率,numpy,tensorflow,array-broadcasting,Numpy,Tensorflow,Array Broadcasting,在使用tensorflow实现一个类似于带潜在向量的FM的模型时,我遇到了意想不到的乘法广播效率问题:相同的阵列在差分转置后乘法的成本非常不同。 然而,使用numpy,效率差异并不是那么显著 那么,tensorflow 1.12和numpy之间是否存在一些广播规则差异? PS:tf1.14&tf2,工作正常,有人知道哪个重要更新修复了此问题吗 简单代码: [b, d, n, 1] * [n, k] # ok [b, n, d, 1] * [n, 1, k] # slow [b, n,

在使用tensorflow实现一个类似于带潜在向量的FM的模型时,我遇到了意想不到的乘法广播效率问题:相同的阵列在差分转置后乘法的成本非常不同。

然而,使用numpy,效率差异并不是那么显著

那么,tensorflow 1.12和numpy之间是否存在一些广播规则差异?

PS:tf1.14&tf2,工作正常,有人知道哪个重要更新修复了此问题吗

简单代码:

[b, d, n, 1] * [n, k]     # ok
[b, n, d, 1] * [n, 1, k]  # slow
[b, n, 1, d] * [n, k, 1]  # very very slow
全部代码:

import  tensorflow as tf
import numpy as np
from time import time
import timeit
sess = tf.InteractiveSession()

batch_size = 1024
k = 8
d = 32 # emb_size
n = 20 # slot count
input_var = np.random.randn(batch_size, n, d, 1)

v_nk = np.random.randn(n, k)
v_nkd = np.random.randn(n, k, d)
v_nk1 = np.reshape(v_nk, [n, k, 1])
v_n1k = np.reshape(v_nk, [n, 1, k])

input_var_bdn1 = np.transpose(input_var, [0, 2, 1, 3]).copy() # [b, d, n, 1]
input_var_bn1d = np.transpose(input_var, [0, 1, 3, 2]).copy() # [b, n, 1, d]
input_var_b1nd = np.transpose(input_var, [0, 3, 1, 2]).copy() # [b, 1, n, d]


# numpy
print('with numpy: ')

print ("X_nk COST: ",timeit.timeit(lambda: input_var_bdn1 * v_nk, number=100 )) # 3.1s
print ("X_nk1 COST: ",timeit.timeit(lambda: input_var_bn1d * v_nk1, number=100 )) # 2.5s
print ("X_n1k COST: ",timeit.timeit(lambda: input_var * v_n1k, number=100 )) # 3.0s
print ("X_nkd COST: ",timeit.timeit(lambda: input_var_bn1d * v_nkd, number=100 )) # 2.5s


input_var = tf.constant(input_var)
input_var_bdn1 = tf.constant(input_var_bdn1)
input_var_bn1d = tf.constant(input_var_bn1d)
input_var_b1nd = tf.constant(input_var_b1nd)

v_nk = tf.constant(v_nk)
v_nk1 = tf.constant(v_nk1)
v_n1k = tf.constant(v_n1k)
v_nkd = tf.constant(v_nkd)

input_X_nk = input_var_bdn1 * v_nk
input_X_n1k = input_var * v_n1k
input_X_nk1 = input_var_bn1d * v_nk1
input_X_nkd = input_var_bn1d * v_nkd

print()
print('with tf: ')

print ("X_nk COST: ",timeit.timeit(lambda: sess.run(input_X_nk), number=100 )) # 0.2s
print ("X_nk1 COST: ",timeit.timeit(lambda: sess.run(input_X_nk1), number=100 )) # 2.2s
print ("X_n1k COST: ",timeit.timeit(lambda: sess.run(input_X_n1k), number=100 ))  # 0.6s
print ("X_nkd COST: ",timeit.timeit(lambda: sess.run(input_X_nkd), number=100 ))  # 0.55s


for _ in range(10):
    input_X_nk += input_var_bdn1 * v_nk
    input_X_n1k += input_var * v_n1k
    input_X_nk1 += input_var_bn1d * v_nk1
    input_X_nkd += input_var_bn1d * v_nkd

print()
print('with tf straightly: ')

print ("X_nk COST: ",timeit.timeit(lambda: sess.run(input_X_nk), number=1 )) # 0.8s
print ("X_nk1 COST: ",timeit.timeit(lambda: sess.run(input_X_nk1), number=1 ))  # 6.1s
print ("X_n1k COST: ",timeit.timeit(lambda: sess.run(input_X_n1k), number=1 )) # 1.7s
print ("X_nkd COST: ",timeit.timeit(lambda: sess.run(input_X_nkd), number=1 )) # 1.6s

测试环境:
tensorflow 1.12.0(tf1.14/tf2工作正常
python 3.6.9和&py 2.7.18
Centos 7.4和Mac 10.14