Python 如何使用索引列表对张量进行切片并合成新的张量

Python 如何使用索引列表对张量进行切片并合成新的张量,python,tensorflow,Python,Tensorflow,如何利用张量流循环对张量进行切片并合成新的张量 只是指: text_embeding =tf.constant( #index 0 index 1 index 2 [[[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]], [[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]],

如何利用张量流循环对张量进行切片并合成新的张量 只是指:

    text_embeding =tf.constant(
                       #index 0       index 1      index 2
                   [[[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]], 
                    [[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]],
                    [[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]]
                   ] 
               )
我想让批处理中的每个张量根据索引的组合得到一个新的值列表 索引列表=[[0,0],[1,1],[2,2],[0,1],[1,2],[0,2]]

我想要得到价值 '''

''' 我的代码是这样的,但批处理\u大小=输出\u层\u序列。在会话图准备就绪之前,形状[0]是无的,这是错误

vsp = tf.batch_gather(output_layer_sequence, tf.tile([[j, j + i]],multiples=[output_layer_sequence.shape[0],1]))  # batch * 2 * hidden_size
谢谢

使用
tf.gather()


你是对的;我只是使用output=tf.gather(文本嵌入,索引列表,axis=1)然后我实现了,非常感谢!
vsp = tf.batch_gather(output_layer_sequence, tf.tile([[j, j + i]],multiples=[output_layer_sequence.shape[0],1]))  # batch * 2 * hidden_size
for i in range(2):
    for j in range(2):
        vsp = tf.batch_gather(output_layer_sequence, tf.tile([[j, j + i]],multiples=[16,1]))  # batch * 2 * hidden_size
        # vsp = tf.batch_gather(output_layer_sequence, tf.tile([[j, j + i]],multiples=[output_layer_sequence.shape[0],1]))  # batch * 2 * hidden_size
        vsp_start, vsp_end = tf.split(vsp, 2, 1)  # batch * 1 * hiddensize
        vsp_start = tf.squeeze(vsp_start)  # batch  * hiddensize
             
        vsp_end = tf.squeeze(vsp_end)  # batch * hiddensize
        vsp = tf.concat([vsp_start, vsp_end], axis=-1, name='concat')  # [batch ,2*hiddensize]

        span_logits = tf.matmul(vsp, output_span_weight, transpose_b=True)  # output:[batch,class_labels]

        span_logits = tf.nn.bias_add(span_logits, output_span_bias)  # [batch,class_labels]
        span_logit_sum.append(span_logits)

text_embeding =tf.constant(
                    #index 0       index 1      index 2
                [[[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]], 
                [[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]],
                [[0.1,0.2,0.3],[0.4,0.5,0.6],[0.1,0.2,0.3]]
                ] 
            )
index_list = tf.constant([[0,0],[1,1],[2,2],[0,1],[1,2],[0,2]])
output = tf.gather(text_embeding, index_list, axis=2)