ICLR 2023 | RevCol:可逆的多 column 网络,大模型架构设计新范式

2023-04-02 21:23 532 阅读 ID:931
将门
将门

该工作将解耦学习的思想引入模型设计中,提出以reversible column为单元来传递信息,既保证特征解耦,同时信息在网络中的传递不受到损失。实验结果,在ImageNet-1K上达到了90%的Top-1 Accuracy

论文标题:

Reversible Column Networks

论文链接:

https://arxiv.org/pdf/2212.11696.pdf

代码链接:

https://github.com/megvii-research/RevCol

我们给神经网络架构增加了一个维度

自ViT时代到来之后,由一叠blocks堆起来构成的基础模型已经成为了广泛遵循的基础模型设计范式,一个神经网络的宏观架构由width宽度(channel数)和depth深度(block数)来决定。有没有想过,一个神经网络未必是一叠blocks组成的?可能是2叠,4叠,或者…16叠?

太长不看系列

介绍一下我们最新的工作“Reversible Column Networks”,将解耦学习(disentangled feature learning)的思想引入模型设计中,提出以reversible column为单元来传递信息,既保证特征解耦,同时信息在网络中的传递不受到损失。整个网络结构包括了多个子网络(我们称为column),column间加入可逆的连接,通过将输入反复接入column,逐渐分离low-level的纹理细节和semantic语义信息。这样做的好处在于,既能够保证在预训练中保持高精度,又保证了low-level的信息不丢失以在下游任务(detection,segmentation)中能够达到更好效果。

为了验证这套设计模式在大模型大数据下的表现,我们在RevCol上做了一个2B参数的纯CNN超大模型,且只使用了3x3的卷积核。在ImageNet-1K上达到了90%的Top-1 Accuracy,下游的检测和分割任务上双双达到60+的水平,COCO AP box 63.8%,ADE 20k mIoU 61.0%。此外,RevCol架构依然遵循了可逆神经网络的设计范式,也就继承了可逆网络天然的节省显存的优势,文中的大部分实验均可在2080ti上完成。而节省显存这件事,对于大模型训练无疑是重要的。

一、背景

早在CNN时代,我们就发现单纯把像ResNet用在下游任务上难以发挥优势,但使用HRNet,FPN这样的多尺度融合的方式后,就能够取得较好的效果。然而单纯把HRNet和FPN放在上游分类任务中,效果又不如ResNet这样的直筒网络。那么多尺度融合的网络和直筒网络到底是哪里做对了,又哪里不足呢?
要回答这个问题,我们可以借助Tishby提出的Information Bottleneck principle来审视这些结构。Information Bottleneck讲,网络向前传播的时候,会逐步丢弃(压缩)和任务无关的信息,而保留对任务有帮助的信息。对于分类任务上的网络来说,靠近输入的浅层feature蕴含了大量和分类无关的low-level信息,而靠近输出的深层主要是semantic语义信息。单从分类上看,这样似乎是合理的,但如果把在分类任务上预训练得到的网络再用于下游,预训练阶段的信息损失就会影响到下游任务上的效果。

所以,最好的backbone应该是具备把task-relevant的语义信息decoupled 出来放到一些dimensions中,且在网络中保留尽可能多的输入信息的。这话Bengio在2012年Disentangling factors of variation的文章中也提到了,原句是

we require a means of feature extraction that disentangles these factors in the data rather than simply learn to represent some of these factors at the expense of those that are lost in the filter pooling operation.

我们的RevCol通过在结构上精妙的设计,在e2e的训练pipline下实现了disentangle的目标。如Figure 1 所示,一般的直筒网络(single column)的信息传递方式是,越靠近input的部分的信息越偏向low-level,越靠近loss的位置越semantic。而RevCol采用了multi-input的设计,每个column的起始位置都是low-level的信息。

