Python 一次使用多个批次计算tensorflow度量

Python 一次使用多个批次计算tensorflow度量,python,tensorflow,keras,Python,Tensorflow,Keras,我使用的是tf.keras,我想计算一个指标,我需要多批验证数据才能可靠地计算它。在计算指标之前,是否有方法累积批次 我想这样做: class MultibatchMetric(tf.keras.metrics.Metric): def __init__(self, num_batches, name="sdr_metric", **kwargs): super().__init__(name=name, **kwargs) self.

我使用的是
tf.keras
,我想计算一个指标,我需要多批验证数据才能可靠地计算它。在计算指标之前,是否有方法累积批次

我想这样做:

class MultibatchMetric(tf.keras.metrics.Metric):
    def __init__(self, num_batches, name="sdr_metric", **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_batches = num_batches
        self.batch_accumulator = []
        self.my_metric = []

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.batch_accumulator.append((y_true, y_pred))
        if len(self.batch_accumulator) >= self.num_batches:
            metric = custom_multibatch_metric_func(self.batch_accumulator)
            self.my_metric.append(metric)
            self.batch_accumulator = []

    def result(self):
        return mean(self.my_metric)

    def reset_states(self):
        self.my_metric = []
        self.batch_accumulator = []

然而,这一切都需要在tensorflow图上发生,使事情严重复杂化。

我尝试了一下您的问题,似乎使用内置的
添加权重
方法可以提供解决方案。通过为批量计数器和具有大小
(2,num\u batches*batch\u size,n\u输出)
的累加器创建状态变量。每次更新都会通过向状态变量添加填充批次来存储批次,并在计数器达到最大批次数时重置。然后,通过调用累加器状态变量上的度量,可以从累加器获得结果。我在下面添加了一个示例

class Metric(tf.keras.metrics.Metric):
    def __init__(self, num_batches, batch_size, name="sdr_metric", **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_batches = num_batches
        self.batch_size = batch_size
        self.batch_accumulator = self.add_weight(name='accumulator', shape=(2, num_batches * batch_size, 2), initializer='zeros')
        self.batch_counter = self.add_weight(name='counter', shape=(), initializer='zeros')
        
    @tf.function
    def update_state(self, y_true, y_pred, sample_weight=None):
        batch_count = self.batch_counter
        batch = tf.stack([tf.cast(y_true, tf.float32), tf.cast(y_pred, tf.float32)])
        paddings = [[0, 0], [batch_count * self.batch_size, (self.num_batches - batch_count - 1) * self.batch_size], [0, 0]]
        padded_batch = tf.pad(batch, paddings)
        self.batch_accumulator.assign_add(padded_batch)            
        self.batch_counter.assign_add(1)
        if batch_count == self.num_batches:
            self.reset_states()
        
    @tf.function
    def result(self):
        if self.batch_counter == self.num_batches - 1:
            return custom_multibatch_metric_func(self.batch_accumulator)
        else:
            return 0.

    def reset_states(self):
        self.batch_counter.assign(0)
        self.batch_accumulator.assign(tf.zeros((2, self.num_batches * self.batch_size, 2)))
还有我用来验证的测试问题

# data
n = 1028
batch_size = 32
num_batches = 3
f = 4
lr = 10e-3

x = tf.random.uniform((n, f), -1, 1)
y = tf.concat([tf.reduce_sum(x, axis=-1, keepdims=True), tf.reduce_mean(x, axis=-1, keepdims=True)], axis=-1)

ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(b, drop_remainder=True)

model = tf.keras.models.Sequential([Dense(f, activation='relu'), Dense(2)])
model.compile(tf.keras.optimizers.SGD(lr), tf.keras.losses.mean_squared_error, metrics=Metric(num_batches, batch_size))
model.fit(ds, epochs=10)
两大问题。首先,调用
result
中的if语句,但根据对结果度量的要求,可以返回幂等值。在这里,我假设您只对所有结果求和,因此
0
没有影响。其次,除非数据集大小可以被批处理大小整除,否则这种方法需要删除余数


我希望这是有帮助的,尽管这无论如何都不是一个最佳的解决方案。

这太棒了!我认为你不必像在切片上调用
assign\u add
那样进行填充。为什么第一维度是2?另外,为什么在
result
中有if语句?
result
不是仅在纪元结束时调用吗?填充是一种解决方法,因为在图形模式下,似乎无法在切片上调用assign。我在添加tf.function语句之前写了这篇文章,这样可能会改变它,然后就不需要填充了。由于
update\u state
y\u true
y\u pred
作为输入,因此真实值和预测值的第一个维度均为2。至于
result
我想你也可以根据
m
的步数来调用它,所以如果你想要这种可能性,你可以。正如我所说,这是一个基本的想法,可以调整到您的规格。如果你有更多的问题,我很乐意帮助你。希望这是有帮助的。