tensorflow 官方推荐使用TFRecord进行数据读写,因为这样效率更高。TFRecord是一种使用pb协议序列化的二进制存储格式。为了高效读取数据,TFRecord将数据序列化并存储在一组文件中实现线性读取(每个文件大概100M-200M,官方说的)。tf.train.Example就是 pb 协议中的消息(Message)定义,下面是代码。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
| import tensorflow as tf
def _parse_record_train_dat(example_proto): features = { 'idx' : tf.FixedLenFeature([], tf.int64), 'fields_num': tf.FixedLenFeature([], tf.int64), 'data' : tf.FixedLenFeature([], tf.string) } parsed_feats = tf.parse_single_example(example_proto, features = features) return parsed_feats
def gen_tfrecord_train_dat(input, output): writer = tf.python_io.TFRecordWriter(output) fin = open(input) lines = fin.readlines() fields_num = len(lines[0].split(","))
for i, line in enumerate(lines): line = line.strip("\r\n") example = tf.train.Example(features = tf.train.Features(feature = { 'idx': tf.train.Feature(int64_list = tf.train.Int64List(value = [i])), 'fields_num': tf.train.Feature(int64_list = tf.train.Int64List(value = [fields_num])), 'data': tf.train.Feature(bytes_list = tf.train.BytesList(value = [line])) })) writer.write(example.SerializeToString()) writer.close()
def write_to_local(ary, path): fout = open(path, "a+") for ele in ary: fout.write(ele + "\n") fout.close()
def read_tfrecord_train_dat(input, output): dataset = tf.data.TFRecordDataset(input) dataset = dataset.map(_parse_record_train_dat) iter = dataset.make_one_shot_iterator() rows = []
with tf.Session() as sess: try: while True: example = sess.run(iter.get_next()) idx = example['idx'] fields_num = example['fields_num'] data = example['data'] rows.append(data) print("--------") print(idx) print(fields_num) print(data) except tf.errors.OutOfRangeError: print("OutOfRangeError")
write_to_local(rows, output)
if __name__ == "__main__": print(tf.version) gen_tfrecord_train_dat("data/test.dat", "data/test.dat.tfrecord") read_tfrecord_train_dat("data/test.dat.tfrecord", "data/test.dat.tfrecord.recover") print("success!")
|
上述代码使用 python 生成了 TFRecord文件,但是通常使用Spark 生成,具体方法参考ecosystem。上述代码不是最优的,仅仅为了理解TFRecord。
参考:
https://www.tensorflow.org/tutorials/load_data/tfrecord