Python 将自定义损耗添加到自动编码器重建损耗
这是我在这里的第一篇文章,所以我希望它符合指导原则,对除我之外的其他人来说也是有趣的 我正在构建一个CNN自动编码器,它将固定大小的矩阵作为输入矩阵,目的是获得它们的低维表示(我在这里称它们为哈希)。当矩阵相似时,我想使这些散列相似。因为我的数据中只有一些是有标签的,所以我想把损失函数变成两个独立函数的组合。一部分是自动编码器的重建错误(该部分工作正常)。另一部分,我希望它用于标记的数据。因为我将有三个不同的类,所以我希望在每个批处理上都有这三个类,以计算属于同一个类的散列值之间的距离(我在实现这一点上遇到了困难) 我迄今为止的努力:Python 将自定义损耗添加到自动编码器重建损耗,python,tensorflow,autoencoder,Python,Tensorflow,Autoencoder,这是我在这里的第一篇文章,所以我希望它符合指导原则,对除我之外的其他人来说也是有趣的 我正在构建一个CNN自动编码器,它将固定大小的矩阵作为输入矩阵,目的是获得它们的低维表示(我在这里称它们为哈希)。当矩阵相似时,我想使这些散列相似。因为我的数据中只有一些是有标签的,所以我想把损失函数变成两个独立函数的组合。一部分是自动编码器的重建错误(该部分工作正常)。另一部分,我希望它用于标记的数据。因为我将有三个不同的类,所以我希望在每个批处理上都有这三个类,以计算属于同一个类的散列值之间的距离(我在实现
X = tf.placeholder(shape=[None, 512, 128, 1], dtype=tf.float32)
class1_indices = tf.placeholder(shape=[None], dtype=tf.int32)
class2_indices = tf.placeholder(shape=[None], dtype=tf.int32)
hashes, reconstructed_output = self.conv_net(X, weights, biases_enc, biases_dec, keep_prob)
class1_hashes = tf.gather(hashes, class1_indices)
class1_cost = self.calculate_within_class_loss(class1_hashes)
class2_hashes = tf.gather(hashes, class2_indices)
class2_cost = self.calculate_within_class_loss(class2_hashes)
loss_all = tf.reduce_sum(tf.square(reconstructed_output - X))
loss_labeled = class1_cost + class2_cost
loss_op = loss_all + loss_labeled
optimizer = tf.train.AdagradOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)
其中,类损失中的calclulate是我创建的一个单独的函数。我目前实现它只是因为一个类的第一个散列与同一批中该类的其他散列不同,但是,我对我当前的实现不满意,而且它看起来不起作用
def calculate_within_class_loss(self, hash_values):
first_hash = tf.slice(hash_values, [0, 0], [1, 256])
total_loss = tf.foldl(lambda d, e: d + tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(e, first_hash)))), hash_values, initializer=0.0)
return total_loss
因此,我有两个问题:
感谢您的时间和帮助:)在示例代码中,您正在计算点之间的欧氏距离之和 为此,您必须循环整个数据集,进行
O(n^2*m)
计算,并进行O(n^2*m)
空间,即Tensorflow图操作。
这里,n
是向量的数量,m
是散列的大小,即256
但是,如果可以将对象更改为以下内容:
然后,可以使用并重写与相同的计算
其中,mu_k
是集群的k
th坐标的平均值
这将允许您计算O(n*m)
time和O(n*m)
Tensorflow操作中的值
如果您认为这种变化(即从欧几里德距离到平方欧几里德距离)不会对损失函数产生不利影响,那么这将是一种方法。类的形状是什么?\u hash?对不起。我在这里只发布了我认为与问题相关的部分代码。由于散列有一个形状[None,256],类?\u散列也有形状[?,256],这取决于我在当前批处理中拥有该类的多少个索引。