Tensorflow 2.0错误,自定义损失函数

Tensorflow 2.0错误,自定义损失函数,tensorflow,keras,deep-learning,tensorflow2.0,loss-function,Tensorflow,Keras,Deep Learning,Tensorflow2.0,Loss Function,我定义了一个损失函数,但当我运行model.fit()时,我遇到了一个无法解决的错误 这是我的损失函数: def l2_loss_eye_distance_normalised(y_actual, y_pred): #this loss function expects an exact, specific indexing order in the labels # the squared, ground truth, interocular distance eye

我定义了一个损失函数,但当我运行model.fit()时,我遇到了一个无法解决的错误

这是我的损失函数:

def l2_loss_eye_distance_normalised(y_actual, y_pred):
    #this loss function expects an exact, specific indexing order in the labels

    # the squared, ground truth, interocular distance
    eye_dst_vector = np.array([y_actual[0] - y_actual[2], y_actual[1] - y_actual[3]])
    denominator = kb.sum(kb.square(eye_dst_vector))

    #the squared distance between predition vector and ground truth vector
    dst_vec = y_actual - y_pred
    numerator = kb.sum(kb.square(dst_vec))

    return tf.cast(numerator / denominator, dtype="float32")
在main()的后面部分,我使用以下代码编译模型:

model.compile(
        optimizer=tf.keras.optimizers.Adam(0.001),
        loss=l2_loss_eye_distance_normalised, #custom loss
        metrics=['mae', 'mse']
    )
调用model.fit()时,我得到一个错误:

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    verbose = 1
)
我认为我在使用自定义损失函数时犯了一个错误,但我不明白是什么错了。 有人能帮忙吗?:)


在您的方法中,您将继续执行while循环,在该循环中,您将迭代批处理维度。然而,这是非常低效和不必要的,并且更类似于用python而不是tensorflow进行思考。相反,您应该一次对每个批次进行每个计算。请记住,丢失输出是单个标量(这就是为什么会出现错误消息“expected non tensor”),因此最终必须对批进行求和

假设您的形状是
(批次,标签)=(无,4)
(如果您有额外的中间尺寸,例如,由于序列,只需将它们添加到下面的所有索引中),则您可以执行以下操作:

left = tf.math.squared_difference(y_actual[:,0], y_actual[:,2] # shape: (Batch,)
right = tf.math.squared_difference(y_actual[:,1], y_actual[:,3] # shape: (Batch,)
denominator = left + right # shape: (Batch,)
从这里开始,所有的东西都需要成型
(批处理)

现在,如何从每个批次的
loss\u中获得最终损失取决于您的设计,但简单地累积单个损失是正常的方式:

return tf.reduce_sum(loss_for_each_batch)

我不清楚你们的预期产量是多少。您的y_实际和标签的形状是什么?它们的含义是什么?还有:有没有像MNIST这样的标准数据集,我们可以用它来复制(如果有,请也添加您的模型)<代码>[y_-actual[0]-y_-actual[2],y_-actual[1]-y_-actual[3]
不做您认为它做的事情(您正在为批处理维度编制索引)添加独立代码以再现错误。有关详细信息,请参阅。@runDOSrun问题似乎正是您指出的问题。我正在索引批次,而不是第二维度。y_actual包含4个标签,表示图片中眼睛的x,y坐标[左眼,左眼,右眼,右眼]。我想做的是计算左眼和右眼之间距离的平方(用y_实际值)。为了做到这一点,我试图减去两个向量,并计算得到的向量的范数。现在的问题是---我如何索引特定的标签而不是批次?非常感谢@runDOSrun!此解决方案有效且超级精益。谢谢
return tf.reduce_sum(loss_for_each_batch)