为了账号安全,请及时绑定邮箱和手机立即绑定

Keras变量()内存泄漏

Keras变量()内存泄漏

杨__羊羊 2021-08-11 16:51:40
我是 Keras 和 tensorflow 的新手,并且遇到了问题。我正在使用一些损失函数(主要是 binary_crossentropy 和 mean_squared_error)来计算预测后的损失。由于 Keras 只接受它自己的变量类型,因此我创建了一个并将其作为参数提供。此场景在循环中执行(带睡眠),如下所示:获取适当的数据 -> 预测 -> 计算丢失的数据 -> 返回它。由于我有多个遵循此模式的模型,因此我创建了 tensorflow 图和会话以防止冲突(在导出模型的权重时,我遇到了单个图和会话的问题,因此我必须为每个模型创建不同的模型)。但是,现在内存不受控制地增加,在几次迭代中从几 MiB 增加到 700MiB。我知道 Keras 的 clear_session() 和 gc.collect(),我在每次迭代结束时使用它们,但问题仍然存在。这里我提供了一个来自项目的代码片段,它不是实际的代码。我创建了单独的脚本以隔离问题:import tensorflow as tffrom keras import backend as Kfrom keras.losses import binary_crossentropy, mean_squared_errorfrom time import time, sleepimport gcfrom numpy.random import randfrom os import getpidfrom psutil import Processfrom csv import DictWriterfrom keras import backend as Kthis_process = Process(getpid())graph = tf.Graph()sess = tf.Session(graph=graph)cnt = 0max_c = 500with open('/home/quark/Desktop/python-test/leak-7.csv', 'a') as file:    writer = DictWriter(file, fieldnames=['time', 'mem'])    writer.writeheader()    while cnt < max_c:          with graph.as_default(), sess.as_default():                     y_true = K.variable(rand(36, 6))            y_pred = K.variable(rand(36, 6))            rec_loss = K.eval(binary_crossentropy(y_true, y_pred))            val_loss = K.eval(mean_squared_error(y_true, y_pred))            writer.writerow({                'time': int(time()),                'mem': this_process.memory_info().rss            })        K.clear_session()        gc.collect()        cnt += 1        print(max_c - cnt)        sleep(0.1)此外,我还添加了内存使用图: Keras memory leak任何帮助表示赞赏。
查看完整描述

2 回答

?
侃侃无极

TA贡献2051条经验 获得超10个赞

最后,我所做的是K.variable()where语句中删除代码。这样,变量是默认图形的一部分,稍后由 清除K.clear_session()


查看完整回答
反对 回复 2021-08-11
  • 2 回答
  • 0 关注
  • 177 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号