Python Tensorflow:在时间流中的索引处查找张量

Python Tensorflow:在时间流中的索引处查找张量,python,tensorflow,Python,Tensorflow,我正在培训一个RNN,其中我需要使用索引来查找示例时间流的另一部分中的值 v = tf.constant([ [[.1, .2], [.3, .4]], # timestream 1 values [[.6, .5], [.7, .8]] # timestream 2 values ]) ixs = tf.constant([ [1, 0], # indices into timestream 1 values [0, 1] # indices into

我正在培训一个RNN,其中我需要使用索引来查找示例时间流的另一部分中的值

v = tf.constant([
    [[.1, .2], [.3, .4]],  # timestream 1 values
    [[.6, .5], [.7, .8]]   # timestream 2 values
])
ixs = tf.constant([
    [1, 0], # indices into timestream 1 values
    [0, 1]  # indices into timestream 2 values
])
我正在寻找一个op,它将执行查找并用张量值替换索引,并产生:

[
    [[.3, .4], [.1, .2]],
    [[.6, .5], [.7, .8]]
]
tf.gather和tf.gather听起来可能是正确的路径,但我真的不明白我从中得到了什么结果

v_at_ix = tf.gather(v, ixs, axis=-1)
sess.run(v_at_ix)
array([[[[0.2, 0.1],
         [0.1, 0.2]],
        [[0.4, 0.3],
         [0.3, 0.4]]],
       [[[0.5, 0.6],
         [0.6, 0.5]],
        [[0.8, 0.7],
         [0.7, 0.8]]]], dtype=float32)

v_at_ix = tf.gather_nd(v, ixs)
sess.run(v_at_ix)
array([[0.6, 0.5],
       [0.3, 0.4]], dtype=float32)

有人知道这样做的正确方法吗?

tf.gather只能获取基于指定轴的切片,其索引是并置的。在
v_at_ix=tf.聚集(v,ixs,轴=-1)

[1,0]
中的
1
表示
v
中的
[2]、[4]、[5]、[8]]

[1,0]
中的
0
表示
v
中的
[1]、[3]、[6]、[7]]

[0,1]
中的
0
表示
v
中的
[1]、[3]、[6]、[7]]

[0,1]
中的
1
表示
v
中的
[2]、[4]、[5]、[8]]

tf.gather\u nd能够以指定的索引获取切片,并且其索引是渐进的。在
v_at_ix=tf.聚集(v,ixs)

[1,0]
中的
1
表示
[.6,5],.7,8]]
中的
v

[1,0]
中的
0
表示
[.6,5]
中的
[.6,5],.7,8]

[0,1]
中的
0
表示
[.1,2],.3,4]]
中的
v

[0,1]
中的
1
表示
[.1,2],.3,4]
中的
[.3,4]

因此,当我们使用
tf.gather nd
时,我们需要的是
[[0,1],[0,0],[[1,0],[1,1]]
。它可以由
[[0,0],[1,1]]
[[1,0],[0,1]]
组成。前者是重复的行号,后者是
ixs
。所以我们可以做到

ixs_row = tf.tile(tf.expand_dims(tf.range(v.shape[0]),-1),multiples=[1,v.shape[1]])
ixs = tf.concat([tf.expand_dims(ixs_row,-1),tf.expand_dims(ixs,-1)],axis=-1)
v_at_ix = tf.gather_nd(v,ixs)