tfrecord是tensorflow基于protobuf框架开发的一种用于持久化训练数据的文件。
protobuf
考虑下面一个简单的protobuf message定义:
syntax = "proto3";
message Msg{
map<string, int64> val = 1;
};
这里,我们将val的field_number属性置为1,下文将在序列化文件中将其提取出来。通过 protoc --proto_path=. --python_out=. demo.proto
编译生成demo_pb2.py,然后就可以使用python脚本按照Msg格式对数据进行编解码。这里,我们使用的序列化代码如下:
import demo_pb2
import sys
v = {
"test":300,
}
msg= demo_pb2.Msg(val = v)
fo=open(sys.argv[1], "wb")
fo.write(msg.SerializeToString())
fo.close()
该代码的就是将数据v持久化到传入的文件中,其中encode接口SerializeToString()是在编译处demo_pb2.py自动生成的,与之相对应的还有decode接口ParseFromString()。 在python write.py data之后,我们可以解析data的二进制格式:
# xxd -b data
0000000: 00001010 00001001 00001010 00000100 01110100 01100101 ....te
0000006: 01110011 01110100 00010000 10101100 00000010 st...
在pb协议中,任何数据都是由key-value的形式管理,其中,key即message中每个字段的field_number, 除了key,还要有wire_type属性, field_number<<3|wire_type构成一个字节,对于有wire_type == length-delimited的数据,还需要一个字节描述数据长度的payload属性。 我们逐个字节解释这几行二进制:
[00001010]Msg's key,经过逆运算就可以获取到field_number==1,wire_type=2(Length-delimited)
[00001001]Msg's payload:9 byte
[00001010, 00000010(end)]Msg's value。
{
{
[00001010]Msg.map<>.string's key,可知其field_number == 1, wire_type == 2
[00000100]Msg.map<>.string's payload:4 byte
[01110100 01100101 01110011 01110100]Msg.map<>.string's value:test
}
{
[00010000]Msg.map<>.int64's key:field_number == 2, wire_type == 0(varint)
[10101100, 00000010(end)]Msg.map<>.int64's value:300
}
}
Eg.
理解了上面的例子,可以看出pb的布局只和定义的proto格式有关,接下来就看下tfrecord是如何利用pb的,其对应的二进制又应该是什么样子的。 tfrecord使用了example-feature定义的proto结构来作为pb的message,对应tensorflow源码中的tensorflow/core/example/example.proto和feature.proto文件。 摘录如下:
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
// Containers for non-sequential data.
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
message Features {
// Map from feature name to feature.
map<string, Feature> feature = 1;
};
//example.proto
message Example {
Features features = 1;
};
将各种嵌套结构展开,如下
example_key
example_payload
example_value
{
Features_key
Features_payload
Features_value
{
{
String_key
String_payload
String_value
}
{
Feature_key
Feature_payload
Feature_value
{
Int64List_key
Int64List_payload
Int64List_value
{
int64_key
int64_payload
int64_value
}
}
}
}
Features_key
Features_payload
Features_value
{
...
}
}
同样,上述递归结构(message->Features->feature->Int64List->value)对于接下来的二进制分析非常重要。 这次用以下脚本写入数据,并获取其序列化文件:
import tensorflow as tf
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
import sys
feature = {
'masked_lm_positions':feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=[2,10,0])),
'masked_lm_weights':feature_pb2.Feature(float_list=feature_pb2.FloatList(value=[1., 1., 0.,])),
'next_sentence_labels':feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=[1]))
}
f = feature_pb2.Features(feature=feature)
example = example_pb2.Example(features = f)
example.SerializeToString()
fo = open(sys.argv[1], "wb")
fo.write(example.SerializeToString())
fo.close()
获取到二进制文件如下:
0000000: 00001010 01100110 00001010 00100101 00001010 00010001 .f.%..
0000006: 01101101 01100001 01110011 01101011 01100101 01100100 masked
000000c: 01011111 01101100 01101101 01011111 01110111 01100101 _lm_we
0000012: 01101001 01100111 01101000 01110100 01110011 00010010 ights.
0000018: 00010000 00010010 00001110 00001010 00001100 00000000 ......
000001e: 00000000 10000000 00111111 00000000 00000000 10000000 ..?...
0000024: 00111111 00000000 00000000 00000000 00000000 ||00001010 ?.....
000002a: 00011110 00001010 00010011 01101101 01100001 01110011 ...mas
0000030: 01101011 01100101 01100100 01011111 01101100 01101101 ked_lm
0000036: 01011111 01110000 01101111 01110011 01101001 01110100 _posit
000003c: 01101001 01101111 01101110 01110011 00010010 00000111 ions..
0000042: 00011010 00000101 00001010 00000011 00000010 0001010 ......
0000048: 00000000 ||00001010 00011101 00001010 00010100 01101110 .....n
000004e: 01100101 01111000 01110100 01011111 01110011 01100101 ext_se
0000054: 01101110 01110100 01100101 01101110 01100011 01100101 ntence
000005a: 01011111 01101100 01100001 01100010 01100101 01101100 _label
0000060: 01110011 00010010 00000101 00011010 00000011 00001010 s.....
0000066: 00000001 00000001
同理,依据proto文件定义的结构,可获取其二进制解析,这里用的表达方式和前文的proto树状结构保持了一致
[00001010]Example's key: field_number==1,wire_type=2(Length-delimited)
[01100110]Example's payload: 102 bytes
[00001010, 00000001(end)]Example's value:
{
[00001010]Example.Features's key
[00100101]Example.Features's payload: 37 bytes
[00001010, 00000000]Example.Features's value:
{
{
[00001010]Example.Features.string's key
[00010001]Example.Features.string's payload: 17 bytes
[01101101, 01110011]Example.Features.string's value: masked_lm_weights
}
{
[00010010]Example.Features.feature's key
[00010000]Example.Features.feature's payload: 16 bytes
[00010010, 00000000]Example.Features.feature's value:
{
[00010010]Example.Features.feature.FloatList's key
[00001110]Example.Features.feature.FloatList's payload: 14 bytes
[00001010, 00000000]Example.Features.feature.FloatList's value:
{
[00001010]Example.Features.feature.FloatList.float's key
[00001100]Example.Features.feature.FloatList.float's payload: 12 bytes
[00000000, 00000000]Example.Features.feature.FloatList.float's value,eg, [00000000 00000000 10000000 00111111]表示float 1.0
}
}
}
}
[00001010]Example.Features's key
[00011110]Example.Features's payload: 30 bytes
[00001010, 00000000]Example.Features's value:
{
[00001010]Example.Features.string's key
[00010011]Example.Features.string's payload: 19 bytes
[00010011, 01110011]Example.Features.string's value: masked_lm_positions
...
}
}
tfrecord
tfrecord在pb的基础上,按照如下格式封装一些校验信息。结合tfrecord的操作API:sample = TFRecordDataset().make_one_shot_iterator().get_next()以及Example = tf.parse_single_example(sample, features)可以推断:获取sample就是tfrecord文件中逐个获取pb可解析字段的过程,即”越过”tfrecord中非pb原生字段,而parse就是比较纯粹地使用pb API按照example格式解析字段的过程。
{
Uint64_t example_len
Uint32_t example_len_crc32
Protocol_buffer_data //protocol buffer API可解析部分
Uint32_t pb_data_crc32
Uint64_t example_len
…
}
Ref.
Protocol Buffer Encoding
tensor-flow notes
Tfrecords and tf.Example