Pytorch 使用卷积自动编码器在照片上添加微笑时出现问题

Pytorch 使用卷积自动编码器在照片上添加微笑时出现问题,pytorch,autoencoder,Pytorch,Autoencoder,我有一个包含图像的数据集,另一个数据集的描述如下: 这里有很多照片:戴太阳镜和不戴太阳镜的人,微笑和其他特征。我想做的是能够在人们不微笑的照片中添加微笑。 我是这样开始的: smile_ids = attrs['Smiling'].sort_values(ascending=False).iloc[100:125].index.values smile_data = data[smile_ids] no_smile_ids = attrs['Smiling'].sort_values(asc

我有一个包含图像的数据集,另一个数据集的描述如下:

这里有很多照片:戴太阳镜和不戴太阳镜的人,微笑和其他特征。我想做的是能够在人们不微笑的照片中添加微笑。 我是这样开始的:

smile_ids = attrs['Smiling'].sort_values(ascending=False).iloc[100:125].index.values
smile_data = data[smile_ids]

no_smile_ids = attrs['Smiling'].sort_values(ascending=True).head(5).index.values
no_smile_data = data[no_smile_ids]

eyeglasses_ids = attrs['Eyeglasses'].sort_values(ascending=False).head(25).index.values
eyeglasses_data = data[eyeglasses_ids]

sunglasses_ids = attrs['Sunglasses'].sort_values(ascending=False).head(5).index.values
sunglasses_data = data[sunglasses_ids]
def plot_gallery(images, h, w, n_row=3, n_col=6, with_title=False, titles=[]):
plt.figure(figsize=(1.5 * n_col, 1.7 * n_row))
plt.subplots_adjust(bottom=0, left=.01, right=.99, top=.90, hspace=.35)
for i in range(n_row * n_col):
    plt.subplot(n_row, n_col, i + 1)
    try:
        plt.imshow(images[i].reshape((h, w, 3)), cmap=plt.cm.gray, vmin=-1, vmax=1, interpolation='nearest')
        if with_title:
            plt.title(titles[i])
        plt.xticks(())
        plt.yticks(())
    except:
        pass
当我打印它们时,它们很好:

plot_gallery(smile_data, IMAGE_H, IMAGE_W, n_row=5, n_col=5, with_title=True, titles=smile_ids)

Plot gallery如下所示:

smile_ids = attrs['Smiling'].sort_values(ascending=False).iloc[100:125].index.values
smile_data = data[smile_ids]

no_smile_ids = attrs['Smiling'].sort_values(ascending=True).head(5).index.values
no_smile_data = data[no_smile_ids]

eyeglasses_ids = attrs['Eyeglasses'].sort_values(ascending=False).head(25).index.values
eyeglasses_data = data[eyeglasses_ids]

sunglasses_ids = attrs['Sunglasses'].sort_values(ascending=False).head(5).index.values
sunglasses_data = data[sunglasses_ids]
def plot_gallery(images, h, w, n_row=3, n_col=6, with_title=False, titles=[]):
plt.figure(figsize=(1.5 * n_col, 1.7 * n_row))
plt.subplots_adjust(bottom=0, left=.01, right=.99, top=.90, hspace=.35)
for i in range(n_row * n_col):
    plt.subplot(n_row, n_col, i + 1)
    try:
        plt.imshow(images[i].reshape((h, w, 3)), cmap=plt.cm.gray, vmin=-1, vmax=1, interpolation='nearest')
        if with_title:
            plt.title(titles[i])
        plt.xticks(())
        plt.yticks(())
    except:
        pass
那么我会:

def to_latent(pic):
with torch.no_grad():
    inputs = torch.FloatTensor(pic.reshape(-1, 45*45*3))
    inputs = inputs.to('cpu')
    autoencoder.eval()
    output = autoencoder.encode(inputs)        
    return output

def from_latent(vec):
with torch.no_grad():
    inputs = vec.to('cpu')
    autoencoder.eval()
    output = autoencoder.decode(inputs)        
    return output
之后:

smile_latent = to_latent(smile_data).mean(axis=0)
no_smile_latent = to_latent(no_smile_data).mean(axis=0)
sunglasses_latent = to_latent(sunglasses_data).mean(axis=0)

smile_vec = smile_latent-no_smile_latent
sunglasses_vec = sunglasses_latent - smile_latent
最后:

def add_smile(ids):
for id in ids:
    pic = data[id:id+1]
    latent_vec = to_latent(pic)
    latent_vec[0] += smile_vec
    pic_output = from_latent(latent_vec)
    pic_output = pic_output.view(-1,45,45,3).cpu()
    plot_gallery([pic,pic_output], IMAGE_H, IMAGE_W, n_row=1, n_col=2)
    
def add_sunglasses(ids):
for id in ids:
    pic = data[id:id+1]
    latent_vec = to_latent(pic)
    latent_vec[0] += sunglasses_vec
    pic_output = from_latent(latent_vec)
    pic_output = pic_output.view(-1,45,45,3).cpu()
    plot_gallery([pic,pic_output], IMAGE_H, IMAGE_W, n_row=1, n_col=2)
但是当我执行这一行时,我没有得到任何面孔:

add_smile(no_smile_ids)
输出:

有人能解释一下我的错误在哪里或者为什么会发生吗?谢谢你的帮助

添加:检查PICU输出的形状:


胡乱猜测,但你似乎是在广播图像,而不是排列轴。前者会产生跨批次/通道混合信息的不良影响

pic_output = pic_output.view(-1, 45, 45, 3).cpu()
应替换为

pic_output = pic_output.permute(0, 2, 3, 1).cpu()

假设张量
picu输出
的形状已经是
(-1,3,45,45)

如果我这样做,我会得到运行时错误:运行时错误:DIM的数量在permut中不匹配你能打印出
picu输出
的形状吗(在从调用
)当然。如果我更改行-出现错误,所以我在这行之后打印了它(没有更改):pic_output=pic_output.view(-1,45,45,3).cpu()输出是:torch.Size([1,45,45,3])你能在广播之前打印它的形状吗(在调用
pic_output.view(-1,45,45,3).cpu()
之前)?成功了,谢谢!