为了账号安全,请及时绑定邮箱和手机立即绑定

tensorflow中使用save和restore保存和恢复模型

我们在训练模型过程中,有时训练一段时间后,往往想要在验证集上验证一下,模型是否存在过拟合,然后视验证情况,再选择继续训练还是修改模型参数。这时tensorflow提供的Saver类,就能很好的帮助到我们。
当我们保存一个模型到指定路径后,还目录下将会出现四种类型的文件:

checkpoint: 具有最近检查点列表的协议缓存区
.data: 保存模型中的变量
.index: 标识检查点
.meta: 保存模型中计算图的结构信息

1、tf.train.Saver( )

首先需要在程序中定义一个saver操作,该定义在会话结构之外。

import tensorflow as tf
...
saver = tf.train.Saver()
...with tf.Session() as sess:
    ...

这样一个saver操作就定义好了。tf.train.Saver( )有几个我们平时常用到的参数,具体如下:

max_to_keep: 设置保存最近的检查点文件的个数,例如max_to_keep=4,就是只保存最新的四个模型。
keep_checkpoint_every_n_hours: 设置每隔多长时间保存一次模型。
savable_variables: 可以设置将要保存的tensor。如tf.train.Saver([w1, w2]),就是只保存w1和w2。如果不指定任何想要保存的tensor,saver默认保存所有的tensor。

2、saver.save( )

在使用tf.train.Saver( )创建了saver操作之后,我们就可以在一个会话中保存我们的模型。

...with tf.Session() as sess:
    ...    for epoch in range(10):
        ...
        saver.save(sess, model_path, global_step=epoch, write_meta_graph=False)
        ...

使用上面代码中的saver.save( )就可以按照我们的要求保存模型。其中参数说明如下:

sess: 会话对象
model_path: 模型保存的路径
global_step=epoch: 可选,在我们保存的文件名字中,加上迭代次数,以方便我们区分保                   存的文件是经过多少次的训练迭代。如global_step = 2,则我们保存的文件名字为-2.data-00000-of-00001,-2.index,-2.meta。
write_meta_graph: 可选,False: 只保存一次.meta文件;True:根据我们设置的保存次   数,保存多次.meta文件。这里对这个参数加一点说明:因为模型一旦建立好之后,计算图的结构就确定了,所以每次保存的.meta文件都是一样的,有时为了节省存储空间,我们选择只保存一次.meta文件。

3、saver.restore( )

在保存了一个模型之后,我们使用saver.restore( )来恢复模型。恢复操作也需要在session会话中。我们可以创建一个新的会话:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model_path/-2.meta') # 以.meta文件名为-2.meta为例
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    ...

我们首先要通过saver = tf.train.import_meta_graph(‘model_path/-2.meta’)加载模型的计算图结构,然后通过saver.restore(sess, tf.train.latest_checkpoint(‘model_path’))来恢复我们保存的所有变量和操作。其中tf.train.latest_checkpoint(‘model_path’)是从最近的检查点中恢复模型。

以上就是保存和恢复全部模型的操作。在实际进行模型优化时,有时我们会对原来的模型进行修改,如增加网络的深度,重新定义精确度指标等。这时,我们就可以通过变量或操作的名字来加载指定的变量或操作。

with tf.Session() as sess:
    graph = tf.get_default_graph()    # 加载网络权重变量w1和w2,"w1:0"中,weight1为定义w1变量时指定的名字,当此tensor没有重复时,后面加上0
    w1 = graph.get_tensor_by_name("weight1:0")
    w2 = graph.get_tensor_by_name("weight2:0")    # 恢复网络中的第七全连接层,fully_connected7为定义fc7时指定的名字,当此tensor没有重复时,后面加上0
    fc7 = graph.get_tensor_by_name("fully_connected7:0")

加载到指定的变量后,我们就可以在其基础上,对原来的模型进行修改。
下面是我在github上对save和restore验证的代码地址:
https://github.com/Demohai/my_tensorflow_learn/tree/master/save_and_restore_models

原文出处


点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消