Tensorflow 在网络中使用张量作为索引
我的网络有多个输入,其中一个输入是一个索引,在网络中用于索引到其他张量 我在使用张量作为索引时遇到了问题Tensorflow 在网络中使用张量作为索引,tensorflow,machine-learning,keras,Tensorflow,Machine Learning,Keras,我的网络有多个输入,其中一个输入是一个索引,在网络中用于索引到其他张量 我在使用张量作为索引时遇到了问题 class MemoryLayer(tf.keras.layers.Layer): def __init__(self, memory_size, k, **kwargs): super().__init__(kwargs) self.memory_size = memory_size self.k = k def build(self,input_shap
class MemoryLayer(tf.keras.layers.Layer):
def __init__(self, memory_size, k, **kwargs):
super().__init__(kwargs)
self.memory_size = memory_size
self.k = k
def build(self,input_shape):
# Set up the memory_var
# Shape of input is [(1,3,6), (1,3)]
def call(self, input):
for i in range(3):
statement = input[0][0,i]
cluster = input[1][0,i]
old_sub_mem = self.memory_var[cluster, :-1] #Error here
# Here should be a bunch of stuff I removed because its not relevant
return tf.expand_dims(self.memory_var, axis=0)
我得到一个
TypeError
,说
不是有效的索引。我尝试在input[1]
上调用.numpy()
,但这不起作用,因为张量没有形状。从数据中输入的集群应该是一个数字。默认情况下,层的输入是tf.float32
。然而,要索引张量,需要整数。您可以将层的输入强制转换为整数,也可以指定该层的输入应为整数类型
铸造
指定输入的类型
我在该示例中使用了函数API:
inp_statement = tf.keras.Input(shape=(3,6))
inp_cluster = tf.keras.Input(shape=(3,), dtype=tf.int32)
memory = MemoryLayer(memory_size=10)([inp1,inp2])
注意:我不太明白您想要实现什么,但是调用tf.gather
或tf.gather\n
可能会优化这个for循环
inp_statement = tf.keras.Input(shape=(3,6))
inp_cluster = tf.keras.Input(shape=(3,), dtype=tf.int32)
memory = MemoryLayer(memory_size=10)([inp1,inp2])