为了账号安全,请及时绑定邮箱和手机立即绑定

TensorFlow:如何保存/恢复模型?

TensorFlow:如何保存/恢复模型?

扬帆大鱼 2019-06-06 10:56:38
TensorFlow:如何保存/恢复模型?当你在坦索弗洛训练了一个模特之后:你是如何拯救受过训练的模特的?稍后如何恢复这个保存的模型?
查看完整描述

3 回答

?
拉风的咖菲猫

TA贡献1995条经验 获得超2个赞

我正在改进我的答案,为保存和恢复模型添加更多的细节。


在(和之后)TensorFlow版本0.11:


保存模型:


import tensorflow as tf


#Prepare to feed input, i.e. feed_dict and placeholders

w1 = tf.placeholder("float", name="w1")

w2 = tf.placeholder("float", name="w2")

b1= tf.Variable(2.0,name="bias")

feed_dict ={w1:4,w2:8}


#Define a test operation that we will restore

w3 = tf.add(w1,w2)

w4 = tf.multiply(w3,b1,name="op_to_restore")

sess = tf.Session()

sess.run(tf.global_variables_initializer())


#Create a saver object which will save all the variables

saver = tf.train.Saver()


#Run the operation by feeding input

print sess.run(w4,feed_dict)

#Prints 24 which is sum of (w1+w2)*b1 


#Now, save the graph

saver.save(sess, 'my_test_model',global_step=1000)

恢复模型:


import tensorflow as tf


sess=tf.Session()    

#First let's load meta graph and restore weights

saver = tf.train.import_meta_graph('my_test_model-1000.meta')

saver.restore(sess,tf.train.latest_checkpoint('./'))



# Access saved Variables directly

print(sess.run('bias:0'))

# This will print 2, which is the value of bias that we saved



# Now, let's access and create placeholders variables and

# create feed-dict to feed new data


graph = tf.get_default_graph()

w1 = graph.get_tensor_by_name("w1:0")

w2 = graph.get_tensor_by_name("w2:0")

feed_dict ={w1:13.0,w2:17.0}


#Now, access the op that you want to run. 

op_to_restore = graph.get_tensor_by_name("op_to_restore:0")


print sess.run(op_to_restore,feed_dict)

#This will print 60 which is calculated 

这个和一些更高级的用例在这里已经解释得很好了。


保存和恢复TensorFlow模型的快速完整教程


查看完整回答
反对 回复 2019-06-06
  • 3 回答
  • 0 关注
  • 1879 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信