NeurIPS 2022|知识蒸馏想要获得更好的性能?那就来一个更强的教师模型吧!

2022-12-20 21:07 672 阅读 ID:616
机器学习算法与自然语言处理
机器学习算法与自然语言处理

论文名称: Knowledge Distillation from A Stronger Teacher

在知识蒸馏中发挥更大教师模型和更强训练策略的作用(NeurIPS 2022)

论文地址:https://arxiv.org/pdf/2205.10536.pdf

1.『知识蒸馏背景和本文动机』

在深度神经网络性能提升的过程中,模型通常会变得更深更宽。然而,由于计算和内存资源的限制,这种沉重的模型在实际应用中部署起来比较笨拙。

知识蒸馏是指:通过在训练过程中蒸馏更大的模型 (教师) 的知识来提高小模型 (学生) 的性能。

知识蒸馏的本质在于如何将知识从教师模型提炼到学生模型里面。最直观有效的方法是通过 Kullback-Leibler (KL) 散度[1]来匹配教师和学生之间的预测分数。一般来讲,KL 散度使得在训练过程中,可以用更有信息量的监督信息指导学生模型的训练,以期望获得更好的性能。

题目中 "更强的教师模型",有两个含义:尺寸更大,数据增强策略更先进

获得更好的知识蒸馏性能的方式之一是尝试不同类型的教师模型 (比如说使用更大的教师模型或者更强的训练策略),作者在本文中认为:应该借助 "更强的教师模型" 进行知识蒸馏。而针对什么是 "更强的教师模型",作者推广实验给出了一些建议:

  • 除了扩大模型规模,还可以通过先进的训练策略,如标签平滑和数据增强 (label smoothing and data augmentation),以获得更强的教师模型。但是仅仅有这些是不够的。配备了更强的教师模型之后,学生模型在正常 KD 下的表现可能会下降,甚至性能还不如不用 KD。

为什么是这样呢?作者觉得:

  • 当将教师和学生的训练策略转换为更强的训练策略时,教师和学生之间的差异往往会变得相当大。在这种情况下,通过 KL 散度来精确恢复预测可能具有挑战性,并导致 KD 的失败。

所以,作者在本文的动机是:

  • 保留教师和学生模型之间的预测关系非常重要。在将知识从 teacher 传给 student 时,我们其实真正关心的是教师模型的偏好 (预测的相对 Rank),而不是去恢复其预测结果的绝对值。教师预测与学生预测之间的相关性有利于放松 KL 散度的精确匹配,提取内在关系 (intrinsic relations)。

2.『知识蒸馏中的匹配问题』

前面提到,当将教师和学生的训练策略转换为更强的训练策略时,教师和学生之间的差异往往会变得相当大。在这种情况下,通过 KL 散度来精确恢复预测可能具有挑战性,并导致 KD 的失败。

如下图1所示,分别是τ为1和τ为4时,使用训练策略 B1 和 B2 直接训练 ResNet18 和 ResNet50 的结果,结果标记为 (R18B1, R18B2, R50B1, R50B2),训练策略的详细说明如图2所示。

                    图1:ImageNet 验证集上的不同训练策略下的 ResNet 预测结果的差异 (KL 散度)
                                                                  图2:详细的训练策略

根据图1结果有如下观察:

  1. 与 ResNet-50 相比,ResNet-18 的输出在更强的训练策略下变化不大。这意味着学生模型的表征能力限制了其性能。同样的道理,当学生模型和教师模型的差异变得足够大时,使学生模型要完全匹配教师模型的输出往往具有挑战性。
  2. 当采用较强的策略训练师生模型时,师生之间的差异会更大。这说明当我们在更强的训练策略下采用 KD 时,KD 损失和常规分类损失的不一致会更严重,从而干扰学生的训练。

所以总的结论就是:当训练策略变得更强时,教师和学生差异拉大,传统的 KL 散度的精确匹配的模式 (即当且仅当教师和学生的输出完全相同时,损失达到最小值) 就显得过于苛刻。所以作者在本文的直觉是搞一种轻松的方式来匹配老师和学生之间的预测。希望在知识蒸馏的过程中,student 不用费劲地去匹配 teacher 的精确的输出结果,而是去匹配真正有用的东西。

3『DIST 中的类间匹配』

预测分数代表了一个模型对于所有类别的置信度,那么如果希望教师模型和学生模型的输出以一种更加轻松的方式匹配,作者认为只需要把各个类别输出置信度的关系匹配好就可以了,也就是去匹配教师模型预测结果的相对顺序。具体如下:

因此,一个简单而有效的映射方式是正线性变换,它就满足松弛匹配的要求:

4.『DIST 中的类内匹配』

除了类间的关系,作者还考虑到了单个样本中类内的关系。作者认为每个类的,多个实例的,预测分数也是有用的。

举个例子, 我有三张图片, 里面的内容分别是 "猫", "狗", 和 "飞机", 它们对于 "猫" 这个类别的预测结果分别是e,f,g。那么, 猫图的对应 "猫" 类别的预测结果应该是最大的, 飞机那个图对应 "猫" 类别的预测结果应该是最小的。所以,e>f>g的关系也应当由教师模型传给学生。

换个例子, 我有两张图片, 里面的内容分别是 "猫1", "猫2", 它们对于 "猫" 这个类别的预测结果分别是e1,e2。如果猫1图的对应 "猫" 类别的预测结果较大, 猫1图的对应 "猫" 类别的预测结果较小, 那么e1>e2关系也应当由教师模型传给学生。

