简介
本文介绍TensorFlow的第二种数据导入方法。
为了保持高效,这种方法稍显繁琐。分为如下几个步骤:
- 把所有样本写入二进制文件(只执行一次)
- 创建Tensor
,从二进制文件读取一个样本
- 创建Tensor
,从二进制文件随机读取一个mini-batch
- 把mini-batchTensor
传入网络作为输入节点。
二进制文件
使用tf.python_io.TFRecordWriter
创建一个专门存储tensorflow数据的writer
,扩展名为’.tfrecord’。
该文件中依次存储着序列化的tf.train.Example
类型的样本。
writer = tf.python_io.TFRecordWriter('/tmp/data.tfrecord')for i in range(0, 10): # 创建样本example # ... serialized = example.SerializeToString() # 序列化 writer.write(serialized) # 写入文件writer.close()
每一个example
的feature
成员变量是一个dict
,存储一个样本的不同部分(例如图像像素+类标)。以下例子的样本中包含三个键a,b,c
:
# 创建样本example a_data = 0.618 + i # float b_data = [2016 + i, 2017+i] # int64 c_data = numpy.array([[0, 1, 2],[3, 4, 5]]) + i # bytes c_data = c_data.astype(numpy.uint8) c_raw = c.tostring() # 转化成字符串 example = tf.train.Example( features=tf.train.Features( feature={ 'a': tf.train.Feature( float_list=tf.train.FloatList(value=[a_data]) # 方括号表示输入为list ), 'b': tf.train.Feature( int64_list=tf.train.Int64List(value=b_data) # b_data本身就是列表 ), 'c': tf.train.Feature( bytes_list=tf.train.BytesList(value=[c_raw]) ) } ) )
dict
成员的值部分接受三种类型数据:
- tf.train.FloatList
:列表每个元素为float。例如a
。
- tf.train.Int64List
:列表每个元素为int64。例如b
。
- tf.train.BytesList
:列表每个元素为string。例如c
。
第三种类型尤其适合图像样本。注意在转成字符串之前要设定为uint8
类型。
读取一个样本
接下来,我们定义一个函数,创建“从文件中读一个样本”操作,返回结果Tensor
。
def read_single_sample(filename): # 读取样本example的每个成员a,b,c # ... return a, b, c
首先创建读文件队列,使用tf.TFRecordReader
从文件队列读入一个序列化的样本。
# 读取样本example的每个成员a,b,c filename_queue = tf.train.string_input_producer([filename], num_epochs=None) # 不限定读取数量 reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue)
如果样本量很大,可以分成若干文件,把文件名列表传入tf.train.string_input_producer
。
和刚才的writer不同,这个reader是符号化的,只有在sess中run才会执行。
接下来解析符号化的样本
# get feature from serialized example features = tf.parse_single_example( serialized_example, features={ 'a': tf.FixedLenFeature([], tf.float32), #0D, 标量 'b': tf.FixedLenFeature([2], tf.int64), # 1D,长度为2 'c': tf.FixedLenFeature([], tf.string) # 0D, 标量 } ) a = features['a'] b = features['b'] c_raw = features['c'] c = tf.decode_raw(c_raw, tf.uint8) c = tf.reshape(c, [2, 3])
对于BytesList
,要重新进行解码,把string
类型的0维Tensor
变成uint8
类型的1维Tensor
。
读取mini-batch
使用tf.train.shuffle_batch
将前述a,b,c
随机化,获得mini-batchTensor
:
a_batch, b_batch, c_batch = tf.train.shuffle_batch([a, b, c], batch_size=2, capacity=200, min_after_dequeue=100, num_threads=2)
使用
创建一个session
并初始化:
# sesssess = tf.Session() init = tf.initialize_all_variables() sess.run(init) tf.train.start_queue_runners(sess=sess)
由于使用了读文件队列,所以要start_queue_runners
。
每一次运行,会随机生成一个mini-batch样本:
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch]) a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
这样的mini-batch可以作为网络的输入节点使用。
总结
如果想进一步了解例子中的队列机制,请参看这篇文章。
本文参考了以下示例:
https://github.com/mnuke/tf-slim-mnist
https://indico.io/blog/tensorflow-data-inputs-part1-placeholders-protobufs-queues/
https://github.com/tensorflow/tensorflow/tree/r0.11/tensorflow/models/image/cifar10
完整代码如下:
import tensorflow as tfimport numpydef write_binary(): writer = tf.python_io.TFRecordWriter('/tmp/data.tfrecord') for i in range(0, 2): a = 0.618 + i b = [2016 + i, 2017+i] c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i c = c.astype(numpy.uint8) c_raw = c.tostring() example = tf.train.Example( features=tf.train.Features( feature={ 'a': tf.train.Feature( float_list=tf.train.FloatList(value=[a]) ), 'b': tf.train.Feature( int64_list=tf.train.Int64List(value=b) ), 'c': tf.train.Feature( bytes_list=tf.train.BytesList(value=[c_raw]) ) } ) ) serialized = example.SerializeToString() writer.write(serialized) writer.close()def read_single_sample(filename): # output file name string to a queue filename_queue = tf.train.string_input_producer([filename], num_epochs=None) # create a reader from file queue reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) # get feature from serialized example features = tf.parse_single_example( serialized_example, features={ 'a': tf.FixedLenFeature([], tf.float32), 'b': tf.FixedLenFeature([2], tf.int64), 'c': tf.FixedLenFeature([], tf.string) } ) a = features['a'] b = features['b'] c_raw = features['c'] c = tf.decode_raw(c_raw, tf.uint8) c = tf.reshape(c, [2, 3]) return a, b, c#-----main function-----if 1: write_binary()else: # create tensor a, b, c = read_single_sample('/tmp/data.tfrecord') a_batch, b_batch, c_batch = tf.train.shuffle_batch([a, b, c], batch_size=3, capacity=200, min_after_dequeue=100, num_threads=2) queues = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS) # sess sess = tf.Session() init = tf.initialize_all_variables() sess.run(init) tf.train.start_queue_runners(sess=sess) a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch]) print(a_val, b_val, c_val) a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch]) print(a_val, b_val, c_val)
共同学习,写下你的评论
评论加载中...
作者其他优质文章