Python 在类中处理tensorflow会话

Python 在类中处理tensorflow会话,python,machine-learning,tensorflow,deep-learning,Python,Machine Learning,Tensorflow,Deep Learning,我用tensorflow来预测神经网络的输出。我有一个类,在那里我描述了神经网络,我有一个主文件,在那里进行预测,并根据结果更新权重。然而,预测似乎非常缓慢。下面是我的代码的样子: class NNPredictor(): def __init__(self): self.input = tf.placeholder(...) ... self.output = (...) #Neural network output def pr

我用tensorflow来预测神经网络的输出。我有一个类,在那里我描述了神经网络,我有一个主文件,在那里进行预测,并根据结果更新权重。然而,预测似乎非常缓慢。下面是我的代码的样子:

class NNPredictor():
    def __init__(self):
        self.input = tf.placeholder(...)
        ...
        self.output = (...) #Neural network output
    def predict_output(self, sess, input):
        return sess.run(tf.squeeze(self.output), feed_dict = {self.input: input})
sess = tf.Session()
predictor = NNPredictor()

input = #some initial value 
for i in range(iter):
    output = predictor.predict_output(sess, input)
    input = #some function of output
以下是主文件的外观:

class NNPredictor():
    def __init__(self):
        self.input = tf.placeholder(...)
        ...
        self.output = (...) #Neural network output
    def predict_output(self, sess, input):
        return sess.run(tf.squeeze(self.output), feed_dict = {self.input: input})
sess = tf.Session()
predictor = NNPredictor()

input = #some initial value 
for i in range(iter):
    output = predictor.predict_output(sess, input)
    input = #some function of output
但是,如果我在类中使用以下函数定义:

    def predict_output(self):
        return self.output
并具有如下主文件:

sess = tf.Session()
predictor = NNPredictor()

input = #some initial value 
output_op = predictor.predict_value()
for i in range(iter):
    output = np.squeeze(sess.run(output_op, feed_dict = {predictor.input: input}))
    input = #some function of output

代码运行速度快了20-30倍。我似乎不明白这里的工作原理,我想知道最佳实践是什么。

这与Python屏蔽的底层内存访问有关。下面是一些示例代码来说明这个想法:

import time

runs = 10000000

class A:
    def __init__(self):
    self.val = 1

    def get_val(self):
    return self.val

# Using method to then call object attribute
obj = A()
start = time.time()
total = 0
for i in xrange(runs):
    total += obj.get_val()
end = time.time()
print end - start

# Using object attribute directly
start = time.time()
total = 0
for i in xrange(runs):
    total += obj.val
end = time.time()
print end - start

# Assign to local_var first
start = time.time()
total = 0
local_var = obj.get_val()
for i in xrange(runs):
    total += local_var
end = time.time()
print end - start
在我的计算机上,它按以下时间运行:

1.49576115608
0.656110048294
0.551875114441

具体到您的情况,您在第一种情况下调用对象方法,但在第二种情况下不调用对象方法。如果您以这种方式多次调用代码,则会有显著的性能差异。

这与Python屏蔽的底层内存访问有关。下面是一些示例代码来说明这个想法:

import time

runs = 10000000

class A:
    def __init__(self):
    self.val = 1

    def get_val(self):
    return self.val

# Using method to then call object attribute
obj = A()
start = time.time()
total = 0
for i in xrange(runs):
    total += obj.get_val()
end = time.time()
print end - start

# Using object attribute directly
start = time.time()
total = 0
for i in xrange(runs):
    total += obj.val
end = time.time()
print end - start

# Assign to local_var first
start = time.time()
total = 0
local_var = obj.get_val()
for i in xrange(runs):
    total += local_var
end = time.time()
print end - start
在我的计算机上,它按以下时间运行:

1.49576115608
0.656110048294
0.551875114441

具体到您的情况,您在第一种情况下调用对象方法,但在第二种情况下不调用对象方法。如果您以这种方式多次调用代码,则会有显著的性能差异。

感谢您的回答,我确实意识到调用对象方法可能会带来一定的开销,但这里的开销太大了。例如,在一次迭代中,调用对象方法需要0.03秒,而在没有对象方法的情况下,通常需要0.001-0.002s。这一切可能仅仅是因为你提到的,还是还有其他原因呢?我没有对单个对象方法调用进行计时。但是如果你的计时是正确的,它大约是20倍,不是吗?是的,但是我使用的代码有点不同。那边的方法是:
返回tf.squese(self.output)
。起初,我认为它无关紧要,但现在我意识到它很慢,因为它每次都创建一个新节点,而另一个
predict\u value()
函数也没有显示。也许您可以将所有相关函数添加到您的问题中,以便将来其他人能够理解。感谢您的回答,我确实意识到调用对象方法可能会带来一定的开销,但这里的开销太高了。例如,在一次迭代中,调用对象方法需要0.03秒,而在没有对象方法的情况下,通常需要0.001-0.002s。这一切可能仅仅是因为你提到的,还是还有其他原因呢?我没有对单个对象方法调用进行计时。但是如果你的计时是正确的,它大约是20倍,不是吗?是的,但是我使用的代码有点不同。那边的方法是:
返回tf.squese(self.output)
。起初,我认为它无关紧要,但现在我意识到它很慢,因为它每次都创建一个新节点,而另一个
predict\u value()
函数也没有显示。也许你可以在你的问题中添加所有相关的功能,以便将来其他人能够理解。