PyTorch:理解运行时错误:梯度计算所需的变量之一已通过就地操作修改
下面是一个问题设置: 将rgb转换为hsv的函数:PyTorch:理解运行时错误:梯度计算所需的变量之一已通过就地操作修改,pytorch,rgb,hsv,Pytorch,Rgb,Hsv,下面是一个问题设置: 将rgb转换为hsv的函数: def rbg2hsv(img_rgb): batch_size, channel, height, width = img_rgb.size() r, g, b = img_rgb[:, 0, :, :], img_rgb[:, 1, :, :], img_rgb[:, 2, :, :] t = torch.min(img_rgb, dim=1, keepdim=False)[0] v = torch.ma
def rbg2hsv(img_rgb):
batch_size, channel, height, width = img_rgb.size()
r, g, b = img_rgb[:, 0, :, :], img_rgb[:, 1, :, :], img_rgb[:, 2, :, :]
t = torch.min(img_rgb, dim=1, keepdim=False)[0]
v = torch.max(img_rgb, dim=1, keepdim=False)[0]
s = (v - t) / (v + 1e-6)
s[v == 0] = 0
# v==r
hr = 60 * (g - b) / (v - t + 1e-6)
# v==g
hg = 120 + 60 * (b - r) / (v - t + 1e-6)
# v==b
hb = 240 + 60 * (r - g) / (v - t + 1e-6)
h = torch.zeros(batch_size, height, width, requires_grad=False)
if torch.cuda.is_available():
h = h.cuda()
h = h.flatten()
hr = hr.flatten()
hg = hg.flatten()
hb = hb.flatten()
h[(v == b).flatten()] = hb[(v == b).flatten()]
h[(v == g).flatten()] = hg[(v == g).flatten()]
h[(v == r).flatten()] = hr[(v == r).flatten()]
h[h < 0] += 360
h = torch.reshape(h, (batch_size, height, width))
img_hsv = torch.stack([h, s, v])
img_hsv = img_hsv.permute(1, 0, 2, 3)
return img_hsv
Variant2:FAIL
失败并出现错误:RuntimeError:gradient计算所需的一个变量已被一个就地操作修改:[torch.FloatTensor[1,2,32,32]],它是ViewBackward的输出0,处于版本1;应为版本0。提示:使用torch.autograd.set\u detect\u normal(True),启用异常检测以查找无法计算其梯度的操作。
如果我将rbg2hsv
更改为某个虚拟add_fn
所有三种变体均有效:
def add_fn(x):
x = x + 1
return x
img_rgb = torch.rand(2, 3, 32, 32)
img_rgb.requires_grad = True
instance_norm = nn.InstanceNorm2d(1, affine=False)
gamma = 2
beta = 1
img_hsv = add_fn(img_rgb)
# 1: OK
#img_hsv[:,2:3,:,:] = img_hsv[:,2:3,:,:] * gamma + beta
# 2: OK
img_hsv[:,2:3,:,:] = instance_norm(img_hsv[:,2:3,:,:])
# 3: OK
#img_hsv[:,2:3,:,:] = instance_norm(img_hsv[:,2:3,:,:].clone())
img_hsv.mean().backward()
print('img_rgb.grad.size()', img_rgb.grad.size())
因此,我的问题是:
rbg2hsv
更改为add\fn
会使所有变体都工作?(即,rbg2hsv
有问题吗?).clone()
时的规则,以及当不需要时的规则,即当某些操作考虑到位或不到位时的规则
def add_fn(x):
x = x + 1
return x
img_rgb = torch.rand(2, 3, 32, 32)
img_rgb.requires_grad = True
instance_norm = nn.InstanceNorm2d(1, affine=False)
gamma = 2
beta = 1
img_hsv = add_fn(img_rgb)
# 1: OK
#img_hsv[:,2:3,:,:] = img_hsv[:,2:3,:,:] * gamma + beta
# 2: OK
img_hsv[:,2:3,:,:] = instance_norm(img_hsv[:,2:3,:,:])
# 3: OK
#img_hsv[:,2:3,:,:] = instance_norm(img_hsv[:,2:3,:,:].clone())
img_hsv.mean().backward()
print('img_rgb.grad.size()', img_rgb.grad.size())