Tensorflow (反)循环lstm自动编码器-错误跳跃

Tensorflow (反)循环lstm自动编码器-错误跳跃,tensorflow,deep-learning,convolution,lstm,autoencoder,Tensorflow,Deep Learning,Convolution,Lstm,Autoencoder,我正在尝试构建一个卷积lstm自动编码器,它也可以用Tensorflow预测未来和过去,它在一定程度上可以工作,但误差有时会反弹,所以本质上,它永远不会收敛 模型如下: 对于lstm的每个时间步,编码器从20帧反弹mnist视频的64x64帧开始。LSTM的每个堆叠层都会将其减半,并通过2x2卷积增加深度,步幅为2。所以->32x32x3->…->1x96 另一方面,lstm在其状态上以1的步长执行3x3卷积。将两个结果连接起来以形成新状态。同样,解码器使用转置卷积返回原始格式。然后计算平方误差

我正在尝试构建一个卷积lstm自动编码器,它也可以用Tensorflow预测未来和过去,它在一定程度上可以工作,但误差有时会反弹,所以本质上,它永远不会收敛

模型如下:

对于lstm的每个时间步,编码器从20帧反弹mnist视频的64x64帧开始。LSTM的每个堆叠层都会将其减半,并通过2x2卷积增加深度,步幅为2。所以->32x32x3->…->1x96 另一方面,lstm在其状态上以1的步长执行3x3卷积。将两个结果连接起来以形成新状态。同样,解码器使用转置卷积返回原始格式。然后计算平方误差

错误开始于2700左右,大约需要20小时才能降至1700左右。在这一点上,跳起来,有时跳回到2300,甚至像440300这样的荒谬的值,经常发生,我真的不能再低了。同样在这一点上,它通常可以精确地指出数字应该在哪里,但它太模糊了,以至于无法真正辨认出数字

我尝试了不同的学习速度和优化程序,所以如果有人知道为什么会出现跳跃,我会很高兴:

下面是一张损失的图表,包括各个时期:


通常,当渐变的幅值在某个点非常高并导致网络参数发生很大变化时,会发生这种情况。为了验证它确实是这样,你可以生成相同的梯度大小图,看看它们是否在损失跳跃之前跳跃。假设是这种情况,经典的方法是使用或一直使用

import tensorflow as tf
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

#based on code by loliverhennigh (Github)
class ConvCell(tf.contrib.rnn.RNNCell):
    count = 0   #exists only to remove issues with variable scope
    def __init__(self, shape, num_features, transpose = False):
        self.shape = shape 
        self.num_features = num_features
        self._state_is_tuple = True
        self._transpose = transpose
        ConvCell.count+=1
        self.count = ConvCell.count

    @property
    def state_size(self):
        return (tf.contrib.rnn.LSTMStateTuple(self.shape[0:4],self.shape[0:4]))

    @property
    def output_size(self):
        return tf.TensorShape(self.shape[1:4])

