Tensorflow 使用tf.gather\n沿轴获取所有可能的排列

Tensorflow 使用tf.gather\n沿轴获取所有可能的排列,tensorflow,Tensorflow,我试图从张量中提取所有可能的排列,沿着特定的轴。我的输入是a[B,S,L]张量(B批长度为L的S向量),我想提取这些向量(S!置换)中所有可能的置换,即a[B,S!,S,L]张量作为输出。 这就是我现在尝试的,但我正在努力获得正确的输出形状。我想我的错误可能是我正在创建一个批处理范围,但我也应该创建一个排列范围 import tensorflow as tf import numpy as np from itertools import permutations S = 3 B = 5 L

我试图从张量中提取所有可能的排列,沿着特定的轴。我的输入是a
[B,S,L]
张量(B批长度为L的S向量),我想提取这些向量(S!置换)中所有可能的置换,即a
[B,S!,S,L]
张量作为输出。 这就是我现在尝试的,但我正在努力获得正确的输出形状。我想我的错误可能是我正在创建一个批处理范围,但我也应该创建一个排列范围

import tensorflow as tf
import numpy as np
from itertools import permutations

S = 3
B = 5
L = 10

input = tf.constant(np.random.randn(B, S, L))

perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]
length_perm = len(perms)
perms = tf.reshape(tf.constant(perms), [1, length_perm, S, 1])
perms = tf.tile(perms, [B, 1, 1, 1])

batch_range = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), shape=[B, 1, 1, 1]), [1, length_perm, S, 1])
indicies = tf.concat([batch_range, perms], axis=3)

permutations = tf.gather_nd(tf.tile(tf.reshape(input, [B, 1, S, L]), [1, length_perm, 1, 1]), indicies) # 
# I get a [ B, P, S, S, L] instead of the desired [B, P, S, L]

我在下面贴了一个可能的“解决方案”,但我认为这个仍然存在问题。我对它进行了测试,如果B>1不太好。

我刚刚找到了一个答案,我想,如果你认为我错了,或者有更简单的方法,请纠正我:

import tensorflow as tf
import numpy as np
from itertools import permutations

S = 3
B = 5
L = 10

input = tf.constant(np.random.randn(B, S, L))

perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]
length_perm = len(perms)
perms = tf.reshape(tf.constant(perms), [1, length_perm, S, 1])
perms = tf.tile(perms, [B, 1, 1, 1])

batch_range = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), shape=[B, 1, 1, 1]), [1, length_perm, S, 1])
perm_range = tf.tile(tf.reshape(tf.range(length_perm, dtype=tf.int32), shape=[1, length_perm, 1, 1]), [B, 1, S, 1])
indicies = tf.concat([batch_range, perm_range, perms], axis=3)

permutations = tf.gather_nd(tf.tile(tf.reshape(input, [B, 1, S, L]), [1, length_perm, 1, 1]), indicies) # 
print permutations