所谓Transformer Distillation(TD),即对Transformer架构的蒸馏。假设教师模型和学生模型的层数分别为
N
N
N和
M
M
M,则首先定义一个映射函数
n
=
g
(
m
)
n=g(m)
n=g(m)表示用学生模型的第
m
m
m层去学习教师模型的第
n
=
g
(
m
)
n=g(m)
n=g(m)层的信息。文章通过数值实验选用了
g
(
m
)
=
3
m
g(m)=3m
g(m)=3m。定义第
0
0
0层为嵌入层,第
M
+
1
M+1
M+1层为预测层,则我们可以将模型的损失函数写作
L
m
o
d
e
l
=
∑
x
∈
X
∑
m
=
0
M
+
1
λ
m
L
l
a
y
e
r
(
f
m
S
(
x
)
,
f
g
(
m
)
T
(
x
)
)
(1)
\mathcal{L}_{model} = \sum_{x\in\mathcal{X}} \sum_{m=0}^{M+1} \lambda_m \mathcal{L}_{layer} (f_m^S(x), f_{g(m)}^T(x)) \tag{1}
Lmodel=x∈X∑m=0∑M+1λmLlayer(fmS(x),fg(m)T(x))(1),其中
L
l
a
y
e
r
\mathcal{L}_{layer}
Llayer表示
l
a
y
e
r
layer
layer层的损失函数,
f
m
S
(
x
)
,
f
g
(
m
)
T
(
x
)
f_m^S(x), f_{g(m)}^T(x)
fmS(x),fg(m)T(x)分别表示学生和教师模型在第
m
m
m或
g
(
m
)
g(m)
g(m)层的函数,
λ
m
\lambda_m
λm为超参数,表示第
m
m
m层的重要性。下面为针对不同层的蒸馏方式
Transformer-layer Distillation:
如上图所示,Transformer-layer Distillation包含以下两种蒸馏方法
Attention based distillation:蒸馏注意力机制矩阵,损失函数为
L
a
t
t
n
=
1
h
∑
i
=
1
h
M
S
E
(
A
i
S
,
A
i
T
)
(2)
\mathcal{L}_{attn} = \frac 1h \sum_{i=1}^h MSE(A_i^S, A_i^T) \tag{2}
Lattn=h1i=1∑hMSE(AiS,AiT)(2),其中
h
h
h为多头注意力机制的head数目,
M
S
E
MSE
MSE表示Mean Squared Error,
A
i
S
,
A
i
T
A_i^S, A_i^T
AiS,AiT分别表示学生模型和教师模型的注意力矩阵。
hidden tsates based distillation:蒸馏隐藏层(即FFN的输出层)状态,蒸馏的损失函数为
L
h
i
d
n
=
M
S
E
(
H
S
W
h
,
H
T
)
(3)
\mathcal{L}_{hidn} = MSE(H^SW_h, H^T) \tag{3}
Lhidn=MSE(HSWh,HT)(3),其中
H
S
,
H
T
H^S, H^T
HS,HT分别表示学生模型和教师模型的隐藏层状态,
W
h
W_h
Wh为可学习的参数,旨在将学生模型的隐藏向量映射到和教师模型隐藏状态相同的高维空间
Embedding-layer Distillation:对嵌入层进行蒸馏,损失函数为
L
e
m
b
d
=
M
S
E
(
E
S
W
e
,
E
T
)
(4)
\mathcal{L}_{embd} = MSE(E^SW_e, E^T) \tag{4}
Lembd=MSE(ESWe,ET)(4),其中
E
S
,
E
T
E^S, E^T
ES,ET分别表示学生模型和教师模型的嵌入层向量,
W
e
W_e
We和上述
W
h
W_h
Wh作用相同,旨在将学生模型的嵌入向量映射到和教师模型嵌入向量相同的高维空间
Prediction-layer Distillation:采用损失函数
L
p
r
e
d
=
C
E
(
z
T
/
t
,
z
S
/
t
)
(5)
\mathcal{L}_{pred} =CE(z^T/t, z^S/t) \tag{5}
Lpred=CE(zT/t,zS/t)(5),其中
z
S
,
z
T
z^S, z^T
zS,zT分别表示学生模型和教师模型的输出logits,
t
t
t表示蒸馏的温度。此设置参考原始蒸馏论文中的设置。 最后,将上述所有损失函数进行统一,得到
(
1
)
(1)
(1)式中的损失函数可表示为
L
l
a
y
e
r
=
{
L
e
m
b
d
,
m
=
0
L
h
i
d
n
+
L
a
t
t
n
,
M
≥
m
>
0
L
p
r
e
d
,
m
=
M
+
1
\mathcal{L}_{layer} = \begin{cases}\mathcal{L}_{embd}, &m = 0\\\mathcal{L}_{hidn} + \mathcal{L}_{attn}, &M\ge m >0\\\mathcal{L}_{pred}, &m=M+1\end{cases}
Llayer=⎩⎨⎧