Warning: file_get_contents(/data/phpspider/zhask/data//catemap/8/mysql/69.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Deep learning 如何使用keras微调inception v3来进行多类分类?_Deep Learning_Classification_Keras - Fatal编程技术网

Deep learning 如何使用keras微调inception v3来进行多类分类?

Deep learning 如何使用keras微调inception v3来进行多类分类?,deep-learning,classification,keras,Deep Learning,Classification,Keras,我想使用Keras使用Kaggle.com上的猫狗数据集进行两类图像分类。 但是我对param“class_mode”有一些问题,如下代码所示。 如果我使用“二进制”模式,准确率约为95%,但如果我使用“分类”模式,准确率异常低,仅高于50% 二进制模式意味着最后一层只有一个输出,并使用sigmoid激活进行分类。样本的标签仅为一个整数 “分类”是指最后一层中的两个输出,并使用softmax激活进行分类。样本的标签是一种热格式,例如(1,0),(0,1) 我认为这两种方法应该有相似的结果。有人知

我想使用Keras使用Kaggle.com上的猫狗数据集进行两类图像分类。 但是我对param“class_mode”有一些问题,如下代码所示。 如果我使用“二进制”模式,准确率约为95%,但如果我使用“分类”模式,准确率异常低,仅高于50%

二进制模式意味着最后一层只有一个输出,并使用sigmoid激活进行分类。样本的标签仅为一个整数

“分类”是指最后一层中的两个输出,并使用softmax激活进行分类。样本的标签是一种热格式,例如(1,0),(0,1)

我认为这两种方法应该有相似的结果。有人知道差异的原因吗?非常感谢

导入操作系统
导入系统
导入glob
导入argparse
将matplotlib.pyplot作为plt导入
来自keras导入版本__
从keras.applications.inception\u v3导入InceptionV3,预处理\u输入
从keras.models导入模型
从keras.layers导入稠密、全局平均池2D
从keras.preprocessing.image导入ImageDataGenerator
从keras.optimizers导入新加坡元
在这里设置一些参数

IM_WIDTH,IM_HEIGHT=299299#用于接收的固定大小v3
NB_时代=1
蝙蝠大小=32
FC_大小=1024
NB_IV3_层_至_冻结=172
损耗模式=“二进制交叉熵”
def get_nb_文件(目录):
“”“通过递归搜索目录获取文件数”“”
如果不存在os.path.exists(目录):
返回0
cnt=0
对于r、dir和os.walk(目录)中的文件:
对于目录中的dr:
cnt+=len(glob.glob(os.path.join(r,dr+“/*”))
返回cnt
转移并学习,将重量保持在

def设置到传输学习(型号、基本型号):
“”“冻结所有层并编译模型”“”
对于基本模型层中的层:
layer.trainable=错误
compile(optimizer='rmsprop',loss=loss_mode,metrics=['accurity'])
添加最后一层进行两类分类

def添加新的最后一层(基本模型、nb类):
“”“将最后一层添加到convnet
Args:
基本型:不包括顶部的keras型
nb#U类:#类
返回:
具有最后一层的新keras模型
"""
x=基本模型输出
x=全局平均池2D()(x)
x=密集(FC_大小,激活='relu')(x)#新FC层,随机初始
如果args.class_mode==“binary”:
预测=密集(1,激活='sigmoid')(x)#新的softmax层
其他:
预测=密集(nb_类,激活='softmax')(x)#新的softmax层
模型=模型(输入=基本模型。输入,输出=预测)
回归模型
冻结底部NB_IV3_层并重新培训剩余的顶部层, 和微调重量

def设置到微调(型号):
“”“冻结底部NB_IV3_层,并重新训练剩余的顶部层。”。
注:NB_IV3_层对应于inceptionv3拱门中的前2个起始块
Args:
模型:keras模型
"""
对于模型中的层。层[:NB_IV3_layers_TO_FREEZE]:
layer.trainable=错误
对于模型中的层。层[NB_IV3_层到_冻结:]:
layer.trainable=True
compile(optimizer=“rmsprop”,loss=loss\u模式,metrics=['accurity'])
#compile(优化器=SGD(lr=0.0001,动量=0.9),loss='classifical\u crossentropy',metrics=['accurity'])
def序列(args):
“”“使用迁移学习和微调在新数据集上训练网络”“”
nb\u train\u samples=获取nb\u文件(args.train\u dir)
nb_classes=len(glob.glob(args.train_dir+“/*”))
nb_val_samples=获取nb_文件(args.val_dir)
nb_epoch=int(参数nb_epoch)
批次大小=整数(参数为批次大小)
打印(“nb_类:{}”。格式(nb_类))
数据准备

train_datagen=ImageDataGenerator(
预处理_函数=预处理_输入,
旋转范围=30,
宽度\偏移\范围=0.2,
高度\位移\范围=0.2,
剪切范围=0.2,
缩放范围=0.2,
水平翻转=真
)
test_datagen=ImageDataGenerator(
预处理_函数=预处理_输入,
旋转范围=30,
宽度\偏移\范围=0.2,
高度\位移\范围=0.2,
剪切范围=0.2,
缩放范围=0.2,
水平翻转=真
)
train_generator=来自目录的train_datagen.flow_(
args.train_dir,
目标屏幕大小=(屏幕宽度、屏幕高度),
批次大小=批次大小,
#class_mode='binary'
class\u mode=args.class\u mode
)
验证\u生成器=来自\u目录的测试\u datagen.flow\u(
args.val_dir,
目标屏幕大小=(屏幕宽度、屏幕高度),
批次大小=批次大小,
#class_mode='binary'
class\u mode=args.class\u mode
)
设置模型

base_model=InceptionV3(weights='imagenet',include_top=False)#include_top=False排除最终FC层
模型=添加新的最后一层(基本模型,nb类)
迁移学习

setup\u to\u transfer\u learn(模型、基本模型)
#model.summary()
历史\u tl=model.fit\u生成器(
列车发电机,
纪元=nb_纪元,
每个历元的步长=nb\U序列样本//BAT\U尺寸,
验证数据=验证生成器,
验证步骤=nb\U val\U样本//BAT\U大小)
微调

设置到微调(型号)
历史记录\u ft=model.fit\u生成器(
列车发电机,
每个历元的步长=nb\U序列样本//BAT\U尺寸,
纪元=nb_纪元,
验证数据=验证生成器,
验证步骤=nb\U val\U样本//BAT\U大小)
model.save(args.output\u model\u文件)
如果args.plot:
绘图训练(历史)
def plot_培训(历史记录):
acc=历史。历史['acc']
val_acc=历史。历史['val_acc']
损失=历史。历史['loss']
val_loss=历史。历史['val_loss']
历元=范围(len(acc))
plt.绘图(时代,附件,右)
plt.绘图(时代,val_acc,'r')
产品名称(“培训和验证准确性”)
plt.图()
plt.绘图(年代、损失、r.)
plt.绘图(时代、价值损失、“r-”)
产品名称(“培训和验证损失”)
plt.show()
主要职能

如果名称=“\uuuuu main\uuuuuuuu”:
a=argparse.ArgumentParser()
A.
keras.optimizers.RMSprop(lr=0.001, rho=0.9, epsilon=None, decay=0.0)