tfrecord二进制解析

tfrecord是tensorflow基于protobuf框架开发的一种用于持久化训练数据的文件。

protobuf

考虑下面一个简单的protobuf message定义:

syntax = "proto3";
message Msg{
	map<string, int64> val = 1;
};

这里,我们将val的field_number属性置为1,下文将在序列化文件中将其提取出来。通过 protoc --proto_path=. --python_out=. demo.proto编译生成demo_pb2.py,然后就可以使用python脚本按照Msg格式对数据进行编解码。这里,我们使用的序列化代码如下:

import demo_pb2
import sys
v = {
    "test":300,
}
msg= demo_pb2.Msg(val = v)
fo=open(sys.argv[1], "wb")
fo.write(msg.SerializeToString())
fo.close()

该代码的就是将数据v持久化到传入的文件中,其中encode接口SerializeToString()是在编译处demo_pb2.py自动生成的,与之相对应的还有decode接口ParseFromString()。 在python write.py data之后,我们可以解析data的二进制格式:

# xxd -b data
0000000: 00001010 00001001 00001010 00000100 01110100 01100101  ....te
0000006: 01110011 01110100 00010000 10101100 00000010           st...

在pb协议中,任何数据都是由key-value的形式管理,其中,key即message中每个字段的field_number, 除了key,还要有wire_type属性, field_number<<3|wire_type构成一个字节,对于有wire_type == length-delimited的数据,还需要一个字节描述数据长度的payload属性。 我们逐个字节解释这几行二进制:

[00001010]Msg's key,经过逆运算就可以获取到field_number==1,wire_type=2(Length-delimited)
[00001001]Msg's payload:9 byte
[00001010, 00000010(end)]Msg's value。
{
    {
        [00001010]Msg.map<>.string's key,可知其field_number == 1, wire_type == 2
        [00000100]Msg.map<>.string's payload:4 byte
        [01110100 01100101 01110011 01110100]Msg.map<>.string's value:test
    }
    {
        [00010000]Msg.map<>.int64's key:field_number == 2, wire_type == 0(varint)
        [10101100, 00000010(end)]Msg.map<>.int64's value:300
    }
}

Eg.

理解了上面的例子,可以看出pb的布局只和定义的proto格式有关,接下来就看下tfrecord是如何利用pb的,其对应的二进制又应该是什么样子的。 tfrecord使用了example-feature定义的proto结构来作为pb的message,对应tensorflow源码中的tensorflow/core/example/example.proto和feature.proto文件。 摘录如下:

message BytesList {
    repeated bytes value = 1;
}
message FloatList {
    repeated float value = 1 [packed = true];
}
message Int64List {
    repeated int64 value = 1 [packed = true];
}

// Containers for non-sequential data.
message Feature {
    // Each feature can be exactly one kind.
    oneof kind {
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};

message Features {
    // Map from feature name to feature.
    map<string, Feature> feature = 1;
};

//example.proto
message Example {
    Features features = 1;
};

将各种嵌套结构展开,如下

example_key
example_payload
example_value
{
    Features_key
    Features_payload
    Features_value
    {
        {
            String_key
            String_payload
            String_value
        }
        {
            Feature_key
            Feature_payload
            Feature_value
            {
                Int64List_key
                Int64List_payload
                Int64List_value
                {
                    int64_key
                    int64_payload
                    int64_value
                }
            }
        }
    }
    Features_key
    Features_payload
    Features_value
    {
        ...
    }
}

同样,上述递归结构(message->Features->feature->Int64List->value)对于接下来的二进制分析非常重要。 这次用以下脚本写入数据,并获取其序列化文件:

import tensorflow as tf
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
import sys
feature = {
    'masked_lm_positions':feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=[2,10,0])),
    'masked_lm_weights':feature_pb2.Feature(float_list=feature_pb2.FloatList(value=[1., 1., 0.,])),
    'next_sentence_labels':feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=[1]))
}
f = feature_pb2.Features(feature=feature)
example = example_pb2.Example(features = f)
example.SerializeToString()
fo = open(sys.argv[1], "wb")
fo.write(example.SerializeToString())
fo.close()