#here comes to the actual conv lstm implementation, if transpose = true, it performs a deconvolution on the input
    def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope or type(self).__name__+str(self.count)): 
            c, h = state
            state_shape = h.shape
            input_shape = inputs.shape

            #filter variables and convolutions on data coming from the same cell, a time step previous
            h_filters = tf.get_variable("h_filters",[3,3,state_shape[3],self.num_features])
            h_filters_gates = tf.get_variable("h_filters_gates",[3,3,state_shape[3],3])
            h_partial = tf.nn.conv2d(h,h_filters,[1,1,1,1],'SAME')
            h_partial_gates = tf.nn.conv2d(h,h_filters_gates,[1,1,1,1],'SAME')

            c_filters = tf.get_variable("c_filters",[3,3,state_shape[3],3])
            c_partial = tf.nn.conv2d(c,c_filters,[1,1,1,1],'SAME')

            #filters and convolutions/deconvolutions on data coming fromthe cell input
            if self._transpose:
                x_filters = tf.get_variable("x_filters",[2,2,self.num_features,input_shape[3]])
                x_filters_gates = tf.get_variable("x_filters_gates",[2,2,3,input_shape[3]])
                x_partial = tf.nn.conv2d_transpose(inputs,x_filters,[int(state_shape[0]),int(state_shape[1]),int(state_shape[2]),self.num_features],[1,2,2,1],'VALID')
                x_partial_gates = tf.nn.conv2d_transpose(inputs,x_filters_gates,[int(state_shape[0]),int(state_shape[1]),int(state_shape[2]),3],[1,2,2,1],'VALID')
            else:
                x_filters = tf.get_variable("x_filters",[2,2,input_shape[3],self.num_features])
                x_filters_gates = tf.get_variable("x_filters_gates",[2,2,input_shape[3],3])
                x_partial = tf.nn.conv2d(inputs,x_filters,[1,2,2,1],'VALID')
                x_partial_gates = tf.nn.conv2d(inputs,x_filters_gates,[1,2,2,1],'VALID')

            #some more lstm gate business
            gate_bias = tf.get_variable("gate_bias",[1,1,1,3])
            h_bias = tf.get_variable("h_bias",[1,1,1,self.num_features*2])

            gates = h_partial_gates + x_partial_gates + c_partial + gate_bias

            i,f,o = tf.split(gates,3,axis=3)

            #concatenate the units coming from the spacial and the temporal dimension to build a unified state
            concat = tf.concat([h_partial,x_partial],3) + h_bias

            new_c = tf.nn.relu(concat)*tf.sigmoid(i)+c*tf.sigmoid(f)
            new_h = new_c * tf.sigmoid(o)

            new_state = tf.contrib.rnn.LSTMStateTuple(new_c,new_h)
            return new_h, new_state #its redundant, but thats how tensorflow likes it, apparently


#global variables               
LEARNING_RATE = 0.005
ITERATIONS_PER_EPOCH = 80
BATCH_SIZE = 75
TEST = False    #manual switch to go from training to testing

if TEST:
    BATCH_SIZE = 1

inputs  = tf.placeholder(tf.float32, (20, BATCH_SIZE, 64, 64,1))    


shape0 = [BATCH_SIZE,64,64,2]
shape1 = [BATCH_SIZE,32,32,6]
shape2 = [BATCH_SIZE,16,16,12]
shape3 = [BATCH_SIZE,8,8,24]
shape4 = [BATCH_SIZE,4,4,48]
shape5 = [BATCH_SIZE,2,2,96]
shape6 = [BATCH_SIZE,1,1,192]

#apparently tf.multirnncell has very specific requirements for the initial states oO
initial_state1 = (tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape1),tf.zeros(shape1)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape2),tf.zeros(shape2)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape3),tf.zeros(shape3)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape4),tf.zeros(shape4)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape5),tf.zeros(shape5)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape6),tf.zeros(shape6)))
initial_state2 = (tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape5),tf.zeros(shape5)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape4),tf.zeros(shape4)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape3),tf.zeros(shape3)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape2),tf.zeros(shape2)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape1),tf.zeros(shape1)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape0),tf.zeros(shape0)))

#encoding part of the autoencoder graph
cell1 = ConvCell(shape1,3)
cell2 = ConvCell(shape2,6)
cell3 = ConvCell(shape3,12)
cell4 = ConvCell(shape4,24)
cell5 = ConvCell(shape5,48)
cell6 = ConvCell(shape6,96)

mcell = tf.contrib.rnn.MultiRNNCell([cell1,cell2,cell3,cell4,cell5,cell6])

rnn_outputs, rnn_states = tf.nn.dynamic_rnn(mcell, inputs[0:20,:,:,:],initial_state=initial_state1,dtype=tf.float32, time_major=True)


#decoding part of the autoencoder graph, forward block and backwards block
cell9a = ConvCell(shape5,48,transpose = True)
cell10a = ConvCell(shape4,24,transpose = True)
cell11a = ConvCell(shape3,12,transpose = True)
cell12a = ConvCell(shape2,6,transpose = True)
cell13a = ConvCell(shape1,3,transpose = True)
cell14a = ConvCell(shape0,1,transpose = True)

mcella = tf.contrib.rnn.MultiRNNCell([cell9a,cell10a,cell11a,cell12a,cell13a,cell14a])

