Python Tensorflow表格查找int->;浮动

Python Tensorflow表格查找int->;浮动,python,tensorflow,Python,Tensorflow,给定一个未知维度的二维张量[?,?],其中包含整数(表示类),我想获得一个相同形状的新张量,但其值替换为从查找表中获取的浮点(表示类权重) 例如: inputs = [ [1,3,3], [2,4,2] ] lookup table: {1: 0.2, 2: 0.25, 3: 0.1, 4: 0.45} output: [ [0.2, 0.1, 0.1], [0.25, 0.45, 0.25] ] 我尝试用tf.map_fn链接两个lambda函数,迭代每一行,然后迭代每个元素: elem_i

给定一个未知维度的二维张量[?,?],其中包含整数(表示类),我想获得一个相同形状的新张量,但其值替换为从查找表中获取的浮点(表示类权重)

例如:

inputs = [ [1,3,3], [2,4,2] ]
lookup table: {1: 0.2, 2: 0.25, 3: 0.1, 4: 0.45}
output: [ [0.2, 0.1, 0.1], [0.25, 0.45, 0.25] ]
我尝试用tf.map_fn链接两个lambda函数,迭代每一行,然后迭代每个元素:

elem_iter = lambda y: unknown_lookup_function(y)
row_iter = lambda x: elem_iter(x)
weights = tf.map_fn(row_iter, inputs, dtype=tf.float32)
但找不到定义查找函数的正确方法。 关于如何实施这种行为有什么建议吗?是否有一个本地op可以代替map\fn使用?

我想您应该使用:

其思想是将查找表存储为数组。在
i
索引处,存储输入
i
的查找值。如果键不是整数而是字符串,则需要使用

# Note I pad a dummpy element at index-0.
lookup_table = tf.constant([0, 0.2, 0.25, 0.1, 0.45])

inputs = tf.constant([ [1,3,3], [2,4,2] ])
output = tf.gather(lookup_table, inputs)
with tf.Session() as sess:
  print sess.run(output)

> 
  [[ 0.2         0.1         0.1       ]
   [ 0.25        0.44999999  0.25      ]]