Jeremy

Tensorflow训练数据读取方法(TFRecords)

我最开始学TF的时候,数据一般都是直接预加载给model的,并因此还犯过一个低级的错误,那就是把数据也保存到checkpoint文件当中,那时候还很好奇,怎么checkpoint文件会这么大;后来我就开始使用feed_dict的方式来供给数据,虽然比预加载灵活性要好些,但是使用feed_dict其实效率并不好;因此,现在终于得开始投到TFRecords的怀抱中了。

看过官方教程的同学们,基本上都知道,TF有三种数据供给方式,分别为:

  1. 预加载数据;
  2. 使用feed_dict给model供给数据;
  3. 使用输入管道(Queue)从文件中读取数据;

前面两种方法是最容易应用的方法,基本上不用多讲,大家都知道。而第三种方法则需要一些时间去了解,因此本文也主要记录这方面的内容,并且也主要关注TFRecords相关的内容。


在开始讲输入管道读取数据之前,我们会先介绍两个背景知识:TFRecord 和 队列与线程。

TFRecord文件

TFRecords 是 Tensorflow 默认的数据格式,它是一种二进制文件,其中包含了序列化的tf.train.Exmaple的Protobuf结构化信息.

关于TFRecords格式的官方介绍如下:

A TFRecords file contains a sequence of strings with CRC hashes. Each record has the format

1
2
3
4
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data

and the records are concatenated together to produce the file. The CRC32s are described here, and the mask of a CRC is

1
masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul

当然,上面这段官方介绍,其实我们没有必要去关心,我们只需要记住,TFRecords是一个存储着基于Protobuf结构化信息的二进制文件。

现在我们需要重点关注2个问题:

  1. 怎么把数据转换成TFRecords呢?
  2. 怎么读取TFRecords文件,并将里面的数据解析出来呢?

怎么把数据转换成TFRecords呢?

如何将数据(比如说图像数据)转换成TFRecords,这里面主要涉及到两个函数,一个是tf.python_io.TFRecordWriter(),另一个是tf.train.Example()。

这里面我们先来看一个官方的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def convert_to(data_set, name):
"""Converts a dataset to tfrecords."""
images = data_set.images
labels = data_set.labels
num_examples = data_set.num_examples
filename = os.path.join(FLAGS.directory, name + '.tfrecords')
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw = images[index].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
writer.close()

在上面的例子中,我们通过指定新TFRecords的名字filename来创建一个TFRecords的存储实例writer,而在TFRecords内部,我们则是通过由一个个tf.train.Example组成的。

怎么读取TFRecords文件,并将里面的数据解析出来呢?

我们在知道了如何生成TFRecords后并生成了TFRecords后,就需要考虑去如何解析TFRecords文件了。

