Python 在Model.fit中使用带有交错的tf.data.Dataset时,模型训练被卡住

Python 在Model.fit中使用带有交错的tf.data.Dataset时,模型训练被卡住,python,multithreading,tensorflow,multiprocessing,conv-neural-network,Python,Multithreading,Tensorflow,Multiprocessing,Conv Neural Network,其目的是使用从目录获取数据的生成器进行线程安全模型培训 我用Dataset包装了ImageDataGenerator.flow\u from_directory(),然后交错多个实例 #%% Create data generators from random_eraser import get_random_eraser def get_gens(size=299, bs_per_gpu=128, eraser=False,

其目的是使用从目录获取数据的生成器进行线程安全模型培训

我用Dataset包装了ImageDataGenerator.flow\u from_directory(),然后交错多个实例

#%% Create data generators

from random_eraser import get_random_eraser

def get_gens(size=299,
             bs_per_gpu=128,
             eraser=False,
             get_classes=False):
    
    bs = bs_per_gpu * num_gpus 
    
    func_eraser = None
    if eraser:
        func_eraser = get_random_eraser(p=0.5,
                                        s_l=0.01,
                                        s_h=0.05,
                                        r_1=0.3,
                                        r_2=1/0.3,
                                        pixel_level=True)

    train_datagen = ImageDataGenerator(preprocessing_function=func_eraser,
                                       rescale=1/255.,
                                       width_shift_range=0.2,
                                       height_shift_range=0.2,
                                       fill_mode='constant',
                                       cval=0.0,
                                       horizontal_flip=True,
                                       rotation_range=45,
                                       brightness_range=[0.5,1.5],
                                       zoom_range=[0.8,1.0],
                                       validation_split=0.3)

    val_datagen = ImageDataGenerator(rescale=1/255.,
                                     validation_split=0.3)

    test_datagen = ImageDataGenerator(rescale=1/255.)

    train_generator = train_datagen.flow_from_directory(train_dir,
                                                        target_size=img_shape,
                                                        batch_size=bs,
                                                        seed=1337,
                                                        subset='training')

    val_generator = val_datagen.flow_from_directory(train_dir,
                                                    target_size=img_shape,
                                                    batch_size=bs,
                                                    shuffle=False,
                                                    seed=1337,
                                                    subset='validation')

    test_generator = test_datagen.flow_from_directory(test_dir,
                                                      target_size=img_shape,
                                                      batch_size=bs,
                                                      shuffle=False,
                                                      classes=['test'])
    
    if get_classes:
        return train_generator.classes, val_generator.classes
    
    def multithread_gen(gen, cores):
        
        # set up tf generator
        Dataset = tf.data.Dataset
        ds = Dataset.from_tensor_slices([str(x) for x in range(cores)])
        ds = ds.interleave(lambda x: Dataset.from_generator(gen,
                                                            output_types=(tf.float32, tf.float32)),
                           cycle_length=cores,
                           block_length=1,
                           num_parallel_calls=cores)
        #ds.prefetch(buffer_size=AUTOTUNE) # or 10?
        return ds
    
    train_generator = multithread_gen(lambda: train_generator,
                                      cores=31) 
    
    val_generator = multithread_gen(lambda: val_generator,
                                    cores=31)

    test_generator = multithread_gen(lambda: test_generator,
                                     cores=31)
    
    return train_generator, val_generator, test_generator

#train_generator, val_generator, test_generator = get_gens(size, bs_per_gpu)

# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
然后我用以下方法训练模型:

def train_model(epochs):

    model_history = model.fit(train_generator, # add val_generator?
                              epochs=epochs,
                              callbacks=callbacks,
                              steps_per_epoch=int(num_train/bs),
                              #class_weight=weights_dict,
                              workers=31,
                              use_multiprocessing=True)
    
    return model_history
但在打印以下内容后,模型被卡住:

纪元1/20

信息:tensorflow:batch\u all\u reduce:156 all reduces with algorithm=nccl,num\u packs=1 信息:tensorflow:batch\u all\u reduce:156 all reduces with algorithm=nccl,num\u packs=1

一直以来,CPU核心和GPU很少使用,仅占0%