随着column的iteration,在column的最末端,feature中的语义信息就逐渐被提取了出来。在column之间采用了Reversible的连接设计,也就是说后面column可以倒推出前面column的信息。这样从最后一个column可以一路倒推回第一个column,以此保证信息在column间传递的时候是无损的。同时在column的最末端加入中间监督,显式地约束每个column的输出信息的表征,以此来保证语义信息能随着column iteration被decouple出来。

二、方法

RevCol结构中包含了很多subnet,我们称作为column。依照column iteration的形式把multi-level reversible 单元排列起来就构成了网络的宏观架构。

2.1 Multi-Level Reversible Unit

2.2 基本结构

【宏观结构】

【微观结构】

Building Block我们用的ConvNeXt,不过它不太稳定,所以我们改了点东西。(见上图)如果用过ConvNeXt-L/XL,就会发现这个模型不太稳定,我曾在他家官方github里提了个issue,说它variance很大。这个问题在更大尺寸更多数据训练下的RevCol中会愈发剧烈。表现为fp16训练模型很快就炸了。如果大家有兴趣,稳定性问题我可以单开帖子讲。所以,我们把downsample的位置改为了postnorm,并且Reversible连接上使用了scaling,并且通过weight decay等方式来限制feature的Magnitude。我们还把原本的ConvNeXt block中7x7的kernel size改成了3x3,大kernel的收益在revcol上有限,并且大kernel会加剧不稳定现象(这个问题早在replknet中就观察到了),另外,小kernel是真的快!

2.3 中间监督

我们还做了个plug-in的中间监督方法,在不改变e2e训练的pipline下,上下游能额外带来1个点以上的收益。我们发现,由于Reversible unit的后加和设计,在每个column的最后,都用Reversible operation的方式一路加到了最后一个column然后接loss,这样就会让前面column的底部离loss很近。离loss近,就意味着这个位置的feature中包含的主要是语义信息。那么第一个column本身是会丢信息的,第一个column如果包含了太多语义信息丢失了太多纹理信息,那后面再Reversible,收益也很小了。它不能坍缩掉。

所以我们在column底部加了一个分类head,和一个feature重建的head,分别接CE、BCE Loss,然后随着column加深,逐步调节这俩loss的比值,最后一个column重建的loss为0,分类的loss占比为100%。

2.4 模型设计

在RevCol这套架构下,一个模型有三个维度了:channel数(width),单个column的block数(depth),还有column数。我们在设计模型的时候,发现增加column数的收益几乎等同于同时增加width和depth,所以做了个简单粗暴的scale up rule:一个small(8 column)就是俩tiny(2x4 column),一个base(16 column)就是四个tiny(4* 4column)。只不过这里在base上为了和竞品对齐计算量稍微做了点调整。

哦对了,RevNet还有个人见人爱的特性,也就是其他Reversible papers里纷纷宣传的节省显存。RevCol-T/S/B用的单column计算量几乎是一样的,增加column后显存中只增加了param的存储,所以这仨尺寸的模型基本上占用的显存是一样的。他们都可以在RTX 2080ti (11G)中训练。在Huge(2B参数)上,我们也开启了Reversible的计算方式,以此提高训练效率。如果无法节省显存,Huge的训练代价恐怕要增加很多倍。

我们在开源的code里放了Reversible Autograd Function,简单来说就是重写每个Column的前传反传逻辑,通过自定义在前传中存储的中间结果和反传中的倒推梯度逻辑来节省GPU Memory中存储的内容。

详见:

https://github.com/megvii-research/RevCol/blob/main/models/revcol_function.py

三、 实验结果

【ImageNet Classification】

这年头把ImageNet刷到89其实不算难,但你体会过把它刷到90有多难吗?我体会过。其实CNN是一个很稳定的东西,私有数据集的弱标签精度和由于模型变宽变身带来的不稳定性都只有在XL级别以上的模型尺寸中才会出现。Transformer不一样,它更加容易崩。可想而知GPT3这种模型的训练有多么困难。

