tensorflow 官方推荐使用TFRecord进行数据读写,因为这样效率更高。TFRecord是一种使用pb协议序列化的二进制存储格式。为了高效读取数据,TFRecord将数据序列化并存储在一组文件中实现线性读取(每个文件大概100M-200M,官方说的)。tf.train.Example就是 pb 协议中的消息(Message)定义,下面是代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#coding:utf8
import tensorflow as tf

def _parse_record_train_dat(example_proto):
features = {
'idx' : tf.FixedLenFeature([], tf.int64),
'fields_num': tf.FixedLenFeature([], tf.int64),
'data' : tf.FixedLenFeature([], tf.string)
}
parsed_feats = tf.parse_single_example(example_proto, features = features)
return parsed_feats

def gen_tfrecord_train_dat(input, output):
writer = tf.python_io.TFRecordWriter(output)
# data = tf.read_file(input)
fin = open(input)
lines = fin.readlines()
fields_num = len(lines[0].split(","))

for i, line in enumerate(lines):
line = line.strip("\r\n")
# 这里kv中的k不能写错
example = tf.train.Example(features = tf.train.Features(feature = {
'idx': tf.train.Feature(int64_list = tf.train.Int64List(value = [i])),
'fields_num': tf.train.Feature(int64_list = tf.train.Int64List(value = [fields_num])),
'data': tf.train.Feature(bytes_list = tf.train.BytesList(value = [line]))
}))
writer.write(example.SerializeToString())
writer.close()

def write_to_local(ary, path):
fout = open(path, "a+")
for ele in ary:
fout.write(ele + "\n")
fout.close()

def read_tfrecord_train_dat(input, output):
dataset = tf.data.TFRecordDataset(input)
dataset = dataset.map(_parse_record_train_dat)
iter = dataset.make_one_shot_iterator()
rows = []

with tf.Session() as sess:
try:
while True:
example = sess.run(iter.get_next())
idx = example['idx']
fields_num = example['fields_num']
data = example['data']
rows.append(data)
print("--------")
print(idx)
print(fields_num)
print(data)
except tf.errors.OutOfRangeError:
print("OutOfRangeError")

write_to_local(rows, output)

if __name__ == "__main__":
print(tf.version)
gen_tfrecord_train_dat("data/test.dat", "data/test.dat.tfrecord")
read_tfrecord_train_dat("data/test.dat.tfrecord", "data/test.dat.tfrecord.recover")
print("success!")

上述代码使用 python 生成了 TFRecord文件,但是通常使用Spark 生成,具体方法参考ecosystem。上述代码不是最优的,仅仅为了理解TFRecord。

参考

https://www.tensorflow.org/tutorials/load_data/tfrecord