tensorflow自定义循环不会在第一个历元结束,进度条将运行到无限

tensorflow自定义循环不会在第一个历元结束,进度条将运行到无限,tensorflow,keras,progress-bar,tensorflow-datasets,Tensorflow,Keras,Progress Bar,Tensorflow Datasets,我正在尝试编写一个tensorflow自定义训练循环,并包含一些tensorboard实用程序 以下是完整的代码: import tensorflow as tf from pathlib import Path from tensorflow.keras.utils import to_categorical from tensorflow.keras import layers import cv2 from tqdm import tqdm from os import listdir i

我正在尝试编写一个tensorflow自定义训练循环,并包含一些tensorboard实用程序

以下是完整的代码:

import tensorflow as tf
from pathlib import Path
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import layers
import cv2
from tqdm import tqdm
from os import listdir
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from tqdm import tqdm
from random import shuffle, choice, uniform

from os.path import isdir, dirname, abspath, join
from os import makedirs
from tensorflow.keras.callbacks import (ModelCheckpoint, TensorBoard,
                                        EarlyStopping, LearningRateScheduler)

import io
from natsort import natsorted
from tensorflow.keras import backend as K
from tensorflow.keras import Sequential,Model

from tensorflow.keras.applications import (DenseNet201, InceptionV3, MobileNetV2,
                                           ResNet101, Xception, EfficientNetB7,VGG19, NASNetLarge)
from tensorflow.keras.applications import (densenet, inception_v3, mobilenet_v2,
                                           resnet, xception, efficientnet, vgg19, nasnet)

from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.layers.experimental.preprocessing import Rescaling, Resizing
from tensorflow.keras.utils import Progbar


ROOT = '/content/drive/MyDrive'
data_path = 'cropped/'
train_path = data_path + 'train'
val_path = data_path + 'val'

labels = {v:k for k, v in enumerate(listdir(train_path))}

models = {
    'densenet': DenseNet201,
    'xception': Xception,
    'inceptionv3': InceptionV3,
    'effecientnetb7': EfficientNetB7,
    'vgg19': VGG19,
    'nasnetlarge': NASNetLarge,
    'mobilenetv2': MobileNetV2,
    'resnet': ResNet101
}

# models['densenet']()

preprocess_pipeline = {
    'densenet': densenet.preprocess_input,
    'xception': xception.preprocess_input,
    'inceptionv3': inception_v3.preprocess_input,
    'effecientnetb7': efficientnet.preprocess_input,
    'vgg19': vgg19.preprocess_input,
    'nasnetlarge': nasnet.preprocess_input,
    'mobilenetv2': mobilenet_v2.preprocess_input,
    'resnet': resnet.preprocess_input
}


def configure_for_performance(ds, buffer_size, batch_size):
    """
    Configures caching and prefetching
    """
    ds = ds.cache()
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=buffer_size)
    return ds


def generator(tfrecord_file, batch_size, n_data, validation_ratio, reshuffle_each_iteration=False):
    """
    Returns training and validation generators with infinite repeat.
    """
    reader = tf.data.TFRecordDataset(filenames=[tfrecord_file])
    reader.shuffle(n_data, reshuffle_each_iteration=reshuffle_each_iteration)
    AUTOTUNE = tf.data.experimental.AUTOTUNE

    val_size = int(n_data * validation_ratio)
    train_ds = reader.skip(val_size)
    val_ds = reader.take(val_size)

    # Parsing data from tfrecord format.
    train_ds = train_ds.map(_parse_function, num_parallel_calls=AUTOTUNE)
    
    # Some data augmentation.
    train_ds = train_ds.map(_augment_function, num_parallel_calls=AUTOTUNE)
    train_ds = configure_for_performance(train_ds, AUTOTUNE, batch_size).repeat()

    val_ds = val_ds.map(_parse_function, num_parallel_calls=AUTOTUNE)
    val_ds = val_ds.map(_augment_function, num_parallel_calls=AUTOTUNE)
    val_ds = configure_for_performance(val_ds, AUTOTUNE, batch_size).repeat() # Is this repeat function the reason behind the issue 
    return train_ds, val_ds

def create_model(optimizer, name='densenet', include_compile=True):
    base_model = models[name](include_top=False, weights='imagenet')
    x = GlobalAveragePooling2D()(base_model.layers[-1].output)
    x = Dense(1024, activation='relu')(x)
    output = Dense(12, activation='softmax')(x)
    model = Model(base_model.inputs, output)

    if include_compile:
        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])

    return model