获取到二进制文件如下:

0000000: 00001010 01100110 00001010 00100101 00001010 00010001  .f.%..
0000006: 01101101 01100001 01110011 01101011 01100101 01100100  masked
000000c: 01011111 01101100 01101101 01011111 01110111 01100101  _lm_we
0000012: 01101001 01100111 01101000 01110100 01110011 00010010  ights.
0000018: 00010000 00010010 00001110 00001010 00001100 00000000  ......
000001e: 00000000 10000000 00111111 00000000 00000000 10000000  ..?...
0000024: 00111111 00000000 00000000 00000000 00000000 ||00001010  ?.....
000002a: 00011110 00001010 00010011 01101101 01100001 01110011  ...mas
0000030: 01101011 01100101 01100100 01011111 01101100 01101101  ked_lm
0000036: 01011111 01110000 01101111 01110011 01101001 01110100  _posit
000003c: 01101001 01101111 01101110 01110011 00010010 00000111  ions..
0000042: 00011010 00000101 00001010 00000011 00000010 0001010  ......
0000048: 00000000 ||00001010 00011101 00001010 00010100 01101110  .....n
000004e: 01100101 01111000 01110100 01011111 01110011 01100101  ext_se
0000054: 01101110 01110100 01100101 01101110 01100011 01100101  ntence
000005a: 01011111 01101100 01100001 01100010 01100101 01101100  _label
0000060: 01110011 00010010 00000101 00011010 00000011 00001010  s.....
0000066: 00000001 00000001

同理,依据proto文件定义的结构,可获取其二进制解析,这里用的表达方式和前文的proto树状结构保持了一致

[00001010]Example's key: field_number==1,wire_type=2(Length-delimited)
[01100110]Example's payload: 102 bytes
[00001010, 00000001(end)]Example's value:
{
    [00001010]Example.Features's key
    [00100101]Example.Features's payload: 37 bytes
    [00001010, 00000000]Example.Features's value:
    {
        {
            [00001010]Example.Features.string's key
            [00010001]Example.Features.string's payload: 17 bytes
            [01101101, 01110011]Example.Features.string's value: masked_lm_weights
        }
        {
            [00010010]Example.Features.feature's key
            [00010000]Example.Features.feature's payload: 16 bytes
            [00010010, 00000000]Example.Features.feature's value:
            {
                [00010010]Example.Features.feature.FloatList's key
                [00001110]Example.Features.feature.FloatList's payload: 14 bytes
                [00001010, 00000000]Example.Features.feature.FloatList's value:
                {
                    [00001010]Example.Features.feature.FloatList.float's key
                    [00001100]Example.Features.feature.FloatList.float's payload: 12 bytes
                    [00000000, 00000000]Example.Features.feature.FloatList.float's value,eg, [00000000 00000000 10000000 00111111]表示float 1.0
                }
            }
        }
    }
    [00001010]Example.Features's key
    [00011110]Example.Features's payload: 30 bytes
    [00001010, 00000000]Example.Features's value:
    {
        [00001010]Example.Features.string's key
        [00010011]Example.Features.string's payload: 19 bytes
        [00010011, 01110011]Example.Features.string's value: masked_lm_positions
        ...
    }
}

tfrecord

tfrecord在pb的基础上,按照如下格式封装一些校验信息。结合tfrecord的操作API:sample = TFRecordDataset().make_one_shot_iterator().get_next()以及Example = tf.parse_single_example(sample, features)可以推断:获取sample就是tfrecord文件中逐个获取pb可解析字段的过程,即”越过”tfrecord中非pb原生字段,而parse就是比较纯粹地使用pb API按照example格式解析字段的过程。

{
    Uint64_t example_len
    Uint32_t example_len_crc32
    Protocol_buffer_data                    //protocol buffer API可解析部分
    Uint32_t pb_data_crc32

    Uint64_t example_len
    …
}

Ref.

Protocol Buffer Encoding
tensor-flow notes
Tfrecords and tf.Example

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.