根据BERT论文, 其12层transformer结构有110M参数, 24层更是高达340M, 虽然google公开了这两个网络的预训练模型, 用户只需在后面加一层Full-connected 但如果自己去做BERT预训练, 到底要耗费多少显存呢?
tensorflow中使用tf.variable()在模型中生成训练参数, 同时, tf.dense()内部也含有weight和bias两类参数. 比如,
tf.get_variable(shape=[vocab_size, embedding_size],initializer=…))
生成了shape是[vocab_size, embedding_size]的variable, 即 vocab_size * embedding_size 个参数,
to_tensor_2d = tf.layers.dense(from_tensor_2d, kernel_initializer=...))
内含 from_tensor_2d.shape[-1] * to_tensor_2d.shape[-1] 个weight 以及 to_tensor_2d.shape[-1] 个bias. 共计 (from_tensor_2d.shape[-1] + 1) * to_tensor_2d.shape[-1] 个参数, 这就是一个FC-Layer的参数量
有了上面的例子, BERT的参数量就不难分析, 整个BERT网络可以分为4个模块: embedding layer + transformer layers + pooled layer+ classifier layer
embedding-layer
Continue reading →