Python 减少内存消耗tensorflow冻结模型

Python 减少内存消耗tensorflow冻结模型,python,memory,tensorflow,tensorflow-serving,Python,Memory,Tensorflow,Tensorflow Serving,使用教程代码计算使用检查点文件与冻结模型(.pb文件)节省的内存。理论(以及afaik的实现)是,检查点文件有很多变量,这些变量在训练模型时使用,而冻结的模型将它们转换为常量(权重和偏差),因此内存消耗必须更低。但是,当我比较内存消耗时,差别只有50 MB左右(660 MB对610 MB)。我想知道,冻结一个模型有什么用?因为不管怎样,在为模型服务时,内存中的大小与从检查点文件重新创建模型的大小没有太大区别。在下面我测量的地方贴了一些代码..不确定代码是否能给出整个图片,但它完全受(解码方法)的

使用教程代码计算使用检查点文件与冻结模型(.pb文件)节省的内存。理论(以及afaik的实现)是,检查点文件有很多变量,这些变量在训练模型时使用,而冻结的模型将它们转换为常量(权重和偏差),因此内存消耗必须更低。但是,当我比较内存消耗时,差别只有50 MB左右(660 MB对610 MB)。我想知道,冻结一个模型有什么用?因为不管怎样,在为模型服务时,内存中的大小与从检查点文件重新创建模型的大小没有太大区别。在下面我测量的地方贴了一些代码..不确定代码是否能给出整个图片,但它完全受(解码方法)的启发

请注意,下面的代码适用于我加载冻结图、创建输入数据和直接调用运行的版本。我还试图确定内存在哪一点爆炸,从输出中可以明显看出它在
会话之后运行。有没有关于我还能做些什么来减少内存消耗的建议?我的最后一个模型将涉及至少多几个这种性质的模型一起运行。非常感谢

def decode_NER():
    graph = load_graph( 'save/NER_SAVE/frozen_model.pb' )
    for op in graph.get_operations():
       print(op.name)
    x = graph.get_tensor_by_name('prefix/encoder0:0')
    z = graph.get_tensor_by_name('prefix/encoder1:0')
    x1 = graph.get_tensor_by_name('prefix/encoder2:0')
    z1 = graph.get_tensor_by_name('prefix/encoder3:0')
    z2 = graph.get_tensor_by_name('prefix/encoder4:0')
    z3 = graph.get_tensor_by_name('prefix/decoder0:0')
    z4 = graph.get_tensor_by_name('prefix/decoder1:0')
    z5 = graph.get_tensor_by_name('prefix/decoder2:0')
    z6 = graph.get_tensor_by_name('prefix/decoder3:0')
    z7 = graph.get_tensor_by_name('prefix/decoder4:0')
    z8 = graph.get_tensor_by_name('prefix/decoder5:0')
    z9 = graph.get_tensor_by_name('prefix/decoder6:0')
    z10 = graph.get_tensor_by_name('prefix/decoder7:0')
    z11 = graph.get_tensor_by_name('prefix/decoder8:0')
    z12 = graph.get_tensor_by_name('prefix/decoder9:0')
    y = [graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection/BiasAdd:0') ,graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_1/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_2/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_3/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_4/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_5/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_6/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_7/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_8/BiasAdd:0'), graph.get_tensor_by_name('prefix/model_with_buckets/embedding_attention_seq2seq/embedding_attention_decoder/attention_decoder/AttnOutputProjection_9/BiasAdd:0') ]
   with tf.Session( graph=graph  ) as sess:
    ...........
       while sentence:
           killer = sentence.split(' ')
           from nltk import pos_tag, word_tokenize
           vik = pos_tag( word_tokenize(sentence)  )
           poss = ''
           for wd, tag in vik:
               poss += tag+' '
           token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes( poss ), en_vocab)
           print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
           print( subprocess.run(['free', '-mh'], stdout=subprocess.PIPE) )
           print( sys.getsizeof( sess ) )
  # Which bucket does it belong to?
           bucket_id = len(_buckets) - 1
           for i, bucket in enumerate(_buckets):
               if bucket[0] >= len(token_ids):
                   bucket_id = i
               break
               else:
                   logging.warning("Sentence truncated: %s", sentence)

  # Get a 1-element batch to feed the sentence to the model.
           encoder_inputs, decoder_inputs, target_weights = loc_get_batch(
                  {bucket_id: [(token_ids, [])]}, bucket_id)

           print( encoder_inputs )
           output_logits = sess.run(y,  feed_dict={
        x: encoder_inputs[0], z: encoder_inputs[1], x1: encoder_inputs[2], z1: encoder_inputs[3], z2: encoder_inputs[4], z3: decoder_inputs[0], z4: decoder_inputs[1], z5: decoder_inputs[2], z6: decoder_inputs[3], z7: decoder_inputs[4], z8: decoder_inputs[5], z9: decoder_inputs[6], z10: decoder_inputs[7], z11: decoder_inputs[8], z12: decoder_inputs[9]
           } )
  # This is a greedy decoder - outputs are just argmaxes of output_logits.
           print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
           print( subprocess.run(['free', '-mh'], stdout=subprocess.PIPE) )
           print( sys.getsizeof( sess ) )
           outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
           print(outputs)