一.保存模型
所谓的模型保存,也就是冻结(freeze)模型,将该模型的图结构和该模型的权重固化到一起。
二.加载模型
在恢复模型的时候,通过get_tensor_by_name获得模型中的变量,然后对变量进行赋值。
三.代码实例
Demo1:单纯地训练,不生成模型
#-*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np
with tf.variable_scope('Placeholder'):
inputs_placeholder = tf.placeholder(tf.float32,name = 'inputs_placeholder',shape = [None,10])
labels_placeholder = tf.placeholder(tf.float32,name = 'labels_placeholder',shape = [None,1])
with tf.variable_scope('NN'):
W1 = tf.get_variable('W1',shape = [10,1],initializer = tf.random_normal_initializer(stddev = 1e-1))
b1 = tf.get_variable('b1',shape = [1],initializer = tf.constant_initializer(0.1))
W2 = tf.get_variable('W2',shape = [10,1],initializer = tf.random_normal_initializer(stddev = 1e-1))
b2 = tf.get_variable('b2',shape = [1],initializer = tf.constant_initializer(0.1))
a1 = tf.nn.relu(tf.matmul(inputs_placeholder,W1) + b1)
a2 = tf.nn.relu(tf.matmul(inputs_placeholder,W2) + b2)
y = tf.div(tf.add(a1,a2),2)
with tf.variable_scope('Loss'):
loss = tf.reduce_sum(tf.square(y - labels_placeholder) / 2)
with tf.variable_scope('Accuracy'):
predictions = tf.greater(y,0.5,name = 'predictions')
correct_predictions = tf.equal(predictions,tf.cast(labels_placeholder,tf.bool),name = "correct_predictions")
accuracy = tf.reduce_mean(tf.cast(correct_predictions,tf.float32))
adam = tf.train.AdamOptimizer(learning_rate = 1e-3)
train_op = adam.minimize(loss)
#generate_data
inputs = np.random.choice(10,size = [10000,10])
labels = (np.sum(inputs,axis = 1) > 45).reshape(-1,1).astype(np.float32)
print('inputs.shape:',inputs.shape)
print('labels.shape:',labels.shape)
test_inputs = np.random.choice(10,size = [100,10])
test_labels = (np.sum(test_inputs,axis = 1) > 45).reshape(-1,1).astype(np.float32)
print('test_inputs.shape:',test_inputs.shape)
print('test_labels.shape:',test_labels.shape)
batch_size = 32
epochs = 10
batches = []
for i in range(len(inputs) // batch_size):
batch = [ inputs[batch_size * i:batch_size * i + batch_size],labels[batch_size * i:batch_size * i + batch_size]]
batches.append(list(batch))
if (i + 1) * batch_size < len(inputs):
batch = [inputs[batch_size*(i + 1):],labels[batch_size*(i + 1):]]
batches.append(list(batch))
print("Number of batches: %d" % len(batches))
print("Size of full batch: %d" % len(batches[0]))
print("Size of final batch: %d" % len(batches[-1]))
global_count = 0
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(epochs):
for batch in batches:
train_loss, _ = sess.run(
[loss,train_op],
feed_dict = {
inputs_placeholder:batch[0],
labels_placeholder:batch[1]
})
if global_count % 100 == 0:
acc = sess.run(accuracy, feed_dict = {
inputs_placeholder: test_inputs,
labels_placeholder: test_labels
})
print('accuracy: %f' % acc)
global_count += 1
acc = sess.run(accuracy,feed_dict = {
inputs_placeholder:test_inputs,
labels_placeholder:test_labels
})
print("final accuracy: %f" % acc)
运行结果:
Demo2:基本的保存与加载模型
保存模型:
import tensorflow as tf
#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#Create a saver object which will save all the variables
saver = tf.train.Saver()
#Run the operation by feeding input
print (sess.run(w4,feed_dict))
#Prints 24 which is sum of (w1+w2)*b1
#Now, save the graph
saver.save(sess, './model/tiny_model',global_step=1000)
运行结果:
这里4,5,6,11行中的name=’w1′, name=’w2′, name=’bias’, name=’op_to_restore’ 千万不能省略,这是恢复还原模型的关键。
加载模型:
import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('./model/tiny_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print (sess.run(op_to_restore,feed_dict))
#This will print 60 which is calculated
运行结果:
ckpt文件不方便模型迁移,比如在windows上训练好的模型放在Linux环境可能加载不了,原因是里面的checkpoints中的路径参数会改变,为了更好的部署和上线,应该考虑将模型保存为pb文件,本文的方式只适合入门学习。
另外:
“Note that when the network is saved, values of the placeholders are not saved.”
点击查看更多内容
为 TA 点赞
评论
共同学习,写下你的评论
评论加载中...
作者其他优质文章
正在加载中
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