Python 张量流元素相乘广播?

Python 张量流元素相乘广播?,python,tensorflow,Python,Tensorflow,tensorflow是否为最后一个维度上的元素相乘广播提供了任何功能 以下是我正在尝试做的和不起作用的示例: import tensorflow as tf x = tf.constant(5, shape=(1, 200, 175, 6), dtype=tf.float32) y = tf.constant(1, shape=(1, 200, 175), dtype=tf.float32) tf.math.multiply(x, y) 基本上,我希望对x沿最后一个维度的每一个切片,使用y进行

tensorflow是否为最后一个维度上的元素相乘广播提供了任何功能

以下是我正在尝试做的和不起作用的示例:

import tensorflow as tf
x = tf.constant(5, shape=(1, 200, 175, 6), dtype=tf.float32)
y = tf.constant(1, shape=(1, 200, 175), dtype=tf.float32)
tf.math.multiply(x, y)
基本上,我希望对
x
沿最后一个维度的每一个切片,使用
y
进行元素矩阵乘法

我发现这个问题询问类似的操作:

不幸的是,建议的方法(使用
tf.multiply()
)现在不再有效。相应的
tf.math.multiply
也不起作用,因为上面的代码给出了以下错误:

Traceback (most recent call last):
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1864, in _create_c_op
    c_op = c_api.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimensions must be equal, but are 175 and 200 for 'Mul' (op: 'Mul') with input shapes: [1,200,175,6], [1,200,175].

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py", line 180, in wrapper
    return target(*args, **kwargs)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py", line 322, in multiply
    return gen_math_ops.mul(x, y, name)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 6490, in mul
    "Mul", x=x, y=y, name=name)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
    op_def=op_def)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 2027, in __init__
    control_input_ops)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1867, in _create_c_op
    raise ValueError(str(e))
ValueError: Dimensions must be equal, but are 175 and 200 for 'Mul' (op: 'Mul') with input shapes: [1,200,175,6], [1,200,175].
回溯(最近一次呼叫最后一次):
文件“/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py”,第1864行,在“创建”和“操作”中
c_op=c_api.TF_FinishOperation(操作说明)
tensorflow.python.framework.errors\u impl.InvalidArgumentError:对于输入形状为[1200175,6]、[1200175]的“Mul”(op:'Mul'),维度必须相等,但为175和200。
在处理上述异常期间,发生了另一个异常:
回溯(最近一次呼叫最后一次):
文件“”,第1行,在
文件“/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site packages/tensorflow/python/util/dispatch.py”,第180行,在包装器中
返回目标(*args,**kwargs)
文件“/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site packages/tensorflow/python/ops/math_ops.py”,第322行,倍增
返回gen_math_ops.mul(x,y,name)
文件“/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site packages/tensorflow/python/ops/gen_math_ops.py”,第6490行,mul格式
“Mul”,x=x,y=y,name=name)
文件“/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/op_-def_-library.py”,第788行,在“应用”op_-helper中
op_def=op_def)
文件“/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site packages/tensorflow/python/util/deprecation.py”,第507行,新函数
返回函数(*args,**kwargs)
文件“/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py”,第3616行,在create_-op中
op_def=op_def)
文件“/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site packages/tensorflow/python/framework/ops.py”,第2027行,在__
控制(输入操作)
文件“/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py”,第1867行,在“创建”和“操作”中
提升值错误(str(e))
ValueError:尺寸必须相等,但输入形状为[1200175,6]、[1200175]的“Mul”(op:'Mul')的尺寸为175和200。
我可以想出一种工作方法:将
y
复制6次,使其具有与
x
完全相同的形状,然后进行元素乘法


但是在tensorflow中有没有更快、更节省内存的方法呢?

这应该可以实现您想要的:

x = np.array([[[1,2,3],[4,5,6],[7,8,9],[10,11,12]]])
# [[[ 1  2  3]
#   [ 4  5  6]
#   [ 7  8  9]
#   [10 11 12]]]
y = np.array([[1,2,3,4]])
# [[1 2 3 4]]
y = tf.expand_dims(y, axis=-1)
mul = tf.multiply(x, y)
# [[[ 1  2  3]
#   [ 8 10 12]
#   [21 24 27]
#   [40 44 48]]]
最后,使用您需要的形状:

x = np.random.rand(1, 200, 175, 6)
y = np.random.rand(1, 200, 175)
y = tf.expand_dims(y, axis=-1)
mul = tf.multiply(x, y)
with tf.Session() as sess:
    print(sess.run(mul).shape)
    # (1, 200, 175, 6)
​

我懂了。所以我缺少这个
tf.expand_dims()
语句。