Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181

Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/image-processing/2.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Tensorflow:每个类的IOU_Tensorflow - Fatal编程技术网

Tensorflow:每个类的IOU

Tensorflow:每个类的IOU,tensorflow,Tensorflow,我正在尝试使用deeplab进行语义分割。我想计算每个班级的借条(仅个人借条),而不是平均借条 在 , 我试图得到混乱矩阵,而不是卑鄙的借据 miou, cmat = tf.metrics.mean_iou(...) metric_map['cmat'] = cmat 但它没有起作用。 如果有人建议我如何走动,我将不胜感激。您可以使用tensorflow.python.ops.metrics\u impl中的\u streaming\u conflusion\u matrix来获取混淆矩阵。

我正在尝试使用deeplab进行语义分割。我想计算每个班级的借条(仅个人借条),而不是平均借条

在 , 我试图得到混乱矩阵,而不是卑鄙的借据

miou, cmat = tf.metrics.mean_iou(...)
metric_map['cmat'] = cmat
但它没有起作用。
如果有人建议我如何走动,我将不胜感激。

您可以使用
tensorflow.python.ops.metrics\u impl
中的
\u streaming\u conflusion\u matrix
来获取混淆矩阵。 本质上,它与其他运行度量(如
mean\u iou
)的工作方式相同。这意味着,当调用此度量时,会得到两个运算,一个是总混淆矩阵运算,另一个是累计更新混淆矩阵的更新运算


使用混淆矩阵,现在您应该能够计算类的iou了。我基于MeanIoU类实现了一个特定于类的iou度量

class ClassIoU(tf.keras.metrics.MeanIoU):
"""Computes the class-specific Intersection-Over-Union metric.

IOU is defined as follows:
  IOU = true_positive / (true_positive + false_positive + false_negative).
The predictions are accumulated in a confusion matrix, weighted by
`sample_weight` and the metric is then calculated from it.

If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.

Args:
  class_idx: The index of the the class of interest
  one_hot: Indicates if the input is a one_hot vector as in CategoricalCrossentropy or if the class indices
    are used as in SparseCategoricalCrossentropy or MeanIoU
  num_classes: The possible number of labels the prediction task can have.
    This value must be provided, since a confusion matrix of dimension =
    [num_classes, num_classes] will be allocated.
  name: (Optional) string name of the metric instance.
  dtype: (Optional) data type of the metric result.
"""
def __init__(self, class_idx, one_hot, num_classes, name=None, dtype=None):
    super().__init__(num_classes, name, dtype)
    self.one_hot = one_hot
    self.class_idx = class_idx

def result(self):
    sum_over_row = tf.cast(
        tf.reduce_sum(self.total_cm, axis=0), dtype=self._dtype)
    sum_over_col = tf.cast(
        tf.reduce_sum(self.total_cm, axis=1), dtype=self._dtype)
    true_positives = tf.cast(
        tf.linalg.diag_part(self.total_cm), dtype=self._dtype)

    # sum_over_row + sum_over_col =
    #     2 * true_positives + false_positives + false_negatives.
    denominator = sum_over_row[self.class_idx] + sum_over_col[self.class_idx] \
        - true_positives[self.class_idx]

    # The mean is only computed over classes that appear in the
    # label or prediction tensor. If the denominator is 0, we need to
    # ignore the class.
    num_valid_entries = tf.reduce_sum(
        tf.cast(tf.not_equal(denominator, 0), dtype=self._dtype))

    iou = tf.math.divide_no_nan(true_positives[self.class_idx], denominator)

    return tf.math.divide_no_nan(
        tf.reduce_sum(iou, name='mean_iou'), num_valid_entries)

def update_state(self, y_true, y_pred, sample_weight=None):
    if self.one_hot:
        return super().update_state(tf.argmax(y_true, axis=-1), tf.argmax(y_pred, axis=-1), sample_weight)
    else:
        return super().update_state(y_true, y_pred, sample_weight)