Python Tensorflow,从具有特定令牌id的输入中选择嵌入,并批处理结果

Python Tensorflow,从具有特定令牌id的输入中选择嵌入,并批处理结果,python,tensorflow,neural-network,Python,Tensorflow,Neural Network,我有一个带形状的输入张量IDinput\uids,[B x T]和一个带形状[B x T x D](B:批量大小,T:序列长度,D:维度)。输入ID是词汇表ID,嵌入矩阵包含对应的嵌入 从嵌入矩阵中,我想选择具有特定ID的元素(例如,103)。使用tf.where和tf.gather\u nd很容易做到这一点,但我不知道怎么做,就是将结果组织成一批大小[B x N x D],其中N是序列中id(103)的最大令牌数。我想根据需要使用0张量作为填充 代码可能会更好地显示它(比如说B=2、T=8和D

我有一个带形状的输入张量ID
input\uids
[B x T]
和一个带形状
[B x T x D]
(B:批量大小,T:序列长度,D:维度)
。输入ID是词汇表ID,嵌入矩阵包含对应的嵌入

从嵌入矩阵中,我想选择具有特定ID的元素(例如,
103
)。使用
tf.where
tf.gather\u nd
很容易做到这一点,但我不知道怎么做,就是将结果组织成一批大小
[B x N x D]
,其中
N
是序列中id(
103
)的最大令牌数。我想根据需要使用0张量作为填充

代码可能会更好地显示它(比如说
B=2、T=8和D=3
):

我想从
嵌入
中选择那些与
input\u id==103
相对应的嵌入,并用零填充剩余的结果。 我可以通过以下方式获得:

indices=  tf.where(tf.equal(input_ids, 103))
result = tf.gather_nd(indices=indices, params=embeddings)
#result.shape==[4x3]

# This will result in a [4x3] matrix where 4 = total number of 103 elements in the batch 
# and 3 is their corresponding embeddings dimension
# Now I want to organize this into a batch of the 
# same batch size as input, i.e., desired shape=(2x3)
# where first (1x3) row contains all token `103`'s embeddings
# in the first sequence but but second (1x3) row has only 
# one token 103 embedding (second sequence has only one 103 token)
# the rest are padded with zeros.
通常,这将导致一个
[M x D]
张量(M=批次中103个令牌的总数)。我想要的是
[B x N x D]
其中(N=每个序列中103个令牌的最大数量,对于上述情况,它是3)。我希望描述清楚(有点难以解释确切的问题)


如何实现这一点?

我认为它可以利用
tf.gather\u nd
在参数
索引为负值时返回
0
的属性

首先在
嵌入
中获取某些ID的索引值

将tensorflow导入为tf
tf.enable_eager_execution()
输入=tf.常数([[1019961630106199642231997],
[  101,  103,  3793,  103,  2443,  2000,  103,  2469]])
嵌入=tf.随机_正态((2,8,3))
条件=tf.相等(输入标识,103)
指数_值=tf.其中(条件)
# [[0 3]
#  [1 1]
#  [1 3]
#  [1 6]]
然后我们应该得到每个序列的令牌数和索引值的掩码

length = tf.reduce_sum(tf.cast(condition,tf.int32),axis=-1)
# [1 3]
indices_mask = tf.sequence_mask(length,tf.reduce_max(length))
# [[ True False False]
#  [ True  True  True]]
接下来,我们需要指定索引值在每个序列中的位置

result_indices = tf.scatter_nd(tf.where(indices_mask),
                               indices_value+1,
                               (indices_mask.shape[0],indices_mask.shape[1],tf.rank(input_ids)))-1
# [[[ 0  3]
#   [-1 -1]
#   [-1 -1]]
#
#  [[ 1  1]
#   [ 1  3]
#   [ 1  6]]]
最后我们通过
tf.gather\n
得到结果

result = tf.gather_nd(indices=result_indices, params=embeddings)
print(result)
# [[[ 1.22885     0.77642244 -0.82193506]
#   [ 0.          0.          0.        ]
#   [ 0.          0.          0.        ]]
# 
#  [[-0.0567691   0.07378497 -0.4799046 ]
#   [-1.1627238  -1.994217    0.8443906 ]
#   [ 0.776338   -0.25828102 -1.7915782 ]]]
result = tf.gather_nd(indices=result_indices, params=embeddings)
print(result)
# [[[ 1.22885     0.77642244 -0.82193506]
#   [ 0.          0.          0.        ]
#   [ 0.          0.          0.        ]]
# 
#  [[-0.0567691   0.07378497 -0.4799046 ]
#   [-1.1627238  -1.994217    0.8443906 ]
#   [ 0.776338   -0.25828102 -1.7915782 ]]]