根据BERT论文, 其12层transformer结构有110M参数, 24层更是高达340M, 虽然google公开了这两个网络的预训练模型, 用户只需在后面加一层Full-connected 但如果自己去做BERT预训练, 到底要耗费多少显存呢?
tensorflow中使用tf.variable()在模型中生成训练参数, 同时, tf.dense()内部也含有weight和bias两类参数. 比如,
tf.get_variable(shape=[vocab_size, embedding_size],initializer=…))
生成了shape是[vocab_size, embedding_size]的variable, 即 vocab_size * embedding_size 个参数,
to_tensor_2d = tf.layers.dense(from_tensor_2d, kernel_initializer=...))
内含 from_tensor_2d.shape[-1] * to_tensor_2d.shape[-1] 个weight 以及 to_tensor_2d.shape[-1] 个bias. 共计 (from_tensor_2d.shape[-1] + 1) * to_tensor_2d.shape[-1] 个参数, 这就是一个FC-Layer的参数量
有了上面的例子, BERT的参数量就不难分析, 整个BERT网络可以分为4个模块: embedding layer + transformer layers + pooled layer+ classifier layer
embedding-layer
最原始的token_one_hot, segment_one_hot和position_one_hot分别经过token_embedding_table, segment_embedding_table以及position_embedding_table的转换, 生成相应的token_embedding, segment_embedding以及position_embedding, 最终作和生成embedding再送入transformer 层. 这里, 三张转换表是要训练的对象. 通过代码:
embedding_table = tf.get_variable(shape=[vocab_size, embedding_size],initializer=...)
token_type_table = tf.get_variable(shape=[token_type_vocab_size, width], initializer=...)
full_position_embeddings = tf.get_variable(shape=[max_position_embeddings, width], initializer=...)
可知, 该层的参数量为: vocab_size * embedding_size + token_type_vocab_size * width + max_position_embeddings * width
transformer-layers
对于每一个transformer块, 又可以分为attention + (fc + layer_norm) + fc + (fc + layer_norm)几部分.
self-mutihead-attention QKV共有三个FC: (hidden_size + 1) * hidden_size * 3
第1次FC-layer_norm转换: fc: (hidden_size + 1) * hidden_size + layer_norm: hidden_size * 2
第2次FC转换: fc: (hidden_size + 1) * intermediate_size
第3次FC-layer_norm转换: fc:(intermediate_size + 1) * hidden_size + layer_norm: hidden_size * 2
上述只是一个transformer block的参数量, BERT通常使用12层或24层, 乘上相应倍数即可
pooled-layer
这一层主要解决每个sequence开头的 [CLS] 的embedding问题, 后续主要用于句子级分类. 参数量: (hidden_size + 1) * hidden_size
classifier-layer
该层在fine-tuning过程中使用. 无论是sentence classifier还是token classifier, 新增参数量均为 num_label * hidden_size + hidden_size
上述即是标准BERT模型的参数量计算. 使用下述参数构建BERT, 其参数量可以快速计算其pretrainning需训练参数量: 235848192
vocab_size = 195061
max_position_embeddings = 512
max_seq_length = 256
hidden_layers = 12
hidden_size = 768
attention_heads = 12
intermediate_size = 3072
type_vocab_size = 2