cell9b = ConvCell(shape5,48,transpose = True)
cell10b = ConvCell(shape4,24,transpose = True)
cell11b= ConvCell(shape3,12,transpose = True)
cell12b = ConvCell(shape2,6,transpose = True)
cell13b = ConvCell(shape1,3,transpose = True)
cell14b = ConvCell(shape0,1,transpose = True)

mcellb = tf.contrib.rnn.MultiRNNCell([cell9b,cell10b,cell11b,cell12b,cell13b,cell14b])

def PredictionLayer(rnn_outputs,viewPoint = 11, reverse = False):

    predLength = viewPoint-2 if reverse else 20-viewPoint   #vision is the input for the decoder
    vision = tf.concat([rnn_outputs[viewPoint-1:viewPoint,:,:,:],tf.zeros([predLength,BATCH_SIZE,1,1,192])],0)

    if reverse:
        rnn_outputs2, rnn_states = tf.nn.dynamic_rnn(mcellb, vision, initial_state = initial_state2, time_major=True)
    else:
        rnn_outputs2, rnn_states = tf.nn.dynamic_rnn(mcella, vision, initial_state = initial_state2, time_major=True)


    mean = tf.reduce_mean(rnn_outputs2,4)

    if TEST:
        return mean

    if reverse:
        return tf.reduce_sum(tf.square(mean-inputs[viewPoint-2::-1,:,:,:,0]))
    else:
        return tf.reduce_sum(tf.square(mean-inputs[viewPoint-1:20,:,:,:,0]))



if TEST:
    mean = tf.concat([PredictionLayer(rnn_outputs,11,True)[::-1,:,:,:],createPredictionLayer(rnn_outputs,11)],0)
else:   #training part of the graph
    error = tf.zeros([1])
    for i in range(8,15): #range size of 7 or less works, 9 or more does not, no idea why
        error += PredictionLayer(rnn_outputs, i)
        error += PredictionLayer(rnn_outputs, i, True)

    train_fn = tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE).minimize(error)



################################################################################
##                           TRAINING LOOP                                    ##
################################################################################
#code based on siemanko/tf_lstm.py (Github)

gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
saver = tf.train.Saver(restore_sequentially=True, allow_empty=True,)
session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
session.run(tf.global_variables_initializer())
vids = np.load("mnist_test_seq.npy") #20/10000/64/64 , moving mnist dataset from http://www.cs.toronto.edu/~nitish/unsupervised_video/
vids = vids[:,0:6000,:,:]   #training set
saver.restore(session,tf.train.latest_checkpoint('./conv_lstm_multiples_v2/'))
#saver.restore(session,'.\conv_lstm_multiples\iteration-74')


for epoch in range(1000):
    if TEST:
        break
    epoch_error = 0

    #randomize batches each epoch
    vids = np.swapaxes(vids,0,1)
    np.random.shuffle(vids)
    vids = np.swapaxes(vids,0,1)


    for i in range(ITERATIONS_PER_EPOCH):
        #running the graph and feeding data
        err,_ = session.run([error, train_fn], {inputs: np.expand_dims(vids[:,i*BATCH_SIZE:(i+1)*BATCH_SIZE,:,:],axis=4)})

        print(err)
        epoch_error += err

    #training error each epoch and regular saving
    epoch_error /= (ITERATIONS_PER_EPOCH*BATCH_SIZE*4096*20*7)
    if (epoch+1) % 5 == 0:
        saver.save(session,'.\conv_lstm_multiples_v2\iteration',global_step=epoch)
        print("saved")
    print("Epoch %d, train error: %f" % (epoch, epoch_error))

#testing
plt.ion()
f, axarr = plt.subplots(2)
vids = np.load("mnist_test_seq.npy")

for i in range(6000,10000):
    img = session.run([mean], {inputs: np.expand_dims(vids[:,i:i+1,:,:],axis=4)})
    for j in range(20):
        axarr[0].imshow(img[0][j,0,:,:])
        axarr[1].imshow(vids[j,i,:,:])
        plt.show()
        plt.pause(0.1)