Python tensorflow:沿第二维度切片张量

Python tensorflow:沿第二维度切片张量,python,tensorflow,Python,Tensorflow,我有一个张量X,它的形状是(无,56300,1),还有一个张量y,它的形状是(无,15),这些张量的第一维是批量大小,我想用y作为索引得到一个张量z,z的形状是(无,15300,1)。有什么像样的方法可以做到这一点吗 我写了一个简单的代码来测试,因为我发现这对我来说很难,因为在实践中我不知道批量大小(这些张量的第一维是无的) 以下是我的测试代码: 将numpy导入为np 导入tensorflow作为tf #在本测试代码中,批次大小为4。 #参数的形状是(4,3,2,1),实际上是(无,56300

我有一个张量X,它的形状是(无,56300,1),还有一个张量y,它的形状是(无,15),这些张量的第一维是批量大小,我想用y作为索引得到一个张量z,z的形状是(无,15300,1)。有什么像样的方法可以做到这一点吗

我写了一个简单的代码来测试,因为我发现这对我来说很难,因为在实践中我不知道批量大小(这些张量的第一维是无的)

以下是我的测试代码:

将numpy导入为np
导入tensorflow作为tf
#在本测试代码中,批次大小为4。
#参数的形状是(4,3,2,1),实际上是(无,56300,1),
参数=[
[a0']、[b0']、[d0']、[e0']、[f0']、[g0']],
[a1']、[b1']、[d1']、[e1']、[f1']、[g1']],
[a2'],[b2'],[d2'],[e2'],[f2'],[g2'],
[a3']、[b3']、[d3']、[e3']、[f3']、[g3'],
]
#ind的形状是(4,2)(实际上是(无,15)),
#所以我想得到输出,它的形状是(4,2,2,1),(实际上是(无,15300,1))
ind=[[1,0],[0,2],[2,0],[2,1]]
#输出=[
#[d0'],[e0'],[a0'],[b0']],
#[a1']、[b1']、[f1']、[g1']],
#[f2'],[g2'],[a2'],[b2']],
#[['f3'],['g3'],[['d3'],['e3']]
#]
使用tf.variable_scope('gather')作为作用域:
tf_par=tf.常数(参数)
tf_ind=tf.常数(ind)
res=tf.聚集(tf\u par,tf\u ind)
使用tf.Session()作为sess:
init=tf.global_variables_initializer()
打印sess.run(res)
打印资源

对于您假设的结果,您应该使用:

ind = [[0, 1], [0, 0], [1, 0], [1, 2], [2, 2], [2, 0], [3, 2], [3, 1]]
更新

通过当前输入,您可以使用此代码获取所需内容:

with tf.variable_scope('gather') as scope:
    tf_par = tf.constant(params)
    tf_ind = tf.constant(ind)

    tf_par_shape = tf.shape(tf_par)
    tf_ind_shape = tf.shape(tf_ind)
    tf_r = tf.div(tf.range(0, tf_ind_shape[0] * tf_ind_shape[1]), tf_ind_shape[1])
    tf_r = tf.expand_dims(tf_r, 1)
    tf_ind = tf.expand_dims(tf.reshape(tf_ind, shape = [-1]), 1)
    tf_ind = tf.concat([tf_r, tf_ind], axis=1)

    res = tf.gather_nd(tf_par, tf_ind)
    res = tf.reshape(res, shape = (-1, tf_ind_shape[1], tf_par_shape[2], tf_par_shape[3]))

对于您假设的结果,您应该使用:

ind = [[0, 1], [0, 0], [1, 0], [1, 2], [2, 2], [2, 0], [3, 2], [3, 1]]
更新

通过当前输入,您可以使用此代码获取所需内容:

