Tensorflow 使用K.tile()复制张量

Tensorflow 使用K.tile()复制张量,tensorflow,keras,copy,tensor,replicate,Tensorflow,Keras,Copy,Tensor,Replicate,我有张量(None,196),在重塑后,它变成(None,14,14)。 现在,我想把它复制到通道轴,这样形状应该是(无,14,14,512)。最后,我想复制到timestep axis,因此它变成(None,10,14,14512)。我使用以下代码片段完成这些步骤: def replicate(tensor, input_target): batch_size = K.shape(tensor)[0] nf, h, w, c = input_target x = K.reshape

我有张量
(None,196)
,在重塑后,它变成
(None,14,14)
。 现在,我想把它复制到通道轴,这样形状应该是
(无,14,14,512)
。最后,我想复制到timestep axis,因此它变成
(None,10,14,14512)
。我使用以下代码片段完成这些步骤:

def replicate(tensor, input_target):
  batch_size = K.shape(tensor)[0]
  nf, h, w, c = input_target
  x = K.reshape(tensor, [batch_size, 1, h, w, 1])

  # Replicate to channel dimension
  x = K.tile(x, [batch_size, 1, 1, 1, c])

  # Replicate to timesteps dimension
  x = K.tile(x, [batch_size, nf, 1, 1, 1])

  return x

x = ...
x = Lambda(replicate, arguments={'input_target':input_shape})(x)
another_x = Input(shape=input_shape) # shape (10, 14, 14, 512)

x = layers.multiply([x, another_x])
x = ...
我绘制模型,输出的形状就像我想要的那样。但是,问题出现在模型训练中。我将批量大小设置为2。这将显示错误消息:

tensorflow.python.framework.errors\u impl.InvalidArgumentError:不兼容的形状:[8,10,14,14512]与[2,10,14,14512]
[{node multiply_1/mul}}=mul[T=DT_FLOAT,[u class=[“loc:@training/Adam/gradients/multiply_1/mul_grad/Sum”],[u device=“/job:localhost/replica:0/task:0/device:GPU:0]”(Lambda_2/Tile_1,[u arg_另一个_x_0_0/]
[{node metrics/top\u k\u categorical\u accurity/Mean\u 1/\u 265}}=\u Recv[client\u terminated=false,Recv\u device=“/job:localhost/replica:0/task:0/device:CPU:0”,send\u device=“/job:localhost/replica:0/task:0/device:GPU:0”,send\u device\u device\u化身=1,tensor\u name=“edge\u 6346\u metrics/top\u k\u categority/Mean\u 1”,tensor\u type=DT\u设备,浮点=“/job:localhost/replica:0/task:0/device:CPU:0”]()]]

看起来,
K.tile()
将批大小从2增加到8。当我将批大小设置为10时,它将变为1000

所以,我的问题是如何实现我想要的结果?使用
tile()
?还是应该使用
repeat\u elements()
?谢谢


我使用的是Tensorflow 1.12.0和Keras 2.2.4。

作为经验法则,尽量避免在
Lambda
层中发生的转换带来批量大小

当您使用
tile
操作时,您只设置了only需要更改的维度(例如,您在tile操作中有
batch\u size
值,这是错误的)。此外,我使用的是
tf.tile
而不是
K.tile
(tf 1.12似乎在Keras后端没有tile)

简单例子 给

>>> (?, 10, 14, 14, 512)

根据经验,尽量避免在
Lambda
层中发生的转换带来批量大小

当您使用
tile
操作时,您只设置了only需要更改的维度(例如,您在tile操作中有
batch\u size
值,这是错误的)。此外,我使用的是
tf.tile
而不是
K.tile
(tf 1.12似乎在Keras后端没有tile)

简单例子 给

>>> (?, 10, 14, 14, 512)

它能工作,谢谢!实际上我正在使用
K.tile
。顺便问一下,哪一个可以复制我的例子的张量:
tile()
repeat\u elements()
?我不认为这对你的情况有多大关系,因为你从1->某个数字
n
,两者都会产生相同的结果。但是如果你从m->n开始,你会得到不同的输出,这取决于你使用的函数。它可以工作,谢谢!实际上我使用的是
K.tile
。顺便说一句,是哪个对于我的例子,复制张量是很好的:
tile()
repeat\u elements()
?我不认为这对你的情况有多大关系,因为你从1->某个数字
n
,两者都会产生相同的结果。但是如果你从m->n开始,那么根据你使用的函数,你会得到不同的输出。
>>> (?, 10, 14, 14, 512)