为了账号安全,请及时绑定邮箱和手机立即绑定

深度网络模型持久化

标签:
人工智能

一、持久化概述

在 torch 中,以下对象可以持久化到硬盘,并可以通过相应的方法把这些对象持久化到内存中:

  • Tensor
  • Variable
  • nn.Module
  • Optimizer

上述对象本质上最后都是保存为 Tensor。并且 Tensor 的保存和加载非常简单,使用 t.save 和 t.load 即可。

在 save/load 时可指定使用的 pickle 模块,在 load 时还可以把 GPU tensor 映射到 CPU 或者其他 GPU 上。

我们可以通过 t.save(obj, file_name) 保存任意可序列化的对象,然后通过 obj=t.load(file_name) 方法加载保存的数据。

对于 Module 和 Optimizer 对象,建议保存为对应的 state_dict,而不是直接保存整个 Module/Optimizer 对象。Optimizer 对象保存的是参数和动量信息,通过加载之前的动量信息,能够很有效地减少模型震荡。

二、tensor 对象的保存和加载

import torch as t

a = t.Tensor(3, 4)
if t.cuda.is_available():
    a = a.cuda(1)  # 把 a 转为 GPU1 上的 tensor
    t.save(a, 'a.pth')

    # 加载为 b,存储于 GPU1 上(因为保存时 tensor 就在 GPU1 上)
    b = t.load('a.pth')

    # 加载为 c,存储于 CPU
    c = t.load('a.pth', map_location=lambda storage, loc: storage)

    # 加载为 d,存储于 GPU0 上
    d = t.load('a.pth', map_location={'cuda:1': 'cuda:0'})

三、Module 对象的保存和加载

t.set_default_tensor_type('torch.FloatTensor')
from torchvision.models import AlexNet

model = AlexNet()
# module 的 state_dict 是一个字典
model.state_dict().keys()

t.save(model.state_dict(), 'alexnet.pth')
model.load_state_dict(t.load('alexnet.pth'))

四、Optimizer 对象的保存和加载

optimizer = t.optim.Adam(model.parameters(), lr=0.1)
t.save(optimizer.state_dict(), 'optimizer.pth')
optimizer.load_state_dict(t.load('optimizer.pth'))

五、所有对象集合的保存和加载

all_data = dict(optimizer=optimizer.state_dict(),
                model=model.state_dict(),
                info=u'模型和优化器的所有参数')
t.save(all_data, 'all.pth')

all_data = t.load('all.pth')
all_data.keys()

dict_keys([‘optimizer’, ‘model’, ‘info’])
六、总结
本章介绍了 torch 的很多工具模块,主要涉及数据加载、可视化和 GPU 加速相关的内容,合理地使用这些模块可以极大地提升我们的编码效率。

作者:二十三岁的有德
原文出处:https://www.cnblogs.com/nickchen121/p/14723819.html

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
Web前端工程师
手记
粉丝
14
获赞与收藏
47

关注作者,订阅最新文章

阅读免费教程

  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消