为了账号安全,请及时绑定邮箱和手机立即绑定

TensorFlow的数据导入方法。

标签:
机器学习

简介

本文介绍TensorFlow的第二种数据导入方法。

为了保持高效,这种方法稍显繁琐。分为如下几个步骤: 
- 把所有样本写入二进制文件(只执行一次) 
- 创建Tensor,从二进制文件读取一个样本 
- 创建Tensor,从二进制文件随机读取一个mini-batch 
- 把mini-batchTensor传入网络作为输入节点。

二进制文件

使用tf.python_io.TFRecordWriter创建一个专门存储tensorflow数据的writer,扩展名为’.tfrecord’。 
该文件中依次存储着序列化的tf.train.Example类型的样本。

writer = tf.python_io.TFRecordWriter('/tmp/data.tfrecord')for i in range(0, 10):    # 创建样本example
    # ...
    serialized = example.SerializeToString()   # 序列化
    writer.write(serialized)    # 写入文件writer.close()

每一个examplefeature成员变量是一个dict,存储一个样本的不同部分(例如图像像素+类标)。以下例子的样本中包含三个键a,b,c

    # 创建样本example
    a_data = 0.618 + i         # float
    b_data = [2016 + i, 2017+i]     # int64
    c_data = numpy.array([[0, 1, 2],[3, 4, 5]]) + i    # bytes
    c_data = c_data.astype(numpy.uint8)
    c_raw = c.tostring()             # 转化成字符串

    example = tf.train.Example(
        features=tf.train.Features(
            feature={                'a': tf.train.Feature(
                    float_list=tf.train.FloatList(value=[a_data])   # 方括号表示输入为list
                ),                'b': tf.train.Feature(
                    int64_list=tf.train.Int64List(value=b_data)    # b_data本身就是列表
                ),                'c': tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=[c_raw])
                )
            }
        )
    )

dict成员的值部分接受三种类型数据: 
tf.train.FloatList:列表每个元素为float。例如a。 
tf.train.Int64List:列表每个元素为int64。例如b。 
tf.train.BytesList:列表每个元素为string。例如c

第三种类型尤其适合图像样本。注意在转成字符串之前要设定为uint8类型。

读取一个样本

接下来,我们定义一个函数,创建“从文件中读一个样本”操作,返回结果Tensor

def read_single_sample(filename):
    # 读取样本example的每个成员a,b,c
    # ...
    return a, b, c

首先创建读文件队列,使用tf.TFRecordReader从文件队列读入一个序列化的样本。

    # 读取样本example的每个成员a,b,c
    filename_queue = tf.train.string_input_producer([filename], num_epochs=None)    # 不限定读取数量
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

如果样本量很大,可以分成若干文件,把文件名列表传入tf.train.string_input_producer。 
和刚才的writer不同,这个reader是符号化的,只有在sess中run才会执行。

接下来解析符号化的样本

    # get feature from serialized example
    features = tf.parse_single_example(
        serialized_example,
        features={            'a': tf.FixedLenFeature([], tf.float32),    #0D, 标量
            'b': tf.FixedLenFeature([2], tf.int64),   # 1D,长度为2
            'c': tf.FixedLenFeature([], tf.string)  # 0D, 标量
        }
    )
    a = features['a']
    b = features['b']
    c_raw = features['c']
    c = tf.decode_raw(c_raw, tf.uint8)
    c = tf.reshape(c, [2, 3])

对于BytesList,要重新进行解码,把string类型的0维Tensor变成uint8类型的1维Tensor

读取mini-batch

使用tf.train.shuffle_batch将前述a,b,c随机化,获得mini-batchTensor

a_batch, b_batch, c_batch = tf.train.shuffle_batch([a, b, c], batch_size=2, capacity=200, min_after_dequeue=100, num_threads=2)

使用

创建一个session并初始化:

# sesssess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
tf.train.start_queue_runners(sess=sess)

由于使用了读文件队列,所以要start_queue_runners

每一次运行,会随机生成一个mini-batch样本:

a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])

这样的mini-batch可以作为网络的输入节点使用。

总结

如果想进一步了解例子中的队列机制,请参看这篇文章

本文参考了以下示例: 
https://github.com/mnuke/tf-slim-mnist 
https://indico.io/blog/tensorflow-data-inputs-part1-placeholders-protobufs-queues/ 
https://github.com/tensorflow/tensorflow/tree/r0.11/tensorflow/models/image/cifar10

完整代码如下:

import tensorflow as tfimport numpydef write_binary():
    writer = tf.python_io.TFRecordWriter('/tmp/data.tfrecord')    for i in range(0, 2):
        a = 0.618 + i
        b = [2016 + i, 2017+i]
        c = numpy.array([[0, 1, 2],[3, 4, 5]]) + i
        c = c.astype(numpy.uint8)
        c_raw = c.tostring()

        example = tf.train.Example(
            features=tf.train.Features(
                feature={                    'a': tf.train.Feature(
                        float_list=tf.train.FloatList(value=[a])
                    ),                    'b': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=b)
                    ),                    'c': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[c_raw])
                    )
                }
            )
        )
        serialized = example.SerializeToString()
        writer.write(serialized)

    writer.close()def read_single_sample(filename):
    # output file name string to a queue
    filename_queue = tf.train.string_input_producer([filename], num_epochs=None)    # create a reader from file queue
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)    # get feature from serialized example

    features = tf.parse_single_example(
        serialized_example,
        features={            'a': tf.FixedLenFeature([], tf.float32),            'b': tf.FixedLenFeature([2], tf.int64),            'c': tf.FixedLenFeature([], tf.string)
        }
    )

    a = features['a']

    b = features['b']

    c_raw = features['c']
    c = tf.decode_raw(c_raw, tf.uint8)
    c = tf.reshape(c, [2, 3])    return a, b, c#-----main function-----if 1:
    write_binary()else:    # create tensor
    a, b, c = read_single_sample('/tmp/data.tfrecord')
    a_batch, b_batch, c_batch = tf.train.shuffle_batch([a, b, c], batch_size=3, capacity=200, min_after_dequeue=100, num_threads=2)

    queues = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)    # sess
    sess = tf.Session()
    init = tf.initialize_all_variables()
    sess.run(init)

    tf.train.start_queue_runners(sess=sess)
    a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
    print(a_val, b_val, c_val)
    a_val, b_val, c_val = sess.run([a_batch, b_batch, c_batch])
    print(a_val, b_val, c_val)

原文出处

点击查看更多内容
1人点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消