Tensorflow 在EfficientNets上的转移学习如何用于灰度图像?

Tensorflow 在EfficientNets上的转移学习如何用于灰度图像?,tensorflow,neural-network,artificial-intelligence,transfer-learning,Tensorflow,Neural Network,Artificial Intelligence,Transfer Learning,我的问题更多的是关于算法是如何工作的。我已经成功地为灰度图像实现了高效的网络集成和模型化,现在我想了解它的工作原理 这里最重要的方面是灰度及其1通道。当我把channels=1放进去时,算法不起作用,因为如果我理解正确的话,它是在3通道图像上制作的。当我把channels=3放进去时,它工作得非常好 所以我的问题是,当我放置channels=3并使用channels=1为模型提供预处理图像时,为什么它继续工作 效率代码netb5 # Variable assignments num_classe

我的问题更多的是关于算法是如何工作的。我已经成功地为灰度图像实现了高效的网络集成和模型化,现在我想了解它的工作原理

这里最重要的方面是灰度及其1通道。当我把
channels=1
放进去时,算法不起作用,因为如果我理解正确的话,它是在3通道图像上制作的。当我把
channels=3
放进去时,它工作得非常好

所以我的问题是,当我放置
channels=3
并使用
channels=1
为模型提供预处理图像时,为什么它继续工作

效率代码netb5

# Variable assignments
num_classes = 9
img_height = 84
img_width = 112
channels = 3
batch_size = 32

# Make the input layer
new_input = Input(shape=(img_height, img_width, channels),
                  name='image_input')

# Download and use EfficientNetB5
tmp = tf.keras.applications.EfficientNetB5(include_top=False,
                                           weights='imagenet',
                                           input_tensor=new_input,
                                           pooling='max')
model = Sequential()
model.add(tmp)  # adding EfficientNetB5
model.add(Flatten())
...
灰度预处理代码

data_generator = ImageDataGenerator(
        validation_split=0.2)

train_generator = data_generator.flow_from_directory(
        train_path,
        target_size=(img_height, img_width),
        batch_size=batch_size,
        color_mode="grayscale", ###################################
        class_mode="categorical",
        subset="training")

这很有趣。如果训练仍然使用
通道=3
,即使输入是灰度的,我会检查
训练生成器的批次形状(可能会打印几个批次以获得感觉)。下面是一个快速检查批处理形状的代码片段。(plotImages()在Tensorflow文档中提供)


这很有趣。如果训练仍然使用
通道=3
,即使输入是灰度的,我会检查
训练生成器的批次形状(可能会打印几个批次以获得感觉)。下面是一个快速检查批处理形状的代码片段。(plotImages()在Tensorflow文档中提供)


你解决过这个问题吗?我也在想同样的事情。我得到了一个关于输入形状的警告,因为我只有1个通道(灰度图像),而不是3个通道,但它仍然训练模型。@erotavlas noo,我在其他地方发布了这个问题,但没有人回答它:/你解决过这个问题吗?我也在想同样的事情。我得到一个关于输入形状的警告,因为我只有1个通道(灰度图像),而不是3个,但它仍然训练模型。@erotavlas noo,我在其他地方发布了这个问题,但没有人回答它:/
imgs,labels = next(train_generator)
print('Batch shape: ',imgs.shape)
plotImages(imgs,labels)