Python 如何使用Keras将输出wrt和输入的梯度合并到损失函数中?

Python 如何使用Keras将输出wrt和输入的梯度合并到损失函数中?,python,tensorflow,keras,Python,Tensorflow,Keras,我已经设法将梯度加入到损失函数中,但问题是我似乎无法选择我需要的特定梯度。该模型有多个输入和一个输出,我目前不确定使用哪一个。该模型定义如下: input_shape = (5,) input_tensor = Input(shape = input_shape) l1 = Dense(400, activation = 'relu')(input_tensor) l2 = Dense(400, activation = 'relu')(l1) l3 = Dense(400, activatio

我已经设法将梯度加入到损失函数中,但问题是我似乎无法选择我需要的特定梯度。该模型有多个输入和一个输出,我目前不确定使用哪一个。该模型定义如下:

input_shape = (5,)
input_tensor = Input(shape = input_shape)
l1 = Dense(400, activation = 'relu')(input_tensor)
l2 = Dense(400, activation = 'relu')(l1)
l3 = Dense(400, activation = 'relu')(l2)
l4 = Dense(400, activation = 'relu')(l3)
output_tensor = Dense(1, activation = 'relu')(l4)
model = Model(input_tensor, output_tensor)


def custom_loss(input_tensor, output_tensor):
    def newloss(y_true, y_pred):
        mse = K.mean(K.square(y_true - y_pred))
        gradients = K.gradients(output_tensor, input_tensor)[0][:,1]
        return mse + K.maximum(-1*gradients, 0)
    return newloss


sgd = keras.optimizers.Adam(lr = 0.001, beta_1 = 0.9, beta_2 = 0.999)
model.compile(loss = custom_loss(input_tensor, output_tensor),
              optimizer = 'sgd',
              metrics = ['mae'])

epochs = 10
batch_size = 100
# Fit the model weights.
history = model.fit(x_train_bs, y_train_bs,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test_bs, y_test_bs))

根据我的理解,K.gradients函数应该产生5个导数(由于5个输入),但我无法对其进行索引以选择我特别需要的导数。我对Keras没有经验,因此任何帮助/直觉都将不胜感激

只选择你想要的输入,而不是使用渐变中的所有输入。是的,这就是我试图做的,但我不确定应该在哪个阶段做。如果我试图在custom_loss参数的第一阶段选择输入,它会给我一个错误。我目前尝试的是gradients变量,末尾有[0][:,1]索引,但这并没有像我希望的那样工作。错误消息到底是什么?当你做
K.梯度(输出张量,输入张量[0])
时会发生什么?(这就是@DanielMöller建议的)这就是我从索引中得到的:
InvalidArgumentError:compatible shapes:[100]vs[0][[node gradients/loss/densite_5_loss/newloss/weighted_loss/mul_grade/broadcastinggradientargs(定义于/opt/anaconda3/envs/geronimo_test/lib/python3.7/site packages/tensorflow_core/python/framework/ops.py:1751)]][Op:\uuuu推断\uKeras\uScratch\uGraph\u785]
当我运行您发布的代码时,我没有收到任何错误(我的keras版本是2.3.1)。你到底需要哪个梯度<代码>K。梯度(输出张量,输入张量)[0][:,i]应与输出相对于第i个输入的梯度相对应。