参考:https://www.bilibili.com/video/BV1Wv411h7kN?p=12

类神经网络训练不起来怎么办

局部最小值与鞍点

从上一篇笔记中我们知道,当训练时训练集上的loss达不到预期,可能时最优化方式出现问题,最优化失败的原因为优化停止在了critical point(驻点),包括local minima(局部最小值)、local maxima(局部最大值)和saddle point(鞍点)

image-20220212205623752

判断critical point类型的方法:

将损失函数在最优化停止的点θ\theta'进行泰勒展开,取前三项-第一项为处的损失函数,第二项包含loss在θ\theta'处导数,第三项包含Hassian矩阵。

image-20220212210603237

当出现critical point的情况时,第二项为0,为了判断具体是哪种驻点,需要看第三项的Hession矩阵

image-20220212210916971

对于局部最小值点,令v=θθv=\theta-\theta',对于局部最小值,L(θ)>L(θ)L(\theta)>L(\theta'),即对于任意向量v,有vTHv>0v_THv>0,H为正定矩阵 (positive definite matrix),H的特征值均大于零,对于局部最小值和鞍点,可以用同样的方法进行推导

image-20220212212010431

可得判断准则:

所有特征值(eigen value)都为正=》local minima

所有特征值为负=》local maxima

特征值有正有负=》saddle point

举个栗子

image-20220212212826891

对于一个最简单的神经网络,y=w1w2xy=w_1w_2x,绘制出损失函数的平面,可以看出平面存在较多的注定,当遇到驻点时,做法如下:

(1)计算矩阵H,并根据H的特征值判断驻点类型

image-20220212213115862

(二)在这个例子中驻点为鞍点,可以进一步优化,求出H矩阵的特征向量u

saddl point 解决方法

(1)

image-20220212214055311

如果判断最优化停在鞍点上,继续往θ=θ+u\theta=\theta'+u方向优化即可(这种方法计算量较大)

(2)momentum(动量)

vanilla gradient decent 计算方法是,选择初始值θ0\theta^0,计算loss函数在对参数的导数g0g^0,并移动到点θ1=θ0ηg0\theta^1=\theta^0-\eta g^0,重复直到找到最低点。这种方法会产生问题,即无法判断使进程停止的点到底是全局最小值还是其他驻点,只要遇到驻点,进程就会停止。momentum(动量)为变量更新的方向提供了一个调整的分量,变量下一步优化方向由当前导数和上一步移动的方向共同决定(可以与力学中的运动问题类比理解,小球运动的方向由当前受到的力和之前的运动方向有关),变量移动方向可以用下面的步骤来表示:

①选取初始变量θ0\theta^0

②初始moment m0=0m^0=0

③计算θ0\theta^0处导数g0g^0

④计算movement m1=λm0ηg0m^1=\lambda m^0-\eta g^0

⑤移动变量 θ1=θ0+m1\theta ^1=\theta^0 +m^1

⑥计算θ1\theta^1处导数g1g^1

image-20220214173258455

image-20220214173828271

如果进程停在极值点,很难进行进一步优化,但研究表明,最优化进程停在极值点且训练集loss函数值很小的情况极少发生

image-20220212230249767

批次(batch)大小怎么选

shaffle:每个epcho中的batch分法不同

batch size大:稳定,不容易受噪声影响

batch size小:gradient desent方向容易受噪声影响

使用GPU平行运算,当batch size不超过一定值时,batch大并不一定比batch小时计算更耗时,考虑一个epcho运算的时间,batch size大花费的时间更小。

image-20220213163043107

从下图中可以看出,small batch训练效果更好,产生的noise反而可以帮助训练,将full batch(所有数据放到一个batch中训练)中计算gradient decent容易陷入局部最小值,但对于samll batch,batch1计算出的gradient decent为0的点对于batch2不一定成立,因此最优化不会停止。

image-20220213163914534

image-20220213164505732

同时又研究表明small batch和big batch在训练集上loss相近,但在测试集上small batch明显效果更好。其中一种可能的解释是,small batch更容易找到周围数值变化更平缓的最小值(flat minima),这种最小值在测试集上的loss会更小image-20220213172208034

small batch与large batch对比

image-20220213172726525