ICML 2023 | Test time adaptation的理论理解与新的方法

2023-06-14 21:59 568 阅读 ID:1151
将门
将门

Domain adaptation(DA: 域自适应),Domain generalization(DG: 域泛化)一直以来都是各大顶会的热门研究方向。DA假设我们有有一个带标签的训练集(源域),这时候我们想让模型在另一个数据集上同样表现很好(目标域),利用目标域的无标签数据,提升模型在域间的适应能力是DA所强调的。以此为基础,DG进一步弱化了假设,我们只有多个源域的数据,根本不知道目标域是什么,这个时候如何提升模型泛化性呢?传统DG方法就是在源域finetune预训练模型,然后部署时不经过任何调整,核心在于如何利用多个源域带来的丰富信息。

然而一些文献表明,在不利用目标域信息的情况下实现很难实现泛化到任意分布这一目标。为了解决这一问题,测试时间自适应(TTA)方法被提出并得到了广泛研究,然而

1. 现有的 TTA 方法在推理阶段需要离线目标数据或更复杂的优化过程,如下图所示,各种 TTA 的方法,要么需要根据测试样本重新训练模型,要么需要更新模型的部分参数,或者需要额外的分支。

2. 绝大多数方法没有一个理论上的验证甚至是直觉。本文介绍我们发表于 ICML 2023 的文章《AdaNPC: Exploring Non-Parametric Classifier for Test-Time Adaptation》,感谢来自北大,Meta,阿里达摩院,普林斯顿的合作者们。

本文提出了非参数化测试时间自适应的方法,不需要任何的梯度更新。在此基础上,我们从理论上验证了该框架的有效性,说明了通过引入测试样本信息,我们能够取得更好的泛化效果。据我们所知,这也是第一篇对 TTA 进行理论分析的工作。

论文链接:https://arxiv.org/abs/2304.12566
代码链接:https://github.com/yfzhang114/AdaNPC  

一、Motivations

在最近的研究中,人们发现在没有在推理期间利用目标样本的情况下,使模型对任何未知分布具有鲁棒性几乎是不可能的。测试时自适应(TTA)方法近期受到了广泛关注,以利用具有计算可行性约束的目标样本。然而,当前的 TTA 方法存在几个缺点。

1. 计算开销:现有的TTA方法需要批处理目标数据进行梯度更新和/或一个额外的模型进行微调,这在目标样本以在线方式一个接一个到达时是不可接受的。

二、Method

为此,我们提出了一种名为 AdaNPC 的非参数适应方法。如下图所示

1. 训练阶段,我们依然可以使用 ERM,CoRAL 等算法进行训练,我们的目标是获得更好的 representation,因此我们的框架和目前绝大多数 DG 方法都是正交的,他们学到的 representation 越好,我们 TTA 的效果也会越好。

2. 测试阶段,我们只需要将模型最后的 Linear 分类器替换成 KNN。在 test time 的时候,我们将所有 training sample 的特征存入一个 memory bank 中,分类器每次 infer K 个最像的样本然后根据他们的标签生成最终的结果。

3. 模型自适应:这一步更加简单,我们只需要将目标域样本的特征和 pseudo label 存入 memory,就完成了整个 TTA 的流程。这一过程不涉及任何梯度反传,模型优化等,因此更为简单和高效。

除此之外,我们还介绍了一些可用的,但不是关键设计的 trick:

2. BN retraining:非参数分类器的性能高度依赖于模型表示,为了获得更强大的表示并保持 AdaNPC 的简单性,我们可以选择在分类器之前添加一个 BN 层。然后在评估时,通过最小化预测熵,只重新训练 BN 层参数。

三、理论分析

这里我们要做的事情有两个:

1. 我们从理论上验证使用 KNN 作为分类器可以显式地减少域散度。

2. 加入目标域样本,即非参数化的 Test-Time 自适应,将进一步减少目标域的期望误差。

本文不会跳入证明细节,只是简单的提供基本的 intuition 和最终的结果,用到的假设涉及到对分布 Density,测度空间,函数平滑性等内容的一些常用假设,基本上可以算作常用假设。

3.1 非参数化(KNN)分类器能够显式减小domain divergence

3.2 AdaNPC 通过引入目标域无标签样本进一步减小目标域期望损失

在本节中,我们开发了基于 covariate-shift 和 posterior-shift 设置下的目标域 excess error 的上届,这进一步阐明了影响我们算法性能的所有因素和使用在线目标数据的好处。首先,所谓的 excess error,实际上就是给定分类器和贝叶斯最优分类器的 error 之差(二分类)。

有了这个定义我们就可以得到如下结论(下面的Proposition 2.),

