Python Tensorflow:在时间流中的索引处查找张量
我正在培训一个RNN,其中我需要使用索引来查找示例时间流的另一部分中的值Python Tensorflow:在时间流中的索引处查找张量,python,tensorflow,Python,Tensorflow,我正在培训一个RNN,其中我需要使用索引来查找示例时间流的另一部分中的值 v = tf.constant([ [[.1, .2], [.3, .4]], # timestream 1 values [[.6, .5], [.7, .8]] # timestream 2 values ]) ixs = tf.constant([ [1, 0], # indices into timestream 1 values [0, 1] # indices into
v = tf.constant([
[[.1, .2], [.3, .4]], # timestream 1 values
[[.6, .5], [.7, .8]] # timestream 2 values
])
ixs = tf.constant([
[1, 0], # indices into timestream 1 values
[0, 1] # indices into timestream 2 values
])
我正在寻找一个op,它将执行查找并用张量值替换索引,并产生:
[
[[.3, .4], [.1, .2]],
[[.6, .5], [.7, .8]]
]
tf.gather和tf.gather听起来可能是正确的路径,但我真的不明白我从中得到了什么结果
v_at_ix = tf.gather(v, ixs, axis=-1)
sess.run(v_at_ix)
array([[[[0.2, 0.1],
[0.1, 0.2]],
[[0.4, 0.3],
[0.3, 0.4]]],
[[[0.5, 0.6],
[0.6, 0.5]],
[[0.8, 0.7],
[0.7, 0.8]]]], dtype=float32)
v_at_ix = tf.gather_nd(v, ixs)
sess.run(v_at_ix)
array([[0.6, 0.5],
[0.3, 0.4]], dtype=float32)
有人知道这样做的正确方法吗?tf.gather只能获取基于指定轴的切片,其索引是并置的。在
v_at_ix=tf.聚集(v,ixs,轴=-1)
:
[1,0]
中的1
表示v
中的[2]、[4]、[5]、[8]]
[1,0]
中的0
表示v
中的[1]、[3]、[6]、[7]]
[0,1]
中的0
表示v
中的[1]、[3]、[6]、[7]]
[0,1]
中的1
表示v
中的[2]、[4]、[5]、[8]]
tf.gather\u nd能够以指定的索引获取切片,并且其索引是渐进的。在v_at_ix=tf.聚集(v,ixs)
:
[1,0]
中的1
表示[.6,5],.7,8]]
中的v
[1,0]
中的0
表示[.6,5]
中的[.6,5],.7,8]
[0,1]
中的0
表示[.1,2],.3,4]]
中的v
[0,1]
中的1
表示[.1,2],.3,4]
中的[.3,4]
因此,当我们使用tf.gather nd
时,我们需要的是[[0,1],[0,0],[[1,0],[1,1]]
。它可以由[[0,0],[1,1]]
和[[1,0],[0,1]]
组成。前者是重复的行号,后者是ixs
。所以我们可以做到
ixs_row = tf.tile(tf.expand_dims(tf.range(v.shape[0]),-1),multiples=[1,v.shape[1]])
ixs = tf.concat([tf.expand_dims(ixs_row,-1),tf.expand_dims(ixs,-1)],axis=-1)
v_at_ix = tf.gather_nd(v,ixs)