神经网络与深度学习——第8章 注意力机制与外部记忆
本文讨论的内容参考自《神经网络与深度学习》http://nndl.github.io/ 第8章 注意力机制与外部记忆
注意力机制与外部记忆
认知神经学中的注意力
注意力机制
注意力机制的变体
硬性注意力
键值对注意力
多头注意力
结构化注意力
指针网络
自注意力模型
人脑中的记忆
记忆增强网络
端到端记忆网络
神经图灵机
基于神经动力学的联想记忆
Hopfield网络
使用联想记忆增加网络容量
总结和深入阅读
习题
假设隐藏神经元的数量为
D
D
D,输入层的维数为
M
M
M,分析一下LSTM结构,遗忘门那里接收上一时刻的隐藏状态
h
t
−
1
h_{t-1}
ht−1和输入
x
t
{x_t}
xt,权重矩阵为
W
f
W_f
Wf,即
f
t
=
σ
(
W
f
⋅
[
h
t
−
1
,
x
t
]
+
b
f
)
=
σ
(
U
f
h
t
−
1
+
w
f
x
t
+
b
f
)
f_t=\sigma(W_f \cdot [h_{t-1},x_t]+b_f)=\sigma (U_f h _{t-1}+w_f x_t+b_f)
ft=σ(Wf⋅[ht−1,xt]+bf)=σ(Ufht−1+wfxt+bf),
W
f
W_f
Wf的参数数量是
D
∗
(
D
+
M
)
D*(D+M)
D∗(D+M),
b
f
b_f
bf的参数数量是
D
D
D。同理可知,输入门、输出门那里都有
D
∗
(
D
+
M
)
+
D
D*(D+M)+D
D∗(D+M)+D个参数,因为
i
t
=
σ
(
W
i
⋅
[
h
t
−
1
,
x
t
]
+
b
i
)
i_t=\sigma(W_i \cdot [h_{t-1},x_t]+b_i)
it=σ(Wi⋅[ht−1,xt]+bi)和
o
t
=
σ
(
W
o
⋅
[
h
t
−
1
,
x
t
]
+
b
o
)
o_t=\sigma(W_o \cdot [h_{t-1},x_t]+b_o)
ot=σ(Wo⋅[ht−1,xt]+bo),然后在新的候选状态
c
^
\hat c
c^那里有
c
^
=
t
a
n
h
(
W
c
⋅
[
h
t
−
1
,
x
t
]
+
b
c
)
\hat c=tanh(W_c \cdot [h_{t-1},x_t]+b_c)
c^=tanh(Wc⋅[ht−1,xt]+bc),也有
D
∗
(
D
+
M
)
+
D
D*(D+M)+D
D∗(D+M)+D个参数,因此一个LSTM隐藏层的参数总数为
4
D
∗
(
D
+
M
+
1
)
4D*(D+M+1)
4D∗(D+M+1)。
σ
(
x
i
)
=
e
x
i
∑
j
=
1
n
e
x
j
\sigma (x_i) =\frac{e^{x_i}}{\sum_{j=1}^ne^{x_j}}
σ(xi)=∑j=1nexjexi,如果指数特别大,那么总体分布的方差就非常大,出现梯度消失的情况,比如说,我们使用非标准化的softmax函数
σ
β
(
x
i
)
=
e
β
x
i
∑
j
=
1
n
e
β
x
j
\sigma_{\beta} (x_i) =\frac{e^{\beta x_i}}{\sum_{j=1}^ne^{\beta x_j}}
σβ(xi)=∑j=1neβxjeβxi,当
β
\beta
β趋于无穷的时候,
σ
β
(
x
i
)
=
a
r
g
m
a
x
(
x
)
\sigma_{\beta} (x_i)=argmax (x)
σβ(xi)=argmax(x),这样的话,当输入的方差或者数量级特别大的时候,softmax会接近不可导的
a
r
g
m
a
x
(
x
)
argmax (x)
argmax(x),梯度导数接近于0,通过缩放点积,除以维度的开方,降低了方差,可以缓解梯度消失的现象。
CNN
输入序列长度是
L
L
L,序列中每个元素的维度是
d
d
d,那么卷积核的维度也是
d
d
d,所以
k
d
∗
d
kd*d
kd∗d是卷积核对元素的复杂度,然后对整个序列的复杂度就是
k
L
d
2
kLd^2
kLd2,RNN
的每个时间步的隐藏状态的更新都依赖于上一个时间步的隐藏状态和当前输入,
h
t
=
U
h
t
−
1
+
W
x
t
h_t=Uh_{t-1}+Wx_t
ht=Uht−1+Wxt,输入和隐藏状态都是
d
d
d维的,U和W就都是
d
2
d^2
d2维的,所以复杂度是
d
2
d^2
d2,整个序列的复杂度是
L
d
2
Ld^2
Ld2,Transformer
的自注意力机制是
X
s
o
f
t
m
a
x
(
X
T
X
)
Xsoftmax(X^TX)
Xsoftmax(XTX),
X
d
L
X_{dL}
XdL,那么
X
T
X
X^TX
XTX的复杂度是
L
2
d
L^2d
L2d,然后
X
d
L
X_{dL}
XdL和一个
L
2
L^2
L2的矩阵相乘,复杂度也是
L
2
d
L^2d
L2d,所以总体的复杂度是
L
2
d
L^2d
L2d,这里解释一下矩阵相乘的复杂度,比如说
d
∗
L
d \ast L
d∗L的矩阵和
L
∗
L
L * L
L∗L的矩阵相乘,最终得到的
d
∗
L
d*L
d∗L的矩阵的元素
a
i
j
a_{ij}
aij为第
i
i
i行和第
j
j
j列元素的乘积加和,共有
L
L
L个元素,所以有
L
L
L次乘积,然后有
d
∗
L
d*L
d∗L个元素,所以是
L
2
d
L^2d
L2d次乘积,在这种神经网络中的复杂度,为了简化计算,所以忽略了加减法,只看浮点数的乘法,实际上加减法是有影响的。同理
L
∗
d
L*d
L∗d和
d
∗
L
d*L
d∗L相乘,最终得到
L
∗
L
L*L
L∗L的矩阵,每个元素经过了
d
d
d次乘法,所以复杂度也是
L
2
d
L^2d
L2d。
序列操作数是因为CNN
和Transformer
都可以并行,所以是
O
(
1
)
O(1)
O(1),而RNN
每一步都依赖于上一个时间步,所以是
O
(
L
)
O(L)
O(L)。
共同点是都通过外部记忆单元进行读取和写入,并且由控制器进行读写的调用。不同点是在读操作时,端到端记忆网络通过多跳操作来读取数据,而且多跳中的参数是共享的,神经图灵机在读取时直接基于内容寻址,控制器通过输出
h
t
h_t
ht产生查询向量,并计算读向量,并将读向量当作下一时刻控制器的输入。在写操作时,端到端记忆网络没有写操作,其外部记忆单元是只读的,而神经图灵机是可读可写的,写操作包括两个子操作(删除和添加),通过输出
h
t
h_t
ht产生删除向量和增加向量,进而写入外部记忆。