Python Tensorflow-断言失败:[预测必须为>;=0][条件x>;=y未保持元素状态:]

Python Tensorflow-断言失败:[预测必须为>;=0][条件x>;=y未保持元素状态:],python,tensorflow,Python,Tensorflow,我正在使用Tensorflow 2.5培训一个模型,以完成多标签任务。 在大多数情况下,在epoch=5-10时验证模型时会出现一些问题。 我不知道这意味着什么,我真的需要你的帮助 def get_metrics(): NUM_CLASSES = config.data.num_classes metrics = [tfa.metrics.F1Score(num_classes=NUM_CLASSES, average='micro', threshold=0.5, name=&

我正在使用Tensorflow 2.5培训一个模型,以完成多标签任务。 在大多数情况下,在epoch=5-10时验证模型时会出现一些问题。 我不知道这意味着什么,我真的需要你的帮助

def get_metrics():
    NUM_CLASSES = config.data.num_classes
    metrics = [tfa.metrics.F1Score(num_classes=NUM_CLASSES, average='micro', threshold=0.5, name="f1_micro"),
               tfa.metrics.F1Score(num_classes=NUM_CLASSES, average='macro', threshold=0.5, name="f1_macro"),
               tfa.metrics.F1Score(num_classes=NUM_CLASSES, average='weighted', threshold=0.5, name="f1_weighted"),
               tfa.metrics.HammingLoss(mode='multilabel', threshold=0.5),
               tf.keras.metrics.RecallAtPrecision(0.8),
               tf.keras.metrics.PrecisionAtRecall(0.8),
               tf.keras.metrics.AUC(multi_label=True, curve='PR', name='auc_pr')]
    metrics += [IndividualAUC(multi_label=False, curve='PR',class_idx=i, name='auc_pr_{}'.format(i)) for i in range(NUM_CLASSES)]
    return metrics    

def loss_for_missing_label(y_true,y_pred):
    mask = tf.cast(y_true>=0.0,'float32')
    loss_func = tfa.losses.SigmoidFocalCrossEntropy(gamma=1.5, alpha=0.25)
    return loss_func(y_true*mask, y_pred*mask)

model.compile(
    optimizer=optimizer,
    loss=loss_for_missing_label,
    metrics=get_metrics(),
)
model.fit(
      get_dataset(training=True),
      epochs=config.train.epochs,
      steps_per_epoch=steps_per_epoch,
      validation_data=get_dataset(training=False),
      validation_steps=num_eval_images // config.eval.batch_size,
      callbacks=[ckpt_callback, tb_callback, rstr_callback],
      verbose=2 if strategy == 'tpu' else 1,
  )

我得到这样的错误

tensorflow.python.framework.errors_impl.InvalidArgumentError: 3 root error(s) found.
  (0) Invalid argument:  assertion failed: [predictions must be >= 0] [Condition x >= y did not hold element-wise:] [x (mul_5:0) = ] [[0.182861328 0.032409668 0.0966796875...]...] [y (Cast_11/x:0) = ] [0]
         [[{{node assert_greater_equal/Assert/AssertGuard/else/_5/assert_greater_equal/Assert/AssertGuard/Assert}}]]
         [[replica_1/assert_greater_equal_16/Assert/AssertGuard/else/_1338/replica_1/assert_greater_equal_16/Assert/AssertGuard/Assert/data_2/_4928]]
  (1) Invalid argument:  assertion failed: [predictions must be >= 0] [Condition x >= y did not hold element-wise:] [x (mul_5:0) = ] [[0.182861328 0.032409668 0.0966796875...]...] [y (Cast_11/x:0) = ] [0]
         [[{{node assert_greater_equal/Assert/AssertGuard/else/_5/assert_greater_equal/Assert/AssertGuard/Assert}}]]
         [[AddN_12/_477]]
  (2) Invalid argument:  assertion failed: [predictions must be >= 0] [Condition x >= y did not hold element-wise:] [x (mul_5:0) = ] [[0.182861328 0.032409668 0.0966796875...]...] [y (Cast_11/x:0) = ] [0]
         [[{{node assert_greater_equal/Assert/AssertGuard/else/_5/assert_greater_equal/Assert/AssertGuard/Assert}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_test_function_244419]

Function call stack:
test_function -> test_function -> test_function

可能是由于
y_pred
具有负值引起的,但这不应导致错误,因为损失函数中的sigmoid会挤压[0,1]中的值。尝试打印
y_pred
值,并检查出现错误时是否有异常值(检查
nan
s,可能还有负值)。