1
2
3
4
5
6
7
8
9
10
11
12
13
def read_and_decode(_filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(_filename_queue)
features = tf.parse_single_example(serialized_example,
features={"label": tf.FixedLenFeature([], tf.int64),
"im_raw": tf.FixedLenFeature([], tf.string)})
im = tf.decode_raw(features["im_raw"], tf.uint8)
im = tf.reshape(im, [224, 224, 3])
im = tf.cast(im, tf.float32) * (1. / 128) - 0.5
label = tf.cast(features["label"], tf.int64)
return im, label

如上面的例子所示,读取TFRecords的方法其实也很简单,那就是通过tf.TFRecordWriter()和tf.parse_single_example()来具体实现。
当然,还有一点上面的例子没说,那就是_filename_queue是如何生成的,关于_filename_queue的生成,主要是通过tf.train.string_input_producer()来实现的。
函数tf.train.string_input_producer()的输入就是TFRecords的路径列表。


以下的内容只是官方教程(中文版)的转发,只要认真地看完了官方教程,那么基本上可以算掌握了。

线程与队列

我们使用TFRecords的一个重要原因就是,希望可以使用它支持多线程的特性来加快训练样本的供给。

正如TF中的其他组件一样,队列Queue就是TF图中的节点。这是一种有状态的节点,就像变量一样:其他节点可以修改它的内容。具体来说,其他节点可以把新元素插入队列后端rear,也可以把队列前端front的元素删除。

队列使用概述

队列,如FIFOQueue和RandomShuffleQueue,在TF的张量异步计算时非常重要。

例如,一个典型的输入结构:使用一个RandomShuffleQueue来作为模型训练的输入:

  • 多个线程准备训练样本,并且把这些样本推入队列;
  • 一个训练线程执行一个训练操作,此操作会从队列中移除最小批次的样本(mini-batch);

在TF中提供了两个类来帮助多线程的实现:

  • tf.Coordinator
  • tf.QueueRunner

从设计上这两个类必须被一起使用。Coordinator类可以用来同时停止多个工作线程并且向那个等待所有工作线程终止的程序报告异常。QueueRunner类用来协调多个工作线程同时将多个张量推入同一个队列中。

Coordinator

Coordinator类用来帮助多个线程协同工作,多个线程同步终止。 其主要方法有:

  • should_stop():如果线程应该停止则返回True。
  • request_stop(): 请求该线程停止。
  • join():等待被指定的线程终止。

首先创建一个Coordinator对象,然后建立一些使用Coordinator对象的线程。这些线程通常一直循环运行,一直到should_stop()返回True时停止。 任何线程都可以决定计算什么时候应该停止。它只需要调用request_stop(),同时其他线程的should_stop()将会返回True,然后都停下来。

QueueRunner

QueueRunner类会创建一组线程, 这些线程可以重复的执行Enquene操作, 他们使用同一个Coordinator来处理线程同步终止。此外,一个QueueRunner会运行一个closer thread,当Coordinator收到异常报告时,这个closer thread会自动关闭队列。

您可以使用一个queue runner,来实现上述结构。

首先建立一个TensorFlow图表,这个图表使用队列来输入样本。增加处理样本并将样本推入队列中的操作。增加training操作来移除队列中的样本。


从文件读取数据

一个典型的文件读取管线会包含下面这些步骤:

  1. 文件名列表
  2. 可配置的 文件名乱序(shuffling)
  3. 可配置的 最大训练迭代数(epoch limit)
  4. 文件名队列
  5. 针对输入文件格式的阅读器
  6. 纪录解析器
  7. 可配置的预处理器
  8. 样本队列

文件名, 乱序(shuffling), 和最大训练迭代数(epoch limits)

可以使用字符串张量(比如[“file0”, “file1”], [(“file%d” % i) for i in range(2)], [(“file%d” % i) for i in range(2)]) 或者tf.train.match_filenames_once 函数来产生文件名列表。

将文件名列表交给tf.train.string_input_producer 函数 string_input_producer()来生成一个先入先出的队列,文件阅读器会需要它来读取数据。

string_input_producer提供的可配置参数来设置文件名乱序和最大的训练迭代数,QueueRunner会为每次迭代(epoch)将所有的文件名加入文件名队列中,如果shuffle=True的话,会对文件名进行乱序处理。这一过程是比较均匀的,因此它可以产生均衡的文件名队列。

这个QueueRunner的工作线程是独立于文件阅读器的线程,因此乱序和将文件名推入到文件名队列这些过程不会阻塞文件阅读器运行。

文件格式

根据你的文件格式,选择对应的文件阅读器,然后将文件名队列提供给阅读器的read方法。阅读器的read方法会输出一个key来表征输入的文件和其中的纪录(对于调试非常有用),同时得到一个字符串标量,这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。

CSV 文件

从CSV文件中读取数据,需要使用TextLineReader和decode_csv操作,如下面的例子所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
value, record_defaults=record_defaults)
features = tf.concat(0, [col1, col2, col3, col4])
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1200):
# Retrieve a single instance:
example, label = sess.run([features, col5])
coord.request_stop()
coord.join(threads)

每次read的执行都会从文件中读取一行内容,decode_csv操作会解析这一行内容并将其转为张量列表。如果输入的参数有缺失,record_default参数可以根据张量的类型来设置默认值。

在调用run或者eval去执行read之前,你必须调用tf.train.start_queue_runners来将文件名填充到队列。否则read操作会被阻塞到文件名队列中有值为止。

固定长度的记录

从二进制文件中读取固定长度纪录,可以使用tf.FixedLengthRecordReader的tf.decode_raw操作。decode_raw操作可以讲一个字符串转换为一个uint8的张量。

举例来说,the CIFAR-10 dataset的文件格式定义是:每条记录的长度都是固定的,一个字节的标签,后面是3072字节的图像数据。uint8的张量的标准操作就可以从中获取图像片并且根据需要进行重组。 例子代码可以在tensorflow/models/image/cifar10/cifar10_input.py找到,具体讲述可参见教程.