with tf.variable_scope('gather') as scope:
    tf_par = tf.constant(params)
    tf_ind = tf.constant(ind)

    tf_par_shape = tf.shape(tf_par)
    tf_ind_shape = tf.shape(tf_ind)
    tf_r = tf.div(tf.range(0, tf_ind_shape[0] * tf_ind_shape[1]), tf_ind_shape[1])
    tf_r = tf.expand_dims(tf_r, 1)
    tf_ind = tf.expand_dims(tf.reshape(tf_ind, shape = [-1]), 1)
    tf_ind = tf.concat([tf_r, tf_ind], axis=1)

    res = tf.gather_nd(tf_par, tf_ind)
    res = tf.reshape(res, shape = (-1, tf_ind_shape[1], tf_par_shape[2], tf_par_shape[3]))

使用
ind
沿第二维度切片
x
,即切片

  • 形状
    (d0,d1,d2,…)的张量
    x
    d0
    可能为
  • 使用形状
    (d0,n1)
    的索引
    ind
    张量
  • 要获得形状为(d0,n1,d2,…)的张量
    y
您可以使用
tf.gather\n和
tf.shape
在运行时获取形状:

ind_shape = tf.shape(ind)
ndind = tf.stack([tf.tile(tf.range(ind_shape[0])[:, None], [1, ind_shape[1]]),
                  ind], axis=-1)
y = tf.gather_nd(x, ndind)

使用
ind
沿第二维度切片
x
,即切片

  • 形状
    (d0,d1,d2,…)的张量
    x
    d0
    可能为
  • 使用形状
    (d0,n1)
    的索引
    ind
    张量
  • 要获得形状为(d0,n1,d2,…)的张量
    y
您可以使用
tf.gather\n和
tf.shape
在运行时获取形状:

ind_shape = tf.shape(ind)
ndind = tf.stack([tf.tile(tf.range(ind_shape[0])[:, None], [1, ind_shape[1]]),
                  ind], axis=-1)
y = tf.gather_nd(x, ndind)

你能在这里添加你对示例代码的期望结果吗?据我理解,
[1,0]
应该是
[[a1'],['b1']
,但是你的期望是
[[d0'],['e0']
@YuwenYan[[a0'],['b0'],['d0'],['e0'],['f0'],['g0']]是第一个示例。[1,0]对第一个样本的响应,因此答案应该是['d0']、['e0']、['a0']、['b0']]。注意ind的长度是4,每个元素对param中每个样本的响应是否可以将样本编码的预期结果添加到这里,据我理解,
[1,0]
应该是
[['a1'],['b1']
,但是您的预期是
['d0'],['e0']
@YuwenYan[['a0'],['b0'],['d0'],['f0'],['g0']这是第一个样本。[1,0]对第一个样本的响应,因此答案应该是['d0']、['e0']、['a0']、['b0']]。请注意,ind的长度是4,每个元素响应paramsys中的每个样本,我正在尝试这样做。但这样,我必须指定批次大小。是的,我正在尝试这样做。但是用这种方式,我必须指定批量大小。是否要评论否决票?我看到你的代表在弗拉基米尔身上跌了1分,那肯定不是你?是的,我可以解释我投的反对票。即使我将
X
更改为
tf\u par
并将
y
更改为
tf\u ind
,您的代码也不起作用。我认为最好是给出问题中的代码示例。抱歉。@Vladimir是的,由于OP第一段中对问题的描述,有一个硬编码的
15
,这确实与他后面的示例不一致。我用一个更通用的
tf.shape(y)[1]
替换了它,它现在应该在所有情况下都能工作了。要不要评论一下否决票?我看到你的代表在弗拉基米尔身上跌了1分,那肯定不是你?是的,我可以解释我投的反对票。即使我将
X
更改为
tf\u par
并将
y
更改为
tf\u ind
,您的代码也不起作用。我认为最好是给出问题中的代码示例。抱歉。@Vladimir是的,由于OP第一段中对问题的描述,有一个硬编码的
15
,这确实与他后面的示例不一致。我用一个更通用的
tf.shape(y)[1]
替换了它,它现在应该在所有情况下都能工作。