受到以上思想的启发,作者也觉得应该进行类内匹配:

因此,如下图3所示,本文最终的蒸馏策略是:

式中,α,β ,γ是超参数。作者通过类间匹配和类内匹配的方式,赋予了学生或多或少的自适应匹配教师网络输出的自由,从而在很大程度上提高了蒸馏性能。

                                                               图3:DIST 方法的蒸馏策略

PyTorch 伪代码:

import torch.nn as nn

def cosine_similarity(a, b, eps=1e-8):
    return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)

def pearson_correlation(a, b, eps=1e-8):
    return cosine_similarity(a - a.mean(1).unsqueeze(1), b - b.mean(1).unsqueeze(1), eps)

def inter_class_relation(y_s, y_t):
    return 1 - pearson_correlation(y_s, y_t).mean()

def intra_class_relation(y_s, y_t):
    return inter_class_relation(y_s.transpose(0, 1), y_t.transpose(0, 1))

class DIST(nn.Module):
    def __init__(self, beta, gamma):
        super(DIST, self).__init__()
        self.beta = beta
        self.gamma = gamma

    def forward(self, z_s, z_t):
        y_s = z_s.softmax(dim=1)
        y_t = z_t.softmax(dim=1)
        inter_loss = inter_class_relation(y_s, y_t)
        intra_loss = intra_class_relation(y_s, y_t)
        kd_loss = self.beta * inter_loss + self.gamma * intra_loss
        return kd_loss

5.『实验结果』

ImageNet-1K 实验结果

根据前文的描述,**"更强的教师模型",有两个含义:尺寸更大,数据增强策略更先进**。因此作者先从第一个维度进行实验,换更强的教师模型,结果如下图4所示。当教师规模较大时,student 模型 ResNet-18 的表现甚至比中等规模的 teacher ResNet-50 更差。但是,DIST 在大 teacher 模型的情况下呈现上升趋势,并且与 KD 相比改进也变得更加显著,这表明 DIST 更好地处理了学生和大教师之间的巨大差异。

                                                图4:更强的教师模型 ImageNet 实验结果

最近,通过复杂的训练策略和强大的数据增强,ImageNet 上的模型性能得到了显著提高 (TIMM 在 ResNet-50 上达到了 80.4% 的精度,而基线策略 B1 仅获得 76.1% 的精度)。但是,目前大多数 KD 方法仍然在简单的训练设置下进行实验,作者进行了高级训练策略 KD 的实验,并将我们的方法与原始 KD 方法进行了比较。

作者首先用强策略训练 ResNet-50,得到的准确率为 80.1%,如下图5所示。当学生模型使用 ResNet-18, ResNet-34,或者 MobileNetV2, EfficientNet-B0 时,DIST 都能达到最佳性能。DIST 在 Swin Transformer 上也能获得很好的效果。

                                                 图5:更强的训练策略 ImageNet 实验结果

目标检测实验结果

作者使用 MS COCO 数据集,同时将 DIST 作为一种类别的额外的监督信号,ResNeXt-101 主干网络 + Cascade Mask R-CNN 检测头作为教师模型,ResNet-50 主干网络 + Faster R-CNN 检测头作为学生模型;ResNeXt-101 主干网络 + RetinaNet 检测头作为教师模型,ResNet-50 主干网络 + RetinaNet 检测头作为学生模型,结果如下图6所示。DIST 在 COCO 验证集上取得了很有竞争力的结果。为了进行比较,作者在与 DIST 相同的设置下训练原始 KD,通过简单地替换损失函数,显著优于原始 KD 方法。

                                                            图6:COCO 目标检测实验结果

语义分割实验结果

作者还对语义分割这一具有挑战性的密集预测任务进行了实验,在 Cityscape 数据集上用 ResNet-18 骨干模型训练 DeepLabV3 和 PSPNet,对分类头的预测采用 DIST 蒸馏策略,教师模型使用 DeepLabV3 的 ResNet101 骨干模型。结果如下图如表7所示,仅在监督类预测的情况下,DIST 在语义分割任务上明显优于现有的知识蒸馏方法。DIST 比最近最先进的方法 CIRKD 在 PSPNet-R18 上的性能高出 1.58%,证明了 DIST 在关系建模方面的有效性。

                                                        图7:Cityscapes 语义分割实验结果

6.『总结』

本文来源于一个观察,即:当教师模型的体积增大时进行知识蒸馏过程,学生模型的性能提升并不明显;当使用更强的训练策略 (数据增强) 进行知识蒸馏过程,学生模型的性能提升也不明显。这启发作者的思考,可能是知识整理过程中经常使用的 KL 散度,这种精确匹配的模式 (即当且仅当教师和学生的输出完全相同时,损失达到最小值) 就显得过于苛刻。所以作者在本文的直觉是搞一种轻松的方式来匹配老师和学生之间的预测。所以本文提出 DIST,一种包含了类内关系和类间关系的蒸馏方法,在图像识别,目标检测和语义分割任务中均取得了不错的表现。

参考

  1. ^Distilling the Knowledge in a Neural Network
  2. ^K. Pearson. Vii. mathematical contributions to the theory of evolution.—iii. regression, heredity, and panmixia
免责声明:作者保留权利,不代表本站立场。如想了解更多和作者有关的信息可以查看页面右侧作者信息卡片。
反馈
to-top--btn