除了2B参数的模型以外,我们还收集了168Million的私有数据集(Megdata-168M),weakly-label的标签。用来预训练。XL模型(800M param),在22k下能达到88.2,经过Megdata-168M,的训练后能够涨到89.4。Huge(2.1 B param)224 pretrain,640x640 Finetune,能够达到90.0% Top-1 Accuracy。这个模型的训练开销:预训练总共1600个ImageNet Epochs,训练一次使用80块A100需要14天。

【COCO Object Detection】

【ADE Semantic Segmentation】

在COCO 上,使用DINO的框架,经过Object 365进一步Finetune之后,RevCol-H能够达到63.8的 AP box。在ADE 20k,使用Mask2Former框架,mIoU 能够达到61%。

【Foundation Models】

我们列举了各家的Foundation Models并且做了个对比。RevCol-H作为一个单模态模型(Megdata-168M数据集只包含图片没有language)且标签用的是semi-labeled方式,我们没有使用Mask Image Modeling预训练,我们还是个CNN。最终的上下游任务都能够达到和其他单模态多模态大模型comparable的结果,比如多模态模型BEiT3,多模态模型Florence,单模态超大模型外加MIM预训练 setting下的Swinv2-G。

四、 结论和展望

我们做RevCol的第一个版本的目标,是在纯视觉任务下验证了它 scale up的 capability。但CV大模型不应只是局限于分类检测等等这些任务下,尤其是ChatGPT现象级爆火之后,CV的大模型未来在哪里更加值得深思。我们能看到的潜在未来,包括视频理解,多模态模型,生成模型,自动驾驶等等。我们坚信RevCol这样一套宏观架构是能够普遍使用的,CV的未来在哪里,RevCol就在哪里,顶峰相见吧各位。

【附, 一些个人的想法和Rebuttal中提到的问题】

1. 点还能再高点吗?

能,不过这个问题要从两方面来看。

1. 我做Reversible的初衷并不是提供一个、跑的贼快点贼高即插即用帮助大家打比赛的模型的,所以我没把天赋点点在刷点上面。我设计tiny small base的时候多少有点秀的成分,就用一个column不停地叠4次8次16次,显然这是sub-optimial的。真想快,全部都搞成4column,我试过,上下游基本都不掉点。

2. Reversible的优势其实不是上游涨点,想法,通过Reversible保留信息的时候保留了一大堆对分类没用的东西,能不掉点就很不错了,下游才是Reversible的关键。尤其是大模型。

2. 来自GLOM的启发

其实最初这篇work的idea来自Hinton的一篇文章"How to represent part-whole hierarchies in a neural network.",但我们着实没按照他的想法做(因为不work)。

Hinton这篇文章里主要说:Features of a neural network should embed the visual input from object level to scene level gradually. In other words, representation lies in a disentangled manner, from the part to the whole. GLOM use a time sequence based method to gradually settle down the representation in feature dimensions, which is also referred as formulating islands.

我们能想到multi-column的架构就是借鉴了Hinton的GLOM,通过将输入反复接入网络,将静态图片统一到视频处理的框架中,方便FFA等local training策略的引入。但是,我们没有通过contrastive learning的方式来逐column逐level监督,使得feature在不同column间呈现disentangle的特点,而是通过Reversible的设计,使得feature在最后一个column的时候不同level间解耦。后来Hinton提出的Forward-Forward Algorithm延续了GLOM的设计,但GLOM这套架构离真正使用依然还有很长的路要走。我们算是做了个更加贴近现实的版本。

3. 会有V2吗?

会。我们欢迎各种新的想法,也欢迎各路有志之士。

在这里由衷感谢我的队友们Yizhuang Zhou, Qi Han(知乎:@Qi Han), Jianjian Sun, Xiangwen Kong, Jun Li, 还有老板:Xiangyu Zhang(知乎:@Xiangyu Zhang)

感谢已故恩师孙剑老师。

作者:蔡宇轩

文章来源:知乎@蔡宇轩 https://zhuanlan.zhihu.com/p/606594311

免责声明:作者保留权利,不代表本站立场。如想了解更多和作者有关的信息可以查看页面右侧作者信息卡片。
反馈
to-top--btn