如何保存图像分类模型并将其用于android

如何保存图像分类模型并将其用于android,android,tensorflow,machine-learning,keras,classification,Android,Tensorflow,Machine Learning,Keras,Classification,如何使用Keras和Tensorflow将图像分类模型保存为.pb文件及其label.txt,以便在android上使用这两个文件。我有一个开始代码,代码仅为save.pb文件,而不是label.txt 我已经做了洞的事情,但没有label.txt 这是密码 import pandas as pd import numpy as np import warnings warnings.filterwarnings('ignore') import matplotlib.pyplot as pl

如何使用Keras和Tensorflow将图像分类模型保存为.pb文件及其label.txt,以便在android上使用这两个文件。我有一个开始代码,代码仅为save.pb文件,而不是label.txt

我已经做了洞的事情,但没有label.txt 这是密码

import pandas as pd 
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import keras
from keras.models import Sequential
from keras.layers import Conv2D,MaxPooling2D,Dense,Flatten,Dropout,Activation
from keras.optimizers import Adam
from keras.callbacks import TensorBoard
from keras.layers.core import Lambda
from keras.optimizers import Adam
import keras 
import keras.backend as k
import tensorflow as tf
from tensorflow.python.framework import graph_util
print(keras.__version__)
print(tf.__version__)
import os
train_df = pd.read_csv('fashionmnist/fashion-mnist_train.csv',sep=',')
test_df = pd.read_csv('fashionmnist/fashion-mnist_test.csv',sep=',')


train_data =np.array(train_df,dtype = 'float32')
test_data = np.array(test_df,dtype = 'float32')
x_train = train_data[:,1:]/255
y_train = train_data[:,0]
x_test = train_data[:,1:]/255
y_test = train_data[:,0]
x_train,x_validate,y_train,y_validate=train_test_split(x_train,y_train,test_size = 0.2,random_state = 12345)
image = x_train[50,:].reshape((28,28))
plt.imshow(image)
plt.show()

image_rows =28
image_cols= 28
batch_size =100
image_shape =(image_rows,image_cols,1)



x_train = x_train.reshape(x_train.shape[0],*image_shape)
x_test = x_test.reshape(x_test.shape[0],*image_shape)
x_validate = x_validate.reshape(x_validate.shape[0],*image_shape)


def build_network(is_training=True):
    model = Sequential()
    model.add(Conv2D(32, (3, 3), input_shape=image_shape,  padding='same',name="1_conv"))
    model.add(Activation('relu'))
    model.add(Conv2D(32, (3, 3), padding='same',name="2_conv"))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2),name="1_pool"))

    model.add(Conv2D(64, (3, 3), padding='same',name="3_conv"))
    model.add(Activation('relu'))
    model.add(Conv2D(64,(3, 3), padding='same',name="4_conv"))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2),name="2_pool"))

    model.add(Conv2D(128,(3, 3),padding='same',name="5_conv"))
    model.add(Activation('relu'))
    model.add(Conv2D(128, (3, 3),padding='same',name="6_conv"))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2),name="3_pool"))

    model.add(Conv2D(256,(3, 3), padding='same',name="7_conv"))
    model.add(Activation('relu'))
    model.add(Conv2D(256, (3, 3), padding='same',name="8_conv"))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2),name="4_pool"))

    model.add(Flatten())
    model.add(Dense(512,name="fc_1"))
    model.add(Activation('relu'))


    if (is_training):
        #model.add(Dense(512, activation='relu'))
        #model.add(Dropout(0.5, name="drop_1"))
        model.add(Lambda(lambda x:k.dropout(x,level=0.5),name="drop_1"))



    model.add(Dense(10,name="fc_2"))
    model.add(Activation('softmax',name="class_result"))
    #model.summary()
    return model


    tf.reset_default_graph()
sess = tf.Session()
k.set_session(sess)
model=build_network()

history_dict = {}
model.compile(loss='sparse_categorical_crossentropy',optimizer = Adam(),metrics=['accuracy'])




class TFCheckpointCallback(keras.callbacks.Callback):
    def __init__(self,saver,sess):
        self.saver=saver
        self.sess=sess

    def on_epoch_end(self,epoch,log=None):
        self.saver.save(self.sess,'fMnist/ckpt',global_step=epoch)


tf_saver= tf.train.Saver(max_to_keep=2)
checkpoint_callback= TFCheckpointCallback(tf_saver,sess)
%time
tf_graph=sess.graph
tf.train.write_graph(tf_graph.as_graph_def(),'freeze','fm_graph.pdtxt',as_text=True)
%time
history = model.fit(x_train,
                    y_train,
                    batch_size=batch_size,
                    epochs=50,
                    callbacks=[checkpoint_callback],
                    shuffle=True,
                    verbose=1,
                    validation_data=(x_validate,y_validate)
                   )

