Python 批量大小未知的Keras重复元素

Python 批量大小未知的Keras重复元素,python,tensorflow,keras,backend,Python,Tensorflow,Keras,Backend,我有一个函数,我需要用尺寸为(?,61,80)的张量和尺寸为(40,61)的二维张量做Keras批点。维度?用于自定义层中的批次大小。在使用Kerasrepeat\u元素时,我们需要指定批大小,使其成为的张量(批大小,40,61)。但是,重复元素不适用于?批量大小 代码是 M1 = K.expand_dims(M,axis=0) BatchM = K.repeat_elements(x=M1,rep=batch_size,axis=0) out1 = K.batch_dot(BatchM,Ash

我有一个函数,我需要用尺寸为
(?,61,80)
的张量和尺寸为
(40,61)
的二维张量做Keras批点。维度
用于自定义层中的批次大小。在使用Keras
repeat\u元素时,我们需要指定批大小,使其成为
的张量(批大小,40,61)
。但是,
重复元素
不适用于
批量大小

代码是

M1 = K.expand_dims(M,axis=0)
BatchM = K.repeat_elements(x=M1,rep=batch_size,axis=0)
out1 = K.batch_dot(BatchM,Ash1,axes=[2,1])
这里的
M
是大小为
(40,61)
的二维张量<代码>批次M
应给出
(批次尺寸,40,61)
Ash1
的尺寸
(?,61,80)

编辑1:

A= Input(shape=(61,80))
M= K.variable(np.random.rand(40,61))
n=1

import tensorflow as tf
M1 = K.expand_dims(M,axis=0)
BatchM = K.repeat_elements(x=M1,rep=tf.shape(A)[0],axis=0)
out1 = K.batch_dot(BatchM,Ash1,axes=[2,1])
此返回错误显示:

Traceback (most recent call last)

 File "<ipython-input-7-edc5ef31181b>", line 3, in <module>
    BatchM = K.repeat_elements(x=M1,rep=tf.shape(A)[0],axis=0)

  File "/home/hanumant/.conda/envs/kerasenv/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2092, in repeat_elements
    x_rep = [s for s in splits for _ in range(rep)]

  File "/home/hanumant/.conda/envs/kerasenv/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2092, in <listcomp>
    x_rep = [s for s in splits for _ in range(rep)]

TypeError: 'Tensor' object cannot be interpreted as an integer
回溯(最近一次呼叫上次)
文件“”,第3行,在
BatchM=K.重复元素(x=M1,rep=tf.形状(A)[0],轴=0)
文件“/home/hanumat/.conda/envs/kerasenv/lib/python3.6/site packages/keras/backend/tensorflow\u backend.py”,第2092行,在repeat\u元素中
x_rep=[s代表s,在范围内(rep)]
文件“/home/hanumat/.conda/envs/kerasenv/lib/python3.6/site packages/keras/backend/tensorflow_backend.py”,第2092行,在
x_rep=[s代表s,在范围内(rep)]
TypeError:“Tensor”对象不能解释为整数

事实上,您不需要以未知的批量大小重复元素。出于相同目的,您可以直接使用
K.dot()
K.permute\u维度

def customer_dot(a,b):
    a = K.permute_dimensions(a, (0, 2, 1))  # x = (?,80,61)
    b = K.permute_dimensions(b, (1, 0))  # kernel = (61,40)
    ab_dot = K.permute_dimensions(K.dot(a, b), (0, 2, 1)) # ab_dot = (?,40,80)
    return ab_dot

A = Input(shape=(61,80))
M = K.variable(np.random.rand(40,61))

result = customer_dot(A,M)
print(result.shape)

# print
(?, 40, 80)
您可以使用以下示例来查看结果是否与代码操作的结果相同

# print
A = K.constant(np.random.rand(3,2,4))
M = K.constant(np.random.rand(5,2))

M1 = K.expand_dims(M,axis=0)
BatchM = K.repeat_elements(x=M1,rep=K.int_shape(A)[0],axis=0)
out1 = K.batch_dot(BatchM,A,axes=[2,1])
print(K.eval(out1))
result = customer_dot(A,M)
print(K.eval(result))

[[[0.07588554 0.19896106 0.4122516  0.16694324]
  [0.02837059 0.07994501 0.15250334 0.05631477]
  [0.02922964 0.03180532 0.17185953 0.11346529]
  [0.24399586 0.64474815 1.3240533  0.53126353]
  [0.06582426 0.0952256  0.38014278 0.22963922]]

 [[0.05856805 0.31629622 0.37190455 0.15167782]
  [0.02006819 0.12145159 0.1384899  0.0497717 ]
  [0.03729554 0.09602766 0.14768752 0.11432388]
  [0.18666261 1.0198846  1.1952925  0.481425  ]
  [0.07623056 0.2298356  0.33025196 0.22802524]]

 [[0.29545793 0.27023914 0.14775626 0.22487558]
  [0.10839225 0.10083499 0.05140937 0.07595014]
  [0.13047284 0.10567644 0.08779343 0.15208915]
  [0.9481214  0.868726   0.47162086 0.7157058 ]
  [0.28504598 0.23714545 0.18145116 0.30803293]]]
[[[0.07588554 0.19896106 0.4122516  0.16694324]
  [0.02837059 0.07994501 0.15250334 0.05631477]
  [0.02922964 0.03180532 0.17185953 0.11346529]
  [0.24399586 0.64474815 1.3240533  0.53126353]
  [0.06582426 0.0952256  0.38014278 0.22963922]]

 [[0.05856805 0.31629622 0.37190455 0.15167782]
  [0.02006819 0.12145159 0.1384899  0.0497717 ]
  [0.03729554 0.09602766 0.14768752 0.11432388]
  [0.18666261 1.0198846  1.1952925  0.481425  ]
  [0.07623056 0.2298356  0.33025196 0.22802524]]

 [[0.29545793 0.27023914 0.14775626 0.22487558]
  [0.10839225 0.10083499 0.05140937 0.07595014]
  [0.13047284 0.10567644 0.08779343 0.15208915]
  [0.9481214  0.868726   0.47162086 0.7157058 ]
  [0.28504598 0.23714545 0.18145116 0.30803293]]]

批量大小到底是多少???如果你输入整数,问题是什么???。。。。。整数为批次大小实际上,此操作需要在自定义keras层内执行。如果我有大小为61x80的图像,keras层会自动添加一个批次维度,从而向层(?,61,80)进行输入。使用model.fit时将给出批处理大小。从fit的初始输入形状自动添加,对吗?哪个是存储在可变批次大小中的整数值?是。该程序在model.fit之前甚至无法运行。只有当它试图获得批处理形状时,才会在中途显示错误。请给出一个代码,我可以在其中重现问题,只是给出与您相同错误的最小部分