Tensorflow 如何确定模型的最后一层进行迁移学习?

Tensorflow 如何确定模型的最后一层进行迁移学习?,tensorflow,deep-learning,ocr,transfer-learning,pre-trained-model,Tensorflow,Deep Learning,Ocr,Transfer Learning,Pre Trained Model,首先:我对深度学习和Tensorflow还不熟悉,所以很抱歉提出这些愚蠢的问题。也许有人能帮助我获得更多的理解和清晰。 我在一个OCR项目中工作,我只有4000张带点字体和白色背景的标记图像。 我决定使用此代码来解决此任务: 我在Synthtext数据集上使用了预先训练过的模型,并使用自己的数据集继续训练,但结果不是很好。 我认为一个问题可能是我的自定义数据集的类数少于预训练模型的类数。 我读过关于迁移学习的书,在那里你只训练最后一层,然后你可以使用不同数量的课程,但我不知道如何做到这一点。更准

首先:我对深度学习和Tensorflow还不熟悉,所以很抱歉提出这些愚蠢的问题。也许有人能帮助我获得更多的理解和清晰。 我在一个OCR项目中工作,我只有4000张带点字体和白色背景的标记图像。 我决定使用此代码来解决此任务:

我在Synthtext数据集上使用了预先训练过的模型,并使用自己的数据集继续训练,但结果不是很好。 我认为一个问题可能是我的自定义数据集的类数少于预训练模型的类数。 我读过关于迁移学习的书,在那里你只训练最后一层,然后你可以使用不同数量的课程,但我不知道如何做到这一点。更准确地说,我不知道如何识别图表的最后几层

以下是构建图表的代码:

    def cnn(self, rois):
        with tf.variable_scope("recog/cnn"):
            conv1 = slim.conv2d(rois, 64, 3, stride=1, padding='SAME', activation_fn=tf.nn.relu, normalizer_fn=None)
            conv1 = slim.conv2d(conv1, 64, 3, stride=1, padding='SAME', activation_fn=tf.nn.relu, normalizer_fn=None)
            pool1 = slim.max_pool2d(conv1, [2, 1], stride=[2, 1])
            conv2 = slim.conv2d(pool1, 128, 3, stride=1, padding='SAME', activation_fn=tf.nn.relu, normalizer_fn=None)
            conv2 = slim.conv2d(conv2, 128, 3, stride=1, padding='SAME', activation_fn=tf.nn.relu, normalizer_fn=None)
            pool2 = slim.max_pool2d(conv2, [2, 1], stride=[2, 1])
            conv3 = slim.conv2d(pool2, 256, 3, stride=1, padding='SAME', activation_fn=tf.nn.relu, normalizer_fn=None)
            conv3 = slim.conv2d(conv3, 256, 3, stride=1, padding='SAME', activation_fn=tf.nn.relu, normalizer_fn=None)
            pool3 = slim.max_pool2d(conv3, [2, 1], stride=[2, 1])
            return pool3
    

    def bilstm(self, input_feature, seq_len):
        with tf.variable_scope("recog/rnn"):
            lstm_fw_cell = rnn.LSTMCell(self.rnn_hidden_num)
            lstm_fw_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_fw_cell, input_keep_prob=self.keepProb, output_keep_prob=self.keepProb)
            lstm_bw_cell = rnn.LSTMCell(self.rnn_hidden_num)
            lstm_bw_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_bw_cell, input_keep_prob=self.keepProb, output_keep_prob=self.keepProb)
            # infer_output, _ = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, input_feature, seq_len, dtype=tf.float32)
            # infer_output, _ = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, input_feature, sequence_length=seq_len, time_major=True, dtype=tf.float32)
            infer_output, _ = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, input_feature, sequence_length=seq_len, dtype=tf.float32)
            # stack_lstm_layer, _, _ = rnn.stack_bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, input_feature, dtype=tf.float32)
            infer_output = tf.concat(infer_output, axis=-1)
            return infer_output
            # return stack_lstm_layer

    def build_graph(self, rois, seq_len):
        num_rois = tf.shape(rois)[0]

        cnn_feature = self.cnn(rois) # N * 1 * W * C
        print cnn_feature

        cnn_feature = tf.reshape(cnn_feature, [nums, -1, 256]) # squeeze B x W x C
        cnn_feature = tf.squeeze(cnn_feature, axis=1) # N * W * C
        reshape_cnn_feature = tf.transpose(cnn_feature, (1, 0, 2))
        reshape_cnn_feature = cnn_feature

        # print "final cnn: ", reshape_cnn_feature.shape

        lstm_output = self.bilstm(reshape_cnn_feature, seq_len) # N * T * 2H

        # print "lstm_output: ", lstm_output

        logits = tf.reshape(lstm_output, [-1, self.rnn_hidden_num * 2]) # (N * T) * 2H
        
        W = tf.Variable(tf.truncated_normal([self.rnn_hidden_num * 2, self.num_classes], stddev=0.1), name="W")
        b = tf.Variable(tf.constant(0., shape=[self.num_classes]), name="b")

        logits = tf.matmul(logits, W) + b # (N * T) * Class

        logits = tf.reshape(logits, [num_rois, -1, self.num_classes])
        logits = tf.reshape(logits, [nums, -1, self.num_classes])
        logits = tf.reshape(logits, [num_rois, -1, self.num_classes])
        
        logits = tf.transpose(logits, (1, 0, 2))

        return logits
我不知道转移学习是否明智? 还有其他方法可以在我自己的数据集上使用预先训练好的模型和finetune,但要使用其他数量的字符类吗?
我真的很困惑如何解决这项任务,以及如何尝试改进结果。我已经尝试过在没有预先训练过的模型的情况下从头开始训练,但是模型过多。

您是否尝试过不同的模型来解决OCR问题?不同的模型意味着什么?你指的是不同的预训练模型吗?正如你所说,你没有通过使用该模型获得更好的分数,那么为什么不使用SOTA模型(即ABCNet)。仅供参考,对于OCR问题,您不需要训练最后一层或其他东西(与一般的计算机视觉问题不同);只需使用预先训练好的重量进行初始化,并以较少的时间跑。好吧,你不能只重新训练OCR模型的最后一层。任何OCR模型的最后一层基本上都是识别文本的识别分支。它与检测分支相结合。要对特定类型的文本进行微调,您需要对synth文本数据集进行培训,并对特定目标数据集进行微调。现在,所选模型已经对synth文本进行了培训,并对基准数据集(即total text、icdar等)进行了微调,我们所能做的就是将他们预先训练好的体重初始化到我们的任务或特定的数据集。