sess.close()


model_folder='fMnist/'
def prepare_graph_for_freezing(model_folder):
    model=build_network(is_training=False)
    checkpoint=tf.train.get_checkpoint_state(model_folder)
    input_checkpoint=checkpoint.model_checkpoint_path
    saver=tf.train.Saver()
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        k.set_session(sess)
        saver.restore(sess,input_checkpoint)
        tf.gfile.MakeDirs(model_folder+'freeze')
        saver.save(sess,model_folder + 'freeze/ckpt',global_step=0)


def freeze_graph(model_folder):
    checkpoint =tf.train.get_checkpoint_state(model_folder)
    print(model_folder+'freeze/')
    input_checkpoint = checkpoint.model_checkpoint_path
    absolut_model_folder="/".join(input_checkpoint.split('/')[:-1])
    output_graph=absolut_model_folder + "/fm_freazen_model.pb"
    print(output_graph)
    output_node_name = "class_result/Softmax"
    clear_devices = True
    new_saver=  tf.train.import_meta_graph(input_checkpoint + '.meta',clear_devices=clear_devices)

    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()


    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess2:
        print(input_checkpoint)
        new_saver.restore(sess2,input_checkpoint)

        output_graph_def=graph_util.convert_variables_to_constants(
        sess2,
        input_graph_def,
        output_node_name.split(","))

        with tf.gfile.GFile(output_graph,"wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph."% len(output_graph_def.node))
tf.reset_default_graph()
prepare_graph_for_freezing("freeze/")
freeze_graph("freeze/")
我有检查点和.pb文件


但是我没有label.txt文件

关于Android上的图像分类,我建议您使用,而不是直接使用协议缓冲区

首先,您需要将Keras模型(
.h5
)转换为TensorFlow Lite模型(
.tflite

该模型已准备好加载到Android上。要检查输入和输出
dtype
shape
,请参阅文件

现在在Android上,首先在
build.gradle
中添加TensorFlow Lite依赖项

dependencies {
...
   implementation 'org.tensorflow:tensorflow-lite:1.13.1'
...
}
现在,我们将模型作为
MappedByteBuffer
对象加载

@抛出(IOException::类)

使用
解释器.run()
方法,我们在给定一些输入的情况下产生一个推断。看这个。此文件包含使用
Bitmap.createScaledBitmap
方法调整
位图大小的方法,以及将其转换为
float[][]

val interpreter = Interpreter( loadModelFile() )
val inputs : Array<FloatArray> = arrayOf( some_input_image. )
val outputs : Array<FloatArray> = arrayOf( floatArrayOf( 0.0f , 0.0f ) )
interpreter.run( inputs , outputs )
val output = outputs[0]
  • 另外,尝试使用Firebase MLKit在Firebase中托管自定义模型

  • 我已经创建了许多应用程序,使用TF对图像和文本进行分类


  • 对于labels.txt,您可以将该文件放在应用程序的资产文件夹中并阅读。先生,感谢您的最佳解释,但问题是如何获取该型号的label.txt文件(如何编写该文本文件)。我有178个类,我用这178个文件夹的图像数据(类)训练网络每个therm都有正确的标签,每个类都有5000张图片。@haptome你是否找到了获取label.txt的方法,我发现获取label.txt非常困难。你找到答案了吗?是的,我得到了答案,但不是从这一页得到的。
    private fun loadModelFile(): MappedByteBuffer {
        val MODEL_ASSETS_PATH = "model.tflite"
        val assetFileDescriptor = assets.openFd(MODEL_ASSETS_PATH)
        val fileInputStream = FileInputStream(assetFileDescriptor.getFileDescriptor())
        val fileChannel = fileInputStream.getChannel()
        val startoffset = assetFileDescriptor.getStartOffset()
        val declaredLength = assetFileDescriptor.getDeclaredLength()
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startoffset, declaredLength)
    }
    
    val interpreter = Interpreter( loadModelFile() )
    val inputs : Array<FloatArray> = arrayOf( some_input_image. )
    val outputs : Array<FloatArray> = arrayOf( floatArrayOf( 0.0f , 0.0f ) )
    interpreter.run( inputs , outputs )
    val output = outputs[0]
    
    converter = tf.lite.TFLiteConverter.from_keras_model_file( 'model.h5' )
    converter.post_training_quantize = True
    tflite_buffer = converter.convert()
    open( 'tflite_model.tflite' , 'wb' ).write( tflite_buffer )