本文分享一下我们在网络剪枝方面的新工作「Sparse Double Descent: Where Network Pruning Aggravates Overfitting」。这篇论文主要是受模型过参数化(over-parameterization)和彩票假说(lottery tickets)两方面研究的启发,探索分析了剪枝后的稀疏神经网络的泛化性能。
一句话结论: 稀疏神经网络的泛化能力受稀疏度的影响,随着稀疏度不断增加,模型的测试准确率会先下降,后上升,最后再次下降。
论文链接:
https://arxiv.org/abs/2206.08684
代码链接:
https://github.com/hezheug/sparse-double-descent
一、研究动机
根据传统机器学习的观点,模型难以同时最小化预测时的偏差与方差,因此往往需要权衡两者,才能找到最合适的模型。这便是广为流传的偏差-方差均衡(bias-variance tradeoff)曲线:随着模型容量增加,模型在训练集上的误差不断下降,然而在测试集上的误差却会先下降后上升。
虽然传统观点认为模型参数过多会导致过拟合,但是神奇的是,在深度学习实践中,大模型往往有着更好的表现。
今年来有学者发现,深度学习模型的测试误差和模型容量的关系,并非是 U 型曲线,而是具备的双下降(Double Descent)的特点,即随着模型参数变多,测试误差是先下降,再上升,然后第二次下降 [1,2]。
也就是说,过参数的神经网络非但不会发生严重的过拟合,反而有可能具有更好的泛化性能!
这究竟是为什么呢?
彩票假说(lottery tickets)[5] 为解释这一现象提供了一个新的思路。彩票假说认为,一个随机初始化的密集网络(未剪枝过的初始网络),包含着性能良好的稀疏子网络,这个子网络从原初始化(winning ticket)训练时,可以达到媲美原始密集网络的准确率,甚至还有可能更快收敛(而如果让这个子网络从一个新的初始化值开始训练,效果则往往大不如原始网络)。
当一个网络参数越多,它包含这样一个性能良好的子网络的概率就越大,也就是中彩票的可能性越高。从这个角度出发,一个过参数的神经网络中,真正对优化和泛化起作用可能只有相当少的一部分参数,而其余的参数只是作为冗余备份存在,即使被剪掉也不会对模型训练产生决定性影响。
彩票假说似乎说明,我们可以安全地剪掉模型当中的冗余参数,而不必担心是否会造成不利影响。还有一些其他文献,本着简单最优的奥卡姆剃刀原则,相信剪枝后的稀疏网络会具有更好的泛化能力 [4]。目前的剪枝文献也都强调自己的算法可以在剪去大量参数的情况下,仍保持与原模型相媲美的准确率。
但是联想到双下降现象,我们不禁反思一个基本问题:剪枝剪去的参数真的是完全冗余的吗?难道过参数更优的双下降在稀疏神经网络上并不成立吗?为了探寻这个问题的答案,我们参考 deep double descent [2] 的设置,在稀疏神经网络上进行了大量实验。
二、稀疏神经网络中的双下降现象
通过实验,我们惊讶地发现,网络中所谓"冗余"的参数其实并不完全冗余。当参数量逐渐减少,稀疏度逐渐上升时,即使模型训练准确率尚未受到影响,其测试准确率可能已经开始明显下降。这时,模型越来越严重地过拟合噪声。
如果进一步的增加模型稀疏度,可以发现当经过某个拐点后,模型的训练准确率开始快速下降,测试准确率开始上升,此时模型对噪声的鲁棒性逐步提高。至于当测试准确率达到最高点后,若继续减少模型的参数,则会影响模型的学习能力。此时,模型的训练与测试准确率同时下降,开始变得难以学习。
此外,我们还发现采用不同的标准来剪枝,得到的模型即使参数量相同,其模型容量/复杂度也不同。例如针对同一类拐点,采用基于权重的剪枝的模型稀疏度更高,而随机剪枝则对应着较低的稀疏度。说明随机剪枝对模型表达能力的破坏更大,想取得相同的效果只能剪更少的参数。
虽然我们的大部分实验都采用了彩票假说的 retrain 方式,但也尝试了其他几种不同的方法。有趣的是,即使是剪枝后微调(Finetuning)也可以观察到明显的双下降。可见稀疏双下降现象并不局限于从初始化训练一个稀疏网络,哪怕沿用剪枝前训练好的参数值也会有相似的结果。
我们还调整了标签噪声的比例,来观察双下降现象的变化。类似于 deep double descent,提高标签噪声的比例,会使得模型训练准确率下降的起始点,向更高模型容量方向移动(即更低的稀疏度)。而另一方面,标签噪声比例越高,为了取得对噪声的鲁棒性,越多的参数需要被剪去以避免过拟合。
三、如何解释稀疏神经网络的泛化性能与双下降现象?
在这里我们主要检验了两种可能的解释。其一是极小点平坦度假说(Minima Flatness Hypothesis)。一些文章指出,剪枝可以为模型增加扰动,这种扰动使得模型更易收敛到平坦的极小点 [5]。由于极小点越平坦,一般会具有更好的泛化能力,因此 [5] 认为剪枝通过影响极小点的平坦度影响着模型的泛化。
那么,极小点平坦度的变化可以解释稀疏双下降吗?我们对 loss 进行如图的可视化,间接比较了不同稀疏度下,模型极小点平坦度的大小。
遗憾的是,随着稀疏度提高,loss 曲线变得越来越陡峭(不平坦)。极小点平坦度与测试准确率之间并没有呈现出相关关系。
另一是学习距离假说(Learning Distance Hypothesis)。已有理论工作证明,深度学习模型的复杂度与参数到初始化的 l2 距离(学习距离)息息相关 [6]。学习距离越小,说明模型停留在离初始化越近的位置,好比早停时获得的模型参数,此时还没有足够的复杂度记忆噪声;反之,则说明模型在参数空间上的改变就越大,此时复杂度更高,容易过拟合。
那么,学习距离的变化可以反应双下降的趋势吗?
如图可见,当准确率下降时,学习距离整体呈上升趋势,且最高点恰好对应准确率的最低点;而当准确率上升时,学习距离也相应下降。学习距离的变化与稀疏双下降的变化趋势基本吻合(尽管当测试准确率第二次下降时,由于可训练的参数过少,学习距离难以再次上升了)。
四、与彩票假说的区别与联系
我们还进行了彩票假说中winning ticket与重新随机初始化的对比实验。有趣的是,在双下降情景下,彩票假说的初始化方式并不总是优于对网络重新初始化的效果。
由图可以看出,Reinit 的结果相比于 Lottery 整体左移,也就是说 Reinit 方式在保留模型的表达能力方面是逊于 Lottery 的。这也从另一方面验证了彩票假说的思想: 即使模型的结构完全相同,从不同的初始化训练时,模型的性能也可能相差甚远。
五、后记
在做这项研究的过程中,我们观察到了一些神奇的、反直觉的实验现象,并尝试进行了分析解释。然而,现有的理论工作还无法完全地解释这些现象存在的原因。比如说在训练准确率接近 100% 时,测试准确率会随着剪枝逐渐下降。为何此时模型没有遗忘数据中的复杂特征,反而对噪声更加严重的过拟合?
我们还观察到模型的学习距离会随着稀疏度增加先上升后下降,为何剪枝会导致模型学习距离发生这样的变化?以及深度学习模型的双下降现象往往需要对输入增加标签噪声才可以观察到 [2],决定双下降是否发生的背后机制是什么?
还有很多问题目前尚无答案。我们现在也在进行一个新的理论工作,以期能对其中的一个或几个问题进行解释。希望可以早日拨开迷雾,探明这一现象背后的本质原因。
参考链接
[1] Belkin, M., Hsu, D., Ma, S., & Mandal, S. (2018). Reconciling modern machine learning and the bias-variance trade-off.stat,1050, 28.
[2] Nakkiran, P., Kaplun, G., Bansal, Y., Yang, T., Barak, B., & Sutskever, I. Deep double descent: Where bigger models and more data hurt. ICLR 2020.
[3] Frankle J., & Carbin, M. The lottery ticket hypothesis: Finding sparse, trainable neural networks. ICLR 2019.
[4] Hoefler, T., Alistarh, D., Ben-Nun, T., Dryden, N., & Peste, A. Sparsity in deep learning: Pruning and growth for efficient inference and training in neural networks. arXiv preprint arXiv:2102.00554, 2021.
[5] Bartoldson, B., Morcos, A. S., Barbu, A., and Erlebacher, G. The generalization-stability tradeoff in neural network pruning. NIPS, 2020.
[6] Nagarajan, V. and Kolter, J. Z. Generalization in deep networks: The role of distance from initialization. arXiv preprint arXiv:1901.01672, 2019.
文章来源:https://zhuanlan.zhihu.com/p/542739368作者:Hez
Illustration by Violetta Barsuk from icons8
-The End-