​ICLR 2023 | 分布外泛化(OOD)中的优化困境

2023-07-03 13:23 855 阅读 ID:1207
将门
将门

随着深度学习模型的应用和推广,人们逐渐发现模型常常会利用数据中存在的虚假关联(Spurious correlation)来获得较高的训练表现。但由于这类关联在测试数据上往往并不成立,因此这类模型的测试表现往往不尽如人意 [1]。其本质是由于传统的机器学习目标经验损失风险最小化(Empirical Risk Minimization,ERM)假设了训练测试集的独立同分布特性,而在现实中该独立同分布假设成立的场景往往有限。

在很多现实场景中,训练数据的分布与测试数据分布通常表现出不一致性,即分布偏移(Distribution shifts),旨在提升模型在该类场景下性能的问题通常被称为分布外泛化(Out-of-Distribution Generalization OOD;或域泛化,Domain Generaliziation)问题。近年来,尽管围绕分布外泛化的研究取得了一定的进展和突破,但大量的证据表明,在严格的实验设置和真实世界的分别偏移下,现有的分布外泛化算法却常常难以超越传统的经验损失风险最小化 [2,3,4]。

我们发现,这一现象与分布外泛化中的优化困境(Optimization Dilemma)不可谓毫无关系。现有工作往往关注于提出更好的优化目标来约束 ERM 的学习,却忽视了其新的优化目标对优化过程带来的影响,以及随之而来的在模型选择(Model Selction)上的挑战。

为此,我们从多目标优化(Multi-Objective Optimization,MOO)的角度重新回顾了分布外泛化的优化过程,揭示了其优化困境的成因,并为此提供了新的优化方案。论文已在 ICLR 2023 发表,并将在 ICLR 域泛化研讨会上作口头报告(oral presentation at ICLR DG Workshop)。本工作由香港中文大学、腾讯 AI Lab 以及香港浸会大学合作完成。

论文链接:https://openreview.net/forum?id=esFxSb_0pSL
项目代码:https://github.com/LFhase/PAIR  

一、分布外泛化中的优化困境

分布外泛化(或域泛化)旨在让模型从一个或多个来自具有不同分布的环境(Environment;或域,Domain)的数据集中学习到具有稳定预测能力的关系,以使得模型能够泛化到来自不同于训练分布之外的分布中。由于传统的 ERM 只关注学习数据中存在的相关性,并不区分其是否为虚假关联,因此 ERM 学到的模型往往难以做到良好的分布外泛化。

为了解决这一问题,现有工作通过设计不同的优化目标(OOD目标)对 ERM 的训练过程进行约束,使其能够学习到环境之间的不变性。通常,其优化形式可以表示为:

1.1 优化目标的放缩

                                    ▲ 图1. IRM需要经过两次放缩到IRMv1才能在实践中使用.
                                                        ▲ 图2. IRMv1与IRM存在巨大的差异.

1.2 优化过程的矛盾

优化困境的另一方面则是公式(1)中,OOD 目标与 ERM 目标协同优化过程中的矛盾。直观上来说,ERM 关注学习数据中存在的相关性,而 OOD 优化目标则关注学习数据中存在的因果关联,这两个优化目标通常存在矛盾。

                                                  ▲ 图3. OOD与ERM优化目标普遍存在矛盾.

我们测量了几种主流 ERM 与 OOD 优化目标产生的梯度的余弦相似度,如上图 3 中展示了优化过程初期 ERM 与 OOD 优化目标梯度余弦相似度的平均值,横轴表示使用的不同的λ 值。可以看出,ERM 与 OOD 优化目标普遍存在矛盾。优化目标的矛盾则会进一步导致一些潜在的具有良好 OOD 泛化能力的解无法通过公式(1)的线性加权方式到达 [8]。

                                                       ▲ 图4. OOD权重需要精细的调参.

此外,由于 ERM 与 OOD 优化目标矛盾的存在,公式(1)中权重λ往往需要精细的调参才能获得一个不错的解。如上图4展示了在经典的 ColoredMNIST 数据集中 [5],主流的几类 OOD 优化目标对超参数设置较为敏感,当 λ 过大时,OOD 优化目标会阻止学习所有的关联,而当 λ 过小时,哪怕以潜在完美的 OOD 解作为初始化,OOD 优化目标也难以维持完美解。

1.3 模型选择的困境

因为 ERM 与 OOD 优化目标的矛盾,对 OOD 泛化中的模型选择也带来了额外的挑战 [2]。现有的解决方案通过指定验证集并根据验证集的模型表现来进行模型选择。其中,验证集既可以来自于一个和训练环境较为接近的分布,也可以来自于一个训练和测试环境之外的分布,亦可以来自于和测试环境较为接近的分布。

