Python Tensorflow在构建和训练RNN时使用过多内存
我在tensorflow中构建了一个非标准RNN,当我尝试训练它时,它使用了太多内存。仅构建网络就需要1GB的内存,训练时内存使用量高达5GB。它也很慢 网络使用200个浮点数作为其内部状态,但在每个步骤中,只有100个浮点数用作网络的输入(选择这些浮点数的规则不是网络的一部分),因此我通过从张量中获取100个值,然后使用tf.stack重新组合它们来对此进行建模。然后使用另一个tf.stack将网络的输出与未用作输入的100个值合并Python Tensorflow在构建和训练RNN时使用过多内存,python,tensorflow,optimization,recurrent-neural-network,Python,Tensorflow,Optimization,Recurrent Neural Network,我在tensorflow中构建了一个非标准RNN,当我尝试训练它时,它使用了太多内存。仅构建网络就需要1GB的内存,训练时内存使用量高达5GB。它也很慢 网络使用200个浮点数作为其内部状态,但在每个步骤中,只有100个浮点数用作网络的输入(选择这些浮点数的规则不是网络的一部分),因此我通过从张量中获取100个值,然后使用tf.stack重新组合它们来对此进行建模。然后使用另一个tf.stack将网络的输出与未用作输入的100个值合并 import tensorflow as tf from r
import tensorflow as tf
from random import sample, random
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial, dtype=tf.float32)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial, dtype=tf.float32)
def run():
sess = tf.Session()
W0 = weight_variable((100, 200))
B0 = bias_variable([200])
W1 = weight_variable((200, 200))
B1 = bias_variable([200])
W2 = weight_variable((200, 100))
B2 = bias_variable([100])
memory = tf.constant(0, tf.float32, (200,))
outputs = []
correct_outputs = []
for i in range(40):
indexes = sample(range(200), 100)
memory_selection = [memory[z] for z in indexes]
S = tf.stack(memory_selection, axis=0)
S = tf.stack([S], axis=0)
Input = tf.nn.relu(tf.matmul(S, W0) + B0)
Hidden = tf.nn.relu(tf.matmul(Input, W1) + B1)
Output = tf.nn.relu(tf.matmul(Hidden, W2) + B2)
memory_output = []
used = 0
for z in range(200):
if z in indexes:
memory_output.append(Output[0][used])
used += 1
else:
memory_output.append(memory[z])
memory = tf.stack(memory_output, axis=0)
outputs.append(Output)
correct_outputs.append([random() for _ in range(100)])
print("Network Built")
outputs = tf.stack(outputs, axis=0)
correct_outputs = tf.constant(correct_outputs)
loss = tf.nn.l2_loss(outputs-correct_outputs)
optimiser = tf.contrib.optimizer_v2.AdamOptimizer(0.0001)
train = optimiser.minimize(loss)
sess.run(tf.global_variables_initializer())
print("Variables Initialized")
sess.run(train)
print("Trained")
run()
我最终得到了一个40步长的展开网络。
堆栈和取消堆栈是否会导致内存问题?
大部分时间都花在渐变上
编辑:我已经用PyTorch重写了上面的代码,与tensorflow所花的分钟相比,它在不到一秒钟的时间内运行得更快