Python 根据多组索引对二维张量的列求和

Python 根据多组索引对二维张量的列求和,python,tensorflow,Python,Tensorflow,在tensorflow中,我想根据多组索引对2D张量的列求和 例如: 对下列张量的列求和 [[1 2 3 4 5] [5 4 3 2 1]] 根据两组索引(第一组为0 1 2列之和,第二组为3 4列之和) 应该给出两列 [[6 9] [12 3]] 备注: 所有列的索引将显示在一组且仅一组索引中 这必须在Tensorflow中完成,以便梯度可以通过此操作 你知道怎么做那个手术吗?我怀疑我需要使用tf.slice,可能还需要使用tf.while\u loop。如果您不介意用NumPy解决

在tensorflow中,我想根据多组索引对2D张量的列求和

例如:

对下列张量的列求和

[[1 2 3 4 5]
 [5 4 3 2 1]]
根据两组索引(第一组为0 1 2列之和,第二组为3 4列之和)

应该给出两列

[[6  9]
 [12 3]]
备注:

  • 所有列的索引将显示在一组且仅一组索引中
  • 这必须在Tensorflow中完成,以便梯度可以通过此操作

  • 你知道怎么做那个手术吗?我怀疑我需要使用tf.slice,可能还需要使用tf.while\u loop。

    如果您不介意用NumPy解决这个问题,我知道在NumPy中有一种粗略的解决方法

    import numpy as np
    
    mat = np.array([[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]])
    
    grid1 = np.ix_([0], [0, 1, 2])
    item1 = np.sum(mat[grid1])
    grid2 = np.ix_([1], [0, 1, 2])
    item2 = np.sum(mat[grid2])
    grid3 = np.ix_([0], [3, 4])
    item3 = np.sum(mat[grid3])
    grid4 = np.ix_([1], [3, 4])
    item4 = np.sum(mat[grid4])
    
    result = np.array([[item1, item3], [item2, item4]])
    
    您可以通过以下方式实现:

    输出:

    [[ 6  9]
     [12  3]]
    

    两组列索引是否可以具有公共元素?如果这是真的,那么每个列索引是否总是其中一个集合的一部分?当然,一个通用的解决方案会更好,但是一个受限的版本可能会更容易解决…实际上我真的需要在Tensorflow中解决这个问题。因为梯度需要通过这个操作。根据这个页面将NumPy数组转换成TensorFlow张量会有帮助吗:?我不这么认为。因为源元素(numpy数组)和结果之间的链接仍然是通过numpy而不是tensorflow完成的。这个“mat”(您使用的变量的名称)是使用我正在培训的tf.variable生成的。聚合的结果由tensorflow trainer使用。一切都需要在Tensorflow中发生。很好!非常感谢,这个tf.segment_sum是我一直在寻找的操作,但找不到。@Tom我已将其更改为类似但更通用的解决方案。为了使您的示例生效,我必须在最后一行之前添加:将tf.Session()作为sess:
    import tensorflow as tf
    
    nums = [[1, 2, 3, 4, 5],
            [5, 4, 3, 2, 1]]
    column_idx = [[0, 1, 2], [3, 4]]
    
    with tf.Session() as sess:
        # Data as TF tensor
        data = tf.constant(nums)
        # Make segment ids
        segments = tf.concat([tf.tile([i], [len(lst)]) for i, lst in enumerate(column_idx)], axis=0)
        # Select columns
        data_cols = tf.gather(tf.transpose(data), tf.concat(column_idx, axis=0))
        col_sum = tf.transpose(tf.segment_sum(data_cols, segments))
        print(sess.run(col_sum))
    
    [[ 6  9]
     [12  3]]