不同验证集的设定会对最终模型的表现带来很大的影响。往往,训练得到的模型会对接近训练分布的数据表现较好而对离训练分布较远的其他分布表现较差,如根据一个接近训练分布的验证集进行选择,往往会使得模型在分布外的测试集的表现较差,而如果根据一个远离训练分布的验证集进行模型选择,则会使得最终得到的模型在训练集上表现较差。

二、帕累托不变风险最小化

总的来说,分布外泛化中的优化困境一方面在于因优化目标过于复杂导致的放缩,另一方面在于 ERM 与 OOD 优化目标的矛盾带来优化过程难度的上升。特别地,上述的优化困境给分布外泛化提出了一个十分具有挑战性的问题:

如何处理 ERM 与 OOD 目标的优化困境并得到需要的 OOD 解?

2.1 帕累托视角下的优化困境

既然公式(1)中通过线性加权得到的单目标优化范式面临种种困境,那么很自然地,我们可以将公式(1)转换到多目标的视角:

                                                               ▲ 图5. IRMv1的帕累托前沿.

综上,我们可知,IRMv1 的失败主要是因为其多次放缩导致所需要的解离开了其帕累托前沿,因此无论使用何种精巧的优化方式都难以得到所需要的 OOD 解。

2.2 优化目标提升

为了解决上述问题,我们首先需要做的就是提升优化目标组合的 OOD 鲁棒性,使得所需要的 OOD 解至少落在对应优化目标组合的帕累托前沿上。为此,我们进一步考虑 IRM 分布外泛化能力的源来。

                                                               ▲ 图6. 提升IRMv1的外推能力.

Bottou et al., 在解释 IRM 分布外泛化能力的时候,提出 IRM 的解既是各个训练环境分布中 ERM loss 内插(Interpolation)组合的驻点(Stationary point),也是外插(Extrapolation)的驻点 [10]。根据先前的 IRMv1 失败的例子,我们知道 IRMv1 所需要的放缩会削弱 IRM 的外推能力。

为了弥补放缩带来的缺陷,很自然地,我们提出引入 VREx 优化目标 [11] 来直接提升训练环境 ERM loss 一定外推区域的泛化能力,并最终得到一个更鲁棒的 OOD 优化目标组合 IRMX:

                                                  ▲ 图7. IRMX有效提升了IRMv1的OOD能力.

如图 7 所示,理论上,我们证明了 IRMX 可以解决任意的 Two-bit Environment 问题 [7],包括 IRMv1 失败的例子。此外,我们在论文附录 C.2 提供了关于 VREx 缺陷的讨论,有兴趣的读者可以参考我们的论文。

2.3 优化过程提升

尽管 IRMX 可以解决优化困境中的优化目标鲁棒性偏弱的问题,但由于额外引入的优化目标,会导致 IRMX 的优化过程更加困难:

进一步地,上述对 ERM 以及 OOD 优化的权衡启发我们提出一种新的模型选择方案 PAIR-s。不同于现有 [2] 中讨论的根据指定验证集的 OOD 模型选择方案,我们考虑充分利用各个优化目标的归纳偏置(Inductive bias),简单地,选择一个能尽可能满足给定目标偏好的模型。实验中,我们发现考虑 ERM 和 OOD 优化权衡的模型选择方案可以充分缓解 OOD 泛化中模型选择的困境。

我们将整个解决方案统称为 PAIR,即帕累托不变风险最小化(Pareto Invariant Risk Minimization)。

三、实验与讨论

在实验中,我们使用了合成数据集,真实数据上的多种合成分布偏差,以及真实场景的分布外泛化数据集对 PAIR 进行充分的测试和验证。

3.1 因果不变性还原测试

首先,我们对 PAIR 解决 IRM 到 IRMv1 的目标表现差异还原能力进行测试。具体地,我们使用一个线性回归任务来测试 PAIR 对 IRM 定义的因果不变性的还原能力 [5,7]。

                                                               ▲ 图8. 因果不变性还原测试.

整个任务的设置以及其结果如上图 8 所示。输入的特征包含横轴和纵轴的值,而目标只取决于横轴的值。给定两个高斯分布采样得到的训练,一个满足因果不变性的模型应该在两个训练环境重叠的横轴区域,即[-2,2],能够正确识别其中的不变特征,即横轴值,进行预测,其预测结果形成的带应该在[-2,2],垂直于横轴。

如上图所示,哪怕经过充分的调参,IRMv1,VREx 以及 IRMX 都无法还原因果不变性。而经过 PAIR 加持后(IRMX+PAIR-o),模型可以充分还原所需要的因果不变性,以此尽可能弥补从 IRM 到 IRMv1 由于目标放缩导致的 OOD 泛化能力的削弱。

  3.2 真实数据集上合成分布偏移下的表现

                                                          ▲ 图9. DomainBed模型选择实验.

