Python Keras variable()内存泄漏

Python Keras variable()内存泄漏,python,tensorflow,memory-leaks,keras,Python,Tensorflow,Memory Leaks,Keras,我对Keras和tensorflow都是新手,有一个问题。我正在使用一些损失函数(主要是二进制交叉熵和均方误差)来计算预测后的损失。因为Keras只接受它自己的变量类型,所以我创建了一个,并将其作为参数提供。此场景在循环中执行(使用睡眠),如下所示: 获取合适的数据->预测->计算损失->归还 因为我有多个模型都遵循这种模式,所以我创建了tensorflow图和会话以防止冲突(同样,在导出模型的权重时,我遇到了单个图和会话的问题,因此我必须为每个模型创建不同的图和会话) 然而,现在内存正在不可控

我对Keras和tensorflow都是新手,有一个问题。我正在使用一些损失函数(主要是二进制交叉熵和均方误差)来计算预测后的损失。因为Keras只接受它自己的变量类型,所以我创建了一个,并将其作为参数提供。此场景在循环中执行(使用睡眠),如下所示:

获取合适的数据->预测->计算损失->归还

因为我有多个模型都遵循这种模式,所以我创建了tensorflow图和会话以防止冲突(同样,在导出模型的权重时,我遇到了单个图和会话的问题,因此我必须为每个模型创建不同的图和会话)

然而,现在内存正在不可控制地增长,在几次迭代中从几个MiB增加到700MiB。我知道Keras的clear_session()和gc.collect(),我在每次迭代结束时都使用它们,但问题仍然存在。这里我提供了一个代码片段,它不是项目中的实际代码。为了隔离问题,我创建了单独的脚本:

import tensorflow as tf

from keras import backend as K
from keras.losses import binary_crossentropy, mean_squared_error

from time import time, sleep
import gc
from numpy.random import rand

from os import getpid
from psutil import Process

from csv import DictWriter
from keras import backend as K

this_process = Process(getpid())

graph = tf.Graph()
sess = tf.Session(graph=graph)

cnt = 0
max_c = 500

with open('/home/quark/Desktop/python-test/leak-7.csv', 'a') as file:
    writer = DictWriter(file, fieldnames=['time', 'mem'])
    writer.writeheader()

    while cnt < max_c:  
        with graph.as_default(), sess.as_default():         
            y_true = K.variable(rand(36, 6))
            y_pred = K.variable(rand(36, 6))

            rec_loss = K.eval(binary_crossentropy(y_true, y_pred))
            val_loss = K.eval(mean_squared_error(y_true, y_pred))

            writer.writerow({
                'time': int(time()),
                'mem': this_process.memory_info().rss
            })

        K.clear_session()
        gc.collect()

        cnt += 1
        print(max_c - cnt)
        sleep(0.1)
将tensorflow导入为tf
从keras导入后端为K
从keras.com导入二进制交叉熵,均方误差
从时间导入时间,睡眠
导入gc
从numpy.random导入rand
从操作系统导入getpid
从psutil导入过程
从csv导入DictWriter
从keras导入后端为K
此进程=进程(getpid())
graph=tf.graph()
sess=tf.Session(graph=graph)
cnt=0
最大值c=500
打开('/home/quark/Desktop/python test/leak-7.csv',a')作为文件:
writer=DictWriter(文件,字段名=['time','mem'])
writer.writeheader()
当cnt
此外,我还添加了内存使用情况图:


非常感谢您的帮助。

我刚刚删除了带有
语句的
(可能是一些tf代码),没有发现任何泄漏。我认为keras会话和tf默认会话之间存在差异。因此,您没有使用
K.clear_session()
清除正确的会话。可能使用
tf.reset\u default\u graph()
也可以

while True: 
    y_true = K.variable(rand(36, 6))
    y_pred = K.variable(rand(36, 6))

    val_loss = K.eval(binary_crossentropy(y_true, y_pred))
    rec_loss = K.eval(mean_squared_error(y_true, y_pred))

    K.clear_session()
    gc.collect()

    sleep(0.1)

最后,我做的是从
where
语句中删除
K.variable()。通过这种方式,变量是默认图形的一部分,稍后由
K清除。clear_session()

您可以添加所需的导入吗?我相信你混合了tf和keras命令。是的,我们可以运行一个完整的示例就好了。我已经更新了代码。我知道用
语句删除
可以解决问题,但我恐怕不能忽略它,因为我有不同的图。无论如何,谢谢,我将尝试使用
tf.reset\u default\u graph()
。然后您只需重置相应的图形,而不仅仅是默认图形