如何在tensorflow中使用集合来保存我自己的对象

如何在tensorflow中使用集合来保存我自己的对象,tensorflow,Tensorflow,我想使用tf.add_to_collection()来保存我自己的对象,以便以后轻松获取它们。 以下是代码段: class Model(object): def __init__(self, scope, is_training=True): 将对象添加到集合: for i in xrange(num_gpus): with tf.device("/gpu:%d"%i): with tf.name_scope("tower_%d"%i) as scope:

我想使用tf.add_to_collection()来保存我自己的对象,以便以后轻松获取它们。 以下是代码段:

class Model(object):
    def __init__(self, scope, is_training=True):
将对象添加到集合:

for i in xrange(num_gpus):
    with tf.device("/gpu:%d"%i):
        with tf.name_scope("tower_%d"%i) as scope:
            m = Model.Model(scope)
            tf.add_to_collection("train_model", m)
models = tf.get_collection("train_model")
从集合中获取对象:

for i in xrange(num_gpus):
    with tf.device("/gpu:%d"%i):
        with tf.name_scope("tower_%d"%i) as scope:
            m = Model.Model(scope)
            tf.add_to_collection("train_model", m)
models = tf.get_collection("train_model")
代码工作正常,但我收到一个警告:

WARNING:tensorflow:Error encountered when serializing train_model.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'Model' object has no attribute 'name
如何避免此警告?

此警告(可能)在您调用时生成,它试图写出一个“元图”,表示
tf.Graph
的内容,包括所有图形集合的内容

避免警告的最简单方法是在调用
saver.save()时传递
write\u meta\u graph=False
。但是,这会使您以后无法导入元图

如果要保存元图并避免警告,则需要实现必要的钩子(
到_proto
和来自_proto
),以序列化格式将
模型
对象序列化为协议缓冲区。说明了如何做到这一点,但基本思路如下:

  • 定义描述
    模型
    对象内容的协议缓冲区(
    ModelProto

  • 定义一个将
    模型
    序列化为
    ModelProto
    ModelProto()
    函数:

    def model_to_proto(model):
        ret = ModelProto()
        # Set fields of `ret` from `model`.
        return ret
    
  • 从proto()定义一个反序列化
    ModelProto
    并返回
    model
    model\u函数:

    def model_from_proto(model_proto):
        # Construct a `Model` from the fields of `model_proto`.
        return Model(...)
    
  • “train\u model”
    集合注册您的功能。这当前使用一个未记录的函数,称为: