Python 如何在函数内部分布和定义的tensorflow图中添加操作?

Python 如何在函数内部分布和定义的tensorflow图中添加操作?,python,tensorflow,keras,Python,Tensorflow,Keras,我想定义2个类,每个类使用不同的图。因此,作为第一步,我为类模型定义了一个图 如果删除带有self.graph.as_default():的行,并使用默认图形运行,即带有tf.Session()as sess:的,下面的代码段将起作用 但是我想在一个图中定义它,这样我就可以用一个新的图添加另一个类,并使这两个图并行或顺序运行 我是tensorflow的新手,所以仍然不确定下面向图中添加操作的方法是否正确 import functools import tensorflow as tf tf.re

我想定义2个类,每个类使用不同的图。因此,作为第一步,我为
类模型定义了一个图

如果删除带有self.graph.as_default():
的行
,并使用默认图形运行,即带有tf.Session()as sess:
,下面的代码段将起作用

但是我想在一个图中定义它,这样我就可以用一个新的图添加另一个类,并使这两个图并行或顺序运行

我是tensorflow的新手,所以仍然不确定下面向图中添加操作的方法是否正确

import functools
import tensorflow as tf
tf.reset_default_graph()
from tensorflow.examples.tutorials.mnist import input_data


def doublewrap(function):
    """
    A decorator decorator, allowing to use the decorator to be used without
    parentheses if not arguments are provided. All arguments must be optional.
    """
    @functools.wraps(function)
    def decorator(*args, **kwargs):
        if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
            return function(args[0])
        else:
            return lambda wrapee: function(wrapee, *args, **kwargs)
    return decorator


@doublewrap
def define_scope(function, scope=None, *args, **kwargs):
    """
    A decorator for functions that define TensorFlow operations. The wrapped
    function will only be executed once. Subsequent calls to it will directly
    return the result so that operations are added to the graph only once.
    The operations added by the function live within a tf.variable_scope(). If
    this decorator is used with arguments, they will be forwarded to the
    variable scope. The scope name defaults to the name of the wrapped
    function.
    """
    attribute = '_cache_' + function.__name__
    name = scope or function.__name__
    @property
    @functools.wraps(function)
    def decorator(self):
        if not hasattr(self, attribute):
            with tf.variable_scope(name, *args, **kwargs):
                setattr(self, attribute, function(self))
        return getattr(self, attribute)
    return decorator


class Model:

    def __init__(self, image, label):
        self.graph = tf.Graph()
        self.image = image
        self.label = label
        self.prediction
        self.optimize
        self.error

    @define_scope(initializer=tf.contrib.slim.xavier_initializer())
    def prediction(self):
        with self.graph.as_default():
            x = self.image
            x = tf.contrib.slim.fully_connected(x, 200)
            x = tf.contrib.slim.fully_connected(x, 200)
            x = tf.contrib.slim.fully_connected(x, 10, tf.nn.softmax)
        return x

    @define_scope
    def optimize(self):
        with self.graph.as_default():
            current_error=self.error
            logprob = tf.log(self.prediction + 1e-12) *(1-current_error)  #Here changed ????????????

            cross_entropy = -tf.reduce_sum(self.label * logprob)
            optimizer = tf.train.RMSPropOptimizer(0.03)
            trainop = optimizer.minimize(cross_entropy) 
        return trainop

    @define_scope
    def error(self):
        with self.graph.as_default():
            mistakes = tf.not_equal(
                tf.argmax(self.label, 1), tf.argmax(self.prediction, 1))
            me = tf.reduce_mean(tf.cast(mistakes, tf.float32))
        return me

    # @define_scope
    # def accuracy(self):
    #     accuracy = tf.reduce_sum()


def main():
    mnist = input_data.read_data_sets('../../MNIST_data/', one_hot=True)
    image = tf.placeholder(tf.float32, [None, 784])
    label = tf.placeholder(tf.float32, [None, 10])
    model = Model(image, label)

    with tf.Session(graph=model.graph) as sess:
        sess.run(tf.initialize_all_variables())

        for _ in range(10):
          images, labels = mnist.test.images, mnist.test.labels
          error = sess.run(model.error, {image: images, label: labels})
          print('Test error {:6.2f}%'.format(100 * error))
          for _ in range(60):
            images, labels = mnist.train.next_batch(100)
            sess.run(model.optimize, {image: images, label: labels})


if __name__ == '__main__':
    main()
当执行上述代码时,我收到以下错误消息

ValueError:Tensor(“error/Const:0”,shape=(1,),dtype=int32)必须为 从与张量相同的图形(“优化/转换:0”,形状=(?,), dtype=32)


您需要在main中定义新图形,并在其中定义占位符。完成后,可以将图形作为参数传递给类

class Model:

    def __init__(self, graph, image, label):
        self.graph = graph
        self.image = image
        self.label = label
        self.prediction
        self.optimize
        self.error

    @define_scope(initializer=tf.contrib.slim.xavier_initializer())
    def prediction(self):
        with self.graph.as_default():
            x = self.image
            x = tf.contrib.slim.fully_connected(x, 200)
            x = tf.contrib.slim.fully_connected(x, 200)
            x = tf.contrib.slim.fully_connected(x, 10, tf.nn.softmax)
            return x

    @define_scope
    def optimize(self):
        with self.graph.as_default():
            current_error=self.error
            logprob = tf.log(self.prediction + 1e-12) *(1-current_error)  #Here changed ????????????

            cross_entropy = -tf.reduce_sum(self.label * logprob)
            optimizer = tf.train.RMSPropOptimizer(0.03)
            trainop = optimizer.minimize(cross_entropy) 
            return trainop

    @define_scope
    def error(self):
        with self.graph.as_default():
            mistakes = tf.not_equal(
                tf.argmax(self.label, 1), tf.argmax(self.prediction, 1))
            me = tf.reduce_mean(tf.cast(mistakes, tf.float32))
            return me

    # @define_scope
    # def accuracy(self):
    #     accuracy = tf.reduce_sum()


def main():
    mnist = input_data.read_data_sets('../../MNIST_data/', one_hot=True)
    graph1 = tf.Graph()
    with graph1.as_default():
        image = tf.placeholder(tf.float32, [None, 784])
        label = tf.placeholder(tf.float32, [None, 10])
        model = Model(graph1, image, label)

        with tf.Session(graph=graph1) as sess:
            sess.run(tf.initialize_all_variables())

            for _ in range(10):
              images, labels = mnist.test.images, mnist.test.labels
              error = sess.run(model.error, {image: images, label: labels})
              print('Test error {:6.2f}%'.format(100 * error))
              for _ in range(60):
                images, labels = mnist.train.next_batch(100)
                sess.run(model.optimize, {image: images, label: labels})


if __name__ == '__main__':
    main()