随后,我们也在 ColoredMNIST 的不同变体上进行了验证性实验。如图 9 所示,其中,CMNIST 对应着原始的 ColoredMNIST 设定,而 CMNIST-m 则对应先前讨论的 IRMv1 失败案例。我们测试了三种 PAIR-o 的变体,其主要原因是 PAIR-o 需要计算模型参数的梯度进行优化过程的再平衡,而实际应用中的模型往往会具有较大的参数,获取大模型的梯度往往需要消耗大量的计算开销。

为此,我们测试了采用模型不同部分梯度的 PAIR-o 表现。可以看到,IRMv1 如预期般在 CMNIST-m 中只取得与 ERM 相当的表现。尽管 VREx 在 CMNIST 以及 CMNIST-m 中表现良好,IRMX 表现却可能比 IRMv1 或者 VREx 更差。在使用 PAIR-o 后,IRMX 在 CMNIST 以及 CMNIST-m 上都取得了更好的表现。

有趣的是,PAIR-o 只使用最后一层分类器 ω 的梯度能获得比使用特征提取器φ 或全部参数 ϝ更好的表现。因此,在后续真实世界数据集实验中,我们在 PAIR-o 中只使用分类器梯度进行优化过程的平衡。

3.3 真实世界分布偏移下的表现

进一步地,我们在真实世界测试基准 Wilds 的 6 个数据集中测试了 PAIR-o 的表现,在 DomainBed 的 3 个数据集中测试了 PAIR-s 用于各类主流 OOD 目标的模型选择实验。

                                                       ▲ 图10. Wilds真实分布外泛化数据集实验.

  如上图 10 所示,在真实世界数据集中,我们也可以观察到与图 9 类似的现象。尽管 IRMX 由于其相较于 IRMv1 或者 VREx 额外的优化难度导致性能的下降,通过 PAIR-o 加持后,其可以达到比 IRMv1 以及 VREx 更好的性能,并且达到 6 个真实世界数据集中综合性能第一。

                                                             ▲ 图11. DomainBed模型选择实验.

此外,我们还在经典的 OOD 模型选择基准 DomainBed 上对 PAIR-s 进行测试。可以看到,对于几类主流的 OOD 优化目标,PAIR-s 选择得到的模型都能带来增益。尤其在虚假关联较为严重的 ColoredMNIST 上,简单地使用 PAIR-s 进行模型选择可以带来超过 10% 的性能增益,充分说明考虑 ERM 以及 OOD 优化权衡对于 OOD 泛化中模型的重要性。

我们还进行了大量的验证性实验对 PAIR 的效果进行探究,包括 PAIR 的优化过程,PAIR 对偏好的敏感性,PAIR 最终选择模型的好坏,以及 PAIR 用于更多的 OOD 目标的实验,欢迎感兴趣的读者参考我们论文的实验部分以及附录F。

四、总结及展望

本文从多目标优化的角度,讨论了分布外泛化中的优化困境,并提出了新的优化方案以及模型选择方案,为 OOD 时代的模型优化设计提供了新的思路。展望未来,基于 PAIR,我们可以探究自动化以及更精准的优化目标偏好学习,以及对多目标优化随机梯度噪声更鲁棒的优化器设计,优化过程的加速以及对于大模型来说更高效的优化方案。

参考文献

[1] CausalAdv: Adversarial Robustness through the Lens of Causality, ICLR 2022.

[2] In Search of Lost Domain Generalization, ICLR 2021.

[3] WILDS: A Benchmark of in-the-Wild Distribution Shifts, ICML 2021.

[4] GOOD: A Graph Out-of-Distribution Benchmark, NeurIPS 2022 D&B Track.

[5] Invariant Risk Minimization, arXiv 2020.

[6] https://zhuanlan.zhihu.com/p/567666715.

[7] Does Invariant Risk Minimization Capture Invariance? AISTATS 2020.

[8] Convex Optimization. Cambridge University Press, 2014.

[9] Rich Feature Construction for the Optimization-Generalization Dilemma, ICML 2022.

[10] Learning representations using causal invariance https://leon.bottou.org/talks/invariances.

[11] Out-of-distribution generalization via risk extrapolation (rex), ICML 2021.

[12] Multiple-gradient descent algorithm (mgda) for multiobjective optimization. Comptes Rendus Mathematique, 350(5):313–318, 2012.

[13] Multi-task learning with user preferences: Gradient descent with controlled ascent in pareto optimization, ICML 2020.

本文来源:公众号【PaperWeekly】

免责声明:作者保留权利,不代表本站立场。如想了解更多和作者有关的信息可以查看页面右侧作者信息卡片。
反馈
to-top--btn