ICLR 2024 Oral | 三行代码,即插即用!NUS尤洋团队新作—InfoBatch,无损数据集动态剪枝加速

2024-01-23 14:02 206 阅读 ID:1824
将门
将门

本文介绍来自NUS尤洋团队的最新科研成果 - InfoBatch。这是首篇关于无损数据剪枝加速的工作,覆盖分类、分割、检测、图片生成、LLM指令微调等任务。作为一个即插即用且与架构无关的框架,在实际应用中,InfoBatch 可以无损地节省 40% 的总开销(时间和计算)。

随着深度学习的网络参数量和数据集规模增长,算力需求日益增加,如何节省训练成本正在成为逐渐凸显的需求。现有的数据集压缩方法大多开销较高,且难以在达到无损的情况下获得可观的节省率;加权抽样的相关方法则对于模型和数据集的特点较为敏感且依赖于重复抽样假设,在实际应用中难以和已完成调参的学习率调整策略结合。两种从数据角度出发的方法在实践中很难真正帮助节省计算。

在本篇工作中,研究者从数据迭代这个角度切入进行了研究。长久以来,数据集的迭代方式大都采用随机迭代。对此,作者提出了InfoBatch框架,根据网络对样本的拟合情况进行动态剪枝采样的方法,并利用重缩放(rescaling)来维持剪枝后的梯度更新(Gradient Update)期望,以此在性能无损的情况下提高训练效率,加快训练速度。

                                                                 InfoBatch在不同任务上的表现
                                                                      仅需三行代码即可使用

在CIFAR10/100(ResNet,分类)、ImageNet-1K(ResNet/ViT,分类)和ADE20K(UperNet,语义分割)上,InfoBatch无损节省了40%的总开销(时间和计算);在检测任务上(YOLOv8),InfoBatch无损节省了30%;

对于MAE预训练(ViT)和FFHQ图片生成(Latent Diffusion), InfoBatch分别节省了24.8%和27%的开销。

LLaMA的指令微调上, InfoBatch成功在静态压缩方法DQ[1]的基础上额外节省了20%开销,总开销为原始的1.6%,并且和LoRA兼容。

论文题目: 

InfoBatch: Lossless Training Speed Up by Unbiased Dynamic Data Pruning 

论文链接:

https://arxiv.org/abs/2303.04947 

代码链接: 

https://github.com/henryqin1997/InfoBatch

一、动机

在过去的十年里,深度学习取得了长足的进步。与之相应的是大部分最先进的深度学习工作大都使用了超大规模的数据集,这对于很多资源有限的研究者来说是难以负担的。为了降低训练开销,研究者们进行了一系列不同研究。

一个比较直接的方法是降低数据集规模。数据集蒸馏(Dataset Distillation)[2]和核心集合选择(Coreset Selection)[1]分别从原有的数据集中合成/选择一个更小但更有信息量的新数据集(子集)。然而,虽然样本数量减少了,这两种方法本身却引入了不可忽略的额外开销。此外,这两种方法达到无损性能比较困难。另外的工作有加权抽样(weighted sampling)[3],可以通过改变样本采样率来提高训练收敛速度,相应的缺点是加速比对模型和数据集敏感,难以直接和学习率调整策略结合。

近期,一些工作试图通过减少迭代来加速训练。其中一类方法和核心集合选择类似,通过给样本打分并排序来选取更有信息量的样本,其余样本不参加训练,作者称之为数据静态剪枝;另一类方法在此基础上,于训练过程中动态打分并周期性选取子集,作者称之为数据动态剪枝。相比于静态方法,动态方法的单次额外开销更小,而且同计算量性能更好,但是现有方法依旧难以达到无损性能。

为了应对以上方法的缺点,作者提出了InfoBatch训练框架。InfoBatch的主要改进如图1所示,它在数据迭代过程中动态剪枝,通过Soft Pruning(概率剪枝)和Gradient Rescaling(梯度重缩放)维护了总更新量的期望值不变,以此达到了无损加速的目的。为了防止剩余训练轮次不足时的残余偏差,InfoBatch在最后的少部分轮次中使用原始数据集随机采样训练。作者在分类,语义分割,目标检测,Diffusion图片生成,LLaMA指令微调等任务上验证了方法的无损加速。

二、方法

2.1 总览

现有的静态/动态数据剪枝方法,会通过某种方式给样本打分,然后对样本得分排序,选取“对训练更有帮助”的样本进行训练。这种选择通常是确定性的,和目标的剪枝百分比直接挂钩。与之相对应的问题是,直接剪枝导致了梯度期望值方向偏差以及总更新量的减少。

为了解决梯度更新的期望偏差,如图2所示,InfoBatch前向传播中维护了每个样本的分值,并以均值为阈值,对一定比例的低分样本进行了动态剪枝。为了维护梯度更新期望,剩余的低分样本的梯度被相应放大。通过这种方式,InfoBatch训练结果和原始数据训练结果的性能差距相比于之前方法得到了改善。为了进一步减少残余的梯度期望值偏差,InfoBatch在最后几个轮次中使用全数据训练。

2.2 无偏剪枝和重缩放(Unbiased Prune and Rescale)

2.3 退火(Annealing)

虽然理论上的期望更新基本一致,上述的期望值实际包含时刻 t 的多次取值。在训练中,如果一个样本在中间的某个轮次被剪枝,后续依旧大概率被训练到;而在剩余更新轮次不足时,这个概率会大幅下降,导致残余的梯度期望偏差。因此,在最后的几个训练轮次中(通常是12.5%~17.5%左右),InfoBatch采用完整的原始数据进行训练。

三、实验

3.1 实验设置

作者在多个数据集上验证了InfoBatch的有效性,包括(分类)CIFAR-10/100,ImageNet-1K,(分割)ADE20K,(图片生成)FFHQ,(指令微调)Alpaca。训练的模型包括(分类)ResNet18,ResNet-50,ViT-Base(MAE), Swin-Tiny,(分割)UperNet,(图片生成)Latent Diffusion, (指令微调)LLaMA-7B。

3.2 实验结果

这里展示主要结果,更多结果请参考论文。

另外,根据作者最新更新,InfoBatch在检测任务上也取得了无损加速30%的效果,代码将会在github更新。

四、总结与展望

在这项工作中,作者提出了InfoBatch框架,能够在广泛的任务上可观地节省训练开销并加速。其核心的思想是根据样本拟合情况动态调整采样剪枝策略,并利用重缩放维持更新量的一致。作者在文中进一步探讨了该策略的适用范围和进一步的优化,期待此类工作以后能取代传统数据迭代方式,助力大模型时代训练加速。

参考

[1]Zhou, Daquan, et al. "Dataset quantization." Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023.

[2]Wang, Kai, et al. "Cafe: Learning to condense dataset by aligning features." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.

[3]Csiba, Dominik, and Peter Richtárik. "Importance sampling for minibatches." The Journal of Machine Learning Research 19.1 (2018): 962-982.

Illustration From IconScout By Delesign Graphics

免责声明:作者保留权利,不代表本站立场。如想了解更多和作者有关的信息可以查看页面右侧作者信息卡片。
反馈
to-top--btn