标准TensorFlow格式

另一种保存记录的方法可以允许你讲任意的数据转换为TensorFlow所支持的格式, 这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter class写入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py就是这样的一个例子。

从TFRecords文件中读取数据,可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个parse_single_example操作可以将Example协议内存块(protocol buffer)解析为张量。 MNIST的例子就使用了convert_to_records所构建的数据。 请参看tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py, 您也可以将这个例子跟fully_connected_feed的版本加以比较。

预处理

你可以对输入的样本进行任意的预处理,这些预处理不依赖于训练参数,你可以在tensorflow/models/image/cifar10/cifar10.py找到数据归一化,提取随机数据片,增加噪声或失真等等预处理的例子。

批处理

在数据输入管线的末端,我们需要有另一个队列来执行输入样本的训练,评价和推理。因此我们使用tf.train.shuffle_batch函数来对队列中的样本进行乱序处理.

示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def read_my_file_format(filename_queue):
reader = tf.SomeReader()
key, record_string = reader.read(filename_queue)
example, label = tf.some_decoder(record_string)
processed_example = some_processing(example)
return processed_example, label
def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example, label = read_my_file_format(filename_queue)
# min_after_dequeue defines how big a buffer we will randomly sample
# from -- bigger means better shuffling but slower start up and more
# memory used.
# capacity must be larger than min_after_dequeue and the amount larger
# determines the maximum we will prefetch. Recommendation:
# min_after_dequeue + (num_threads + a small safety margin) * batch_size
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch(
[example, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch

如果你需要对不同文件中的样本进行更强的乱序和并行处理,可以使用tf.train.shuffle_batch_join 函数. 示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def read_my_file_format(filename_queue):
# Same as above
def input_pipeline(filenames, batch_size, read_threads, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=True)
example_list = [read_my_file_format(filename_queue)
for _ in range(read_threads)]
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch_join(
example_list, batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch

在这个例子中, 你虽然只使用了一个文件名队列, 但是TensorFlow依然能保证多个文件阅读器从同一次迭代(epoch)的不同文件中读取数据,知道这次迭代的所有文件都被开始读取为止。(通常来说一个线程来对文件名队列进行填充的效率是足够的)

另一种替代方案是:使用tf.train.shuffle_batch函数,设置num_threads的值大于1。这种方案可以保证同一时刻只在一个文件中进行读取操作(但是读取速度依然优于单线程),而不是之前的同时读取多个文件。这种方案的优点是:

  • 避免了两个不同的线程从同一个文件中读取同一个样本。
  • 避免了过多的磁盘搜索操作。

你一共需要多少个读取线程呢? 函数tf.train.shuffle_batch为TensorFlow图提供了获取文件名队列中的元素个数之和的方法。 如果你有足够多的读取线程, 文件名队列中的元素个数之和应该一直是一个略高于0的数。

创建线程并使用QueueRunner对象来预取

简单来说:使用上面列出的许多tf.train函数添加QueueRunner到你的数据流图中。
在你运行任何训练步骤之前,需要调用tf.train.start_queue_runners函数,否则数据流图将一直挂起。

tf.train.start_queue_runners 这个函数将会启动输入管道的线程,填充样本到队列中,以便出队操作可以从队列中拿到样本。
这种情况下最好配合使用一个tf.train.Coordinator,这样可以在发生错误的情况下正确地关闭这些线程。
如果你对训练迭代数做了限制,那么需要使用一个训练迭代数计数器,并且需要被初始化。推荐的代码模板如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Create the graph, etc.
init_op = tf.initialize_all_variables()
# Create a session for running operations in the Graph.
sess = tf.Session()
# Initialize the variables (like the epoch counter).
sess.run(init_op)
# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
# Run training steps or whatever
sess.run(train_op)
except tf.errors.OutOfRangeError:
print 'Done training -- epoch limit reached'
finally:
# When done, ask the threads to stop.
coord.request_stop()
# Wait for threads to finish.
coord.join(threads)
sess.close()

疑问: 这是怎么回事?

首先,我们先创建数据流图,这个数据流图由一些流水线的阶段组成,阶段间用队列连接在一起。第一阶段将生成文件名,我们读取这些文件名并且把他们排到文件名队列中。第二阶段从文件中读取数据(使用Reader),产生样本,而且把样本放在一个样本队列中。根据你的设置,实际上也可以拷贝第二阶段的样本,使得他们相互独立,这样就可以从多个文件中并行读取。在第二阶段的最后是一个排队操作,就是入队到队列中去,在下一阶段出队。因为我们是要开始运行这些入队操作的线程,所以我们的训练循环会使得样本队列中的样本不断地出队。

在tf.train中要创建这些队列和执行入队操作,就要添加tf.train.QueueRunner到一个使用tf.train.add_queue_runner函数的数据流图中。每个QueueRunner负责一个阶段,处理那些需要在线程中运行的入队操作的列表。一旦数据流图构造成功,tf.train.start_queue_runners函数就会要求数据流图中每个QueueRunner去开始它的线程运行入队操作。

如果一切顺利的话,你现在可以执行你的训练步骤,同时队列也会被后台线程来填充。如果您设置了最大训练迭代数,在某些时候,样本出队的操作可能会得到一个tf.OutOfRangeError的错误。这其实是TensorFlow的“文件结束”(EOF) ———— 这就意味着已经达到了最大训练迭代数,已经没有更多可用的样本了。

最后一个因素是Coordinator。这是负责在收到任何关闭信号的时候,让所有的线程都知道。最常用的是在发生异常时这种情况就会呈现出来,比如说其中一个线程在运行某些操作时出现错误(或一个普通的Python异常)。

疑问: 在达到最大训练迭代数的时候如何清理关闭线程?

想象一下,你有一个模型并且设置了最大训练迭代数。这意味着,生成文件的那个线程将只会在产生OutOfRange错误之前运行许多次。该QueueRunner会捕获该错误,并且关闭文件名的队列,最后退出线程。关闭队列做了两件事情:

如果还试着对文件名队列执行入队操作时将发生错误。任何线程不应该尝试去这样做,但是当队列因为其他错误而关闭时,这就会有用了。
任何当前或将来出队操作要么成功(如果队列中还有足够的元素)或立即失败(发生OutOfRange错误)。它们不会防止等待更多的元素被添加到队列中,因为上面的一点已经保证了这种情况不会发生。
关键是,当在文件名队列被关闭时候,有可能还有许多文件名在该队列中,这样下一阶段的流水线(包括reader和其它预处理)还可以继续运行一段时间。 一旦文件名队列空了之后,如果后面的流水线还要尝试从文件名队列中取出一个文件名(例如,从一个已经处理完文件的reader中),这将会触发OutOfRange错误。在这种情况下,即使你可能有一个QueueRunner关联着多个线程。如果这不是在QueueRunner中的最后那个线程,OutOfRange错误仅仅只会使得一个线程退出。这使得其他那些正处理自己的最后一个文件的线程继续运行,直至他们完成为止。 (但如果假设你使用的是tf.train.Coordinator,其他类型的错误将导致所有线程停止)。一旦所有的reader线程触发OutOfRange错误,然后才是下一个队列,再是样本队列被关闭。

同样,样本队列中会有一些已经入队的元素,所以样本训练将一直持续直到样本队列中再没有样本为止。如果样本队列是一个RandomShuffleQueue,因为你使用了shuffle_batch 或者 shuffle_batch_join,所以通常不会出现以往那种队列中的元素会比min_after_dequeue 定义的更少的情况。 然而,一旦该队列被关闭,min_after_dequeue设置的限定值将失效,最终队列将为空。在这一点来说,当实际训练线程尝试从样本队列中取出数据时,将会触发OutOfRange错误,然后训练线程会退出。一旦所有的培训线程完成,tf.train.Coordinator.join会返回,你就可以正常退出了。

筛选记录或产生每个记录的多个样本

举个例子,有形式为[x, y, z]的样本,我们可以生成一批形式为[batch, x, y, z]的样本。 如果你想滤除这个记录(或许不需要这样的设置),那么可以设置batch的大小为0;但如果你需要每个记录产生多个样本,那么batch的值可以大于1。 然后很简单,只需调用批处理函数(比如: shuffle_batch or shuffle_batch_join)去设置enqueue_many=True就可以实现。

稀疏输入数据

SparseTensors这种数据类型使用队列来处理不是太好。如果要使用SparseTensors你就必须在批处理之后使用tf.parse_example 去解析字符串记录 (而不是在批处理之前使用 tf.parse_single_example) 。

Refs


林建民-机器视觉
Blog地址:http://www.linjm.tech/
旧博客地址:http://blog.csdn.net/linj_m