最后,本文提出的 AdaNPC 是一种特殊的 Test-Time 自适应方法,它可以利用在线目标样本来提高预测泛化。接下来,我们从理论上验证,通过将在线目标样本纳入 KNN 存储库,可以进一步减小 excess error。

我们想要强调的是,这个错误边界比不更新内存库的情况更加 tight,也就是说,通过将新的 test instances 引入 memory,我们得到了更好的 error bound。

四、实验验证

4.1 AdaNPC在域泛化,鲁棒性等benchmark上都取得SOTA

这里的 AdaNPC 即使用了 BN retraining 的策略。

除此之外,我们发现,当 batch size 非常小时,现有方法(基于梯度更新或者其他参数更新方法)往往会产生负面影响,因为单个样本的梯度噪声非常高,这不利于模型优化。然而,应该强调的是,批数据(batch of online data)不符合在线学习的设置,在线学习需要按需推理而不是等待一批一批的数据传入,或者当推理发生在边缘设备(如手机)上时,没有机会进行批处理。因此,AdaNPC 这种对批量大小不敏感的 TTA 方法对当前的研究领域具有重要价值。

4.2 AdaNPC克服了灾难性遗忘,有很强的知识可扩展性

以 Rotated MNIST 数据集为例,下图显示,最新的 TTA 方法,即 T3A 和 Tent,在进行测试时间适应的情况下,其性能略高于甚至低于 ERM 基线。相反,AdaNPC 记住了所有 adaptation 过程中的信息,因此取得了惊人的效果。也正是因为这个原因,AdaNPC 在源域的效果不会随着自适应的进行而变差,这是相对于现有算法的另一个突出优势。

4.3 AdaNPC:不需要任何源域训练的域泛化

下图显示了直接在目标域上评估预训练模型的结果,而没有在源域进行任何微调,对于 AdaNPC,只是将源域特征进行了存储。在 PACS 数据集上(下图a)和 Rotated MNIST 数据集上(下图b)使用 MLP 分类器的平均泛化性能低于25%,即使使用强大的主干(ViT-L16)。相反,使用 KNN 分类器可以达到平均泛化精度71.4%。如今,由于预训练模型的规模不断增长,微调通常在计算上是昂贵的。AdaNPC 要求的不是基于梯度的更新,而是外部大容量存储来存储用于图像分类的知识,例如图像特征图,这为利用预训练的知识提供了一个新的有前途的方向。此外,随着源域实例数量的增加,下图(c)表明 AdaNPC 获得了更好的性能,这验证了我们的理论结果。

4.4 AdaNPC有较强的可解释性,允许引入专家信息

下图显示了 KNN 分类器如何使用源域的知识。决策过程将不再是一个黑匣子。例如,下图(b)中的长颈鹿被分类为低置信度,因为它最近的邻居是大多数具有相似姿势的人或狗。也就是说,encoder representation 忽略了一些重要特征,例如面部形状。然而,这些特征很容易被人类识别;因此,当我们得到低置信度预测时,AdaNPC 允许我们手动删除一些明显错误的邻居。在这种情况下,我们的分类结果将更加准确和自信,这对于高风险任务来说很有希望结合专家知识以获得更好的分类结果。

五、总结与未来工作

该论文提出了一种新的域泛化测试时间自适应方法,AdaNPC,它引入了一个非参数分类器,即 KNN 分类器,用于预测和自适应。与需要模型更新且容易忘记先前知识的当前领域泛化或测试时间自适应方法不同,所提出的方法是无参数的并且可以记住所有知识,使得AdaNPC适用于实际设置,特别是模型需要适应一系列域的时候。

我们推导出协变量偏移和后验概率偏移设置下的误差界限,其中 AdaNPC 理论上显示能够减少看不见域的目标误差。此外,AdaNPC 具有更快的收敛性、更好的可解释性和强大的知识可扩展性。更重要的是,AdaNPC 无需对源域进行任何微调即可实现高泛化精度,这为利用规模不断增长的预训练模型提供了一个有前途的方向。

一个可能的 concern 是 AdaNPC 需要进行 dense vector searching,以及需要存储大量的源域特征,对显存/内存有较高要求。我们在文中实验部分对这些 concern 进行了解答,目前的搜索速度和内存要求和现有算法至少是可以相比较或者更快的。但是我们在未来工作中也会考虑如何更加简化 memory 的构造,加速整个框架的推理时间。

作者:张一帆

公众号:PaperWeekly

免责声明:作者保留权利,不代表本站立场。如想了解更多和作者有关的信息可以查看页面右侧作者信息卡片。
反馈
to-top--btn