Python 为什么塔诺不更新?
我打算解决以下问题:theano函数的输出值是类方法在执行while循环后返回的值,在该循环中更新参数:Python 为什么塔诺不更新?,python,python-3.x,deep-learning,theano,Python,Python 3.x,Deep Learning,Theano,我打算解决以下问题:theano函数的输出值是类方法在执行while循环后返回的值,在该循环中更新参数: import theano import theano.tensor as T import numpy as np import copy theano.config.exception_verbosity = 'high' class Test(object): def __init__(self): self.rate=0.01 W_val=4
import theano
import theano.tensor as T
import numpy as np
import copy
theano.config.exception_verbosity = 'high'
class Test(object):
def __init__(self):
self.rate=0.01
W_val=40.00
self.W=theano.shared(value=W_val, borrow=True)
def start(self, x, y):
for i in range(5):
z=T.mean(x*self.W/y)
gz=T.grad(z, self.W)
self.W-=self.rate*gz
return z
x_set=np.array([1.,2.,1.,2.,1.,2.,1.,2.,1.,2.])
y_set=np.array([1,2,1,2,1,2,1,2,1,2])
x_set = theano.shared(x_set, borrow=True)
y_set = theano.shared(y_set, borrow=True)
y_set=T.cast(y_set, 'int32')
batch_size=2
x = T.dvector('x')
y = T.ivector('y')
index = T.lscalar()
test = Test()
cost=test.start(x,y)
train = theano.function(
inputs=[index],
outputs=cost,
givens={
x: x_set[index * batch_size: (index + 1) * batch_size],
y: y_set[index * batch_size: (index + 1) * batch_size]
}
)
for i in range(5):
result=train(i)
print(result)
这是打印的结果:
39.96000000089407
39.96000000089407
39.96000000089407
39.96000000089407
39.96000000089407
现在平均值的梯度(x*W/y)等于1(因为x和y总是有相同的值)。所以我第一次应该是39.95,比39.90等等。。。
为什么我总是有相同的结果
谢谢在我的朋友帕斯卡的帮助下,我得到了这个结果。解决方案是创建其他符号变量:
class Test(object):
def __init__(self):
self.rate=0.01
W_val=40.00
self.W=theano.shared(value=W_val, borrow=True)
def start(self, x, y):
new_W=self.W
for i in range(5):
z=T.mean(x*new_W/y)
gz=T.grad(z, new_W)
new_W-=self.rate*gz
return z, (self.W, new_W)
并修改theano函数:
test = Test()
cost, updates=test.start(x,y)
train = theano.function(
inputs=[index],
outputs=cost,
updates=[updates],
givens={
x: x_set[index * batch_size: (index + 1) * batch_size],
y: y_set[index * batch_size: (index + 1) * batch_size]
}
)
输出:
39.96000000089407
39.91000000201166
39.860000003129244
39.81000000424683
39.76000000536442
在你的开始循环中。在每次迭代中,你都会覆盖你的z变量,这难道不是原因吗?@PepaKorbel:我必须覆盖z,但我更新了W的值,所以z的值必须改变,不是吗?哦,是的,我仔细检查了一下,你正在使用z变量更新权重。那一定在某个地方else@PepaKorbel是的,但是在哪里??我快疯了…你的x和y张量是一样的?