现在,让我们创建一个模型并初始化:

n_data = len(list(Path(data_path).rglob('*.jpg'))) # Find out how many images are there
validation_ratio = 0.2
val_size = int(n_data * validation_ratio) # Find out validation image size.
train_size = n_data - val_size # And train images size
batch_size = 64
n_epochs = 5

# Tfrecord of images
filename = '/content/drive/MyDrive/cropped_data.tfrecord'

train_ds, val_ds = generator(filename,
                            batch_size=batch_size,
                            n_data=n_data,
                            validation_ratio=validation_ratio,
                            reshuffle_each_iteration=True)

# Tensorboard initialization
model_name = 'xception'

path_to_run = "runs/run_1"
tb_train_path = join(path_to_run, 'logs','train')
tb_test_path = join(path_to_run, 'logs', 'test')

train_writer = tf.summary.create_file_writer(tb_train_path)
test_writer = tf.summary.create_file_writer(tb_test_path)
train_step = test_step = 0

blocks_to_train = []
lr = 1e-4

optimizer = SGD(lr=lr, decay=1e-6,momentum=0.9,nesterov=True)
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
acc_metric = tf.keras.metrics.CategoricalCrossentropy()

# Create the xception model
model = create_model(optimizer, name=model_name, include_compile=False)

metrics = {'acc': 0.0, 'loss': 0.0, 'val_acc': 0.0, 'val_loss': 0.0, 'lr': lr}
这是培训和测试的循环:

for epoch in range(n_epochs):
    # Iterate through the training set
    progress_bar = Progbar(train_size, stateful_metrics=list(metrics.keys()))

    for batch_idx, (x, y) in enumerate(train_ds):
        with tf.GradientTape() as tape:
            y_pred = model(x, training=True)
            loss = loss_fn(y, y_pred)

        gradients = tape.gradient(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(gradients, model.trainable_weights))
        acc_metric.update_state(y, y_pred)
        train_step += 1
        progress_bar.update(batch_idx*batch_size, values=[('acc',acc_metric.result()),
                                       ('loss', loss)])

    with train_writer.as_default():
        tf.summary.scalar("Loss", loss, step=epoch)
        tf.summary.scalar(
            "Accuracy", acc_metric.result(), step=epoch
        )

    # reset accuracy between epochs (and for testing and test)

    acc_metric.reset_states()


    for batch_idx, (x,y) in enumerate(val_ds):
        y_pred = model(x, training=False)
        loss = loss_fn(y, y_pred)
        acc_metric.update_state(y,
                                y_pred)
        confusion += get_confusion_matrix(y, y_pred, class_names=list(labels.keys()))

    with test_writer.as_default():
        tf.summary.scalar("Loss", loss, step=epoch)
        tf.summary.scalar("Accuracy", acc_metric.result(), step=epoch)

    progress_bar.update(train_size, values=[('val_acc', acc_metric.result()), ('val_loss', loss)])

    # reset accuracy between epochs (and for testing and test)
    acc_metric.reset_states()
我修改了代码并删除了一些tensorboard实用程序。代码开始训练,但不会在预定义的时代结束时停止。我看到进度条一直在运行,从未停止显示验证指标

你们能帮我做一个完全一样的进度条吗,比如
keras.fit
功能

谢谢

我发现了长时间训练背后的(愚蠢的)原因:

数据包括
train_size
培训数据和
val_size
验证数据,不考虑批次。例如,训练数据由4886个数据样本组成,这些样本将是76个数据批(批大小为64)

当我在enumerate(train\u gen)中对批次idx(x,y)使用
时:
,我总共有76个批次,但我在循环中错误地循环了4886个批次

我将以下几行改写为:

for epoch in range(n_epochs):
# Iterate through the training set
progress_bar = Progbar(train_size, stateful_metrics=list(metrics.keys()))

train_gen = train_ds.take(train_size//batch_size) # This line

for batch_idx, (x, y) in enumerate(train_gen):

.....


val_gen = val_ds.take(val_size//batch_size)

for batch_idx, (x,y) in enumerate(val_gen):

对于进度条,一个简单的更改是使用
tqdm
trange
而不是range。我的主要问题是数据生成器没有在“训练”部分结束,也没有启动验证部分。@Masoudmasoumimomoghadam Hi,您使用的是哪个数据集?即使你发布了你的解决方案,我也有兴趣复制它。@Jared这是我最近所做工作的链接。但代码并非完美无缺。