下图展示了模型在领域数据随时间发生旋转和膨胀时的泛化表现。通过在一些随机时间点(蓝色标记点)的观测,模型可以在任意时刻生成适用的神经网络,其决策边界始终与数据分布保持协调一致。
01
摘要
在实际应用中,数据集的数据分布往往随着时间而不断变化,预测模型需要持续更新以保持准确性。时域泛化旨在预测未来数据分布,从而提前更新模型,使模型与数据同步变化。
然而,传统方法假设领域数据在固定时间间隔内收集,忽视了现实任务中数据集采集的随机性和不定时性,无法应对数据分布在连续时间上的变化。此外,传统方法也难以保证泛化过程在整个时间流中保持稳定和可控。
为此,本文提出了
连续时域泛化
任务,并设计了一个基于模型动态系统的时域泛化框架 Koodos,使得模型在连续时间中与数据分布的变化始终保持协调一致。Koodos 通过库普曼算子将模型的复杂非线性动态转化为可学习的连续动态系统,同时利用先验知识以确保泛化过程的稳定性和可控性。
实验表明,Koodos 显著超越现有方法,为时域泛化开辟了全新的研究方向。
02
论文信息
论文链接:
https://arxiv.org/pdf/2405.16075
开源代码:
https://github.com/Zekun-Cai/Koodos/
OpenReview:
https://openreview.net/forum?id=G24fOpC3JE
我们在代码库中提供了详细的逐步教程,涵盖了 Koodos 的实现、核心概念的解读以及可视化演示:
https://github.com/Zekun-Cai/Koodos/blob/main/Tutorial_for_Koodos.ipynb
整个教程流程紧凑,十分钟即可快使掌握 Koodos 的使用方法,力荐尝试!
03
情景导入
在实际应用中,训练数据的分布通常与测试数据不同,导致模型在训练环境之外的泛化能力受限。
领域泛化
(Domain Generalization, DG)作为一种重要的机器学习策略,旨在学习一个能够在未见目标领域中也保持良好表现的模型。
近年来研究人员发现,在动态环境中,领域数据(Domain Data)分布往往具有显著的时间依赖性,这促使了
时域泛化
(Temporal Domain Generalization, TDG)技术的快速发展。
时域泛化将多个领域视为一个时间序列而非一组独立的静态个体,利用历史领域预测未来领域,从而实现对模型参数的提前调整,显著提升了传统 DG 方法的效果。
然而,现有的时域泛化研究集中在“
离散时间域
”假设下,即假设领域数据在固定时间间隔(如逐周或逐年)收集。基于这一假设,概率模型被用于预测时域演变,例如通过隐变量模型生成未来数据,或利用序列模型(如 LSTM)预测未来的模型参数。
然而在现实中,
领域数据的观测并不总是在离散、规律的时间点上,而是随机且稀疏地分布在连续时间轴上
。例如,图 1 展示了一个典型的例子——基于推文数据进行社交媒体舆情预测。
与传统 TDG 假设的领域在时间轴上规律分布不同,实际中我们只能在特定事件(如总统辩论)发生时获得一个域,而这些事件的发生时间并不固定。同时,
概念漂移
(Concept Drift)在时间轴上发生,即领域数据分布随着时间不断演变:如活跃用户增加、新交互行为形成、年龄与性别分布变化等。
理想情况下,每个时态域对应的预测模型也应随时间逐渐调整,以应对这种概念漂移。最后,由于未来的域采集时间未知,我们希望可以泛化预测模型到未来任意时刻。
▲ 图1:连续时域泛化示意图。图中展示了通过推文训练分类模型进行舆情预测。其中训练域仅能在特定政治事件(如总统辩论)前后采集。我们希望通过这些不规律时间分布的训练域来捕捉分布漂移,并最终使模型能够推广到任意未来时刻。
事实上,领域分布在连续时间上的场景十分常见,例如:
事件驱动的数据采集
:仅在特定事件发生时采集领域数据,事件之间没有数据。
流数据的随机观测
:领域数据在数据流的任意时间点开始或结束采集,而非持续进行。
离散时态域但缺失
:尽管领域数据基于离散时间点采集,但部分时间节点的领域数据缺失。
为了应对这些场景中的模型泛化,我们提出了“
连续时域泛化
”(Continuous Temporal Domain Generalization, CTDG)任务,其中观测和未观测的领域均分布于连续时间轴上随机的时间点。
CTDG 关注于如何表征时态领域的连续动态,使得模型能够在任意时间点实现稳定、适应性的调整,从而完成泛化预测
。
04
核心挑战
CTDG 任务的挑战远超传统的 TDG 方法。CTDG 不仅需要处理不规律时间分布的训练域,更重要的是,它旨在让模型
泛化到任意时刻
,即要求在连续时间的每个点上都能精确描述模型状态。
而 TDG 方法则仅关注未来的单步泛化:在观测点优化出当前模型状态后,只需将其外推一步即可。这使得 CTDG 区别于 TDG 任务:
CTDG 的关键在于如何在连续时间轴上同步数据分布和模型参数的动态演变,而不是仅局限于未来某一特定时刻的模型表现
。
具体而言,与 TDG 任务相比,CTDG 的复杂性主要来自以下几个尚未被充分探索的核心挑战:
如何建模数据动态并同步模型动态
:CTDG 要求在连续时间轴上捕捉领域数据的动态,并据此同步调整模型状态。然而,数据动态本身难以直接观测,需要通过观测时间点来学习。此外,模型动态的演变过程也同样复杂。理解数据演变如何驱动模型演变构成了 CTDG 的首要挑战。
如何在高度非线性模型动态中捕捉主动态
:领域数据的预测模型通常依赖过参数化(over-parametrized)的深度神经网络,模型动态因此呈现出高维、非线性的复杂特征。这导致模型的主动态嵌藏在大量潜在维度中。如何有效提取并将这些主动态映射到可学习的空间,是 CTDG 任务中的另一重大挑战。
如何确保长期泛化的稳定性和可控性
:为实现未来任意时刻的泛化,CTDG 必须确保模型的长期稳定性。此外,在许多情况下,我们可能拥有数据动态的高层次先验知识。如何将这些先验知识嵌入 CTDG 的优化过程中,进而提升泛化的稳定性和可控性,是一个重要的开放性问题。
05
技术方法
5.1 问题定义
在 CTDG 中,一个域
表示在时间
采集的数据集,由实例集
组成,其中
和
分别为特征值,目标值和实例数。我们重点关注连续时间上的渐进性概念漂移,表示为领域数据的条件概率分布
随时间平滑变化。
在训练阶段,模型接收一系列在不规律时间点
上收集的观测域
,其中每个时间点
是定义在连续时间轴
上的实数,且满足 $t_1
在每个
上,模型学习到领域数据
的预测函数
,其中
表示
时刻的模型参数。CTDG 的目标是建模参数的动态变化,以便在任意给定时刻
上预测模型参数
,从而得到泛化模型
。
在后续部分中,我们使用简写符号
和
,分别表示在时间
上的
、
和
。
5.2 设计思路
我们的方法通过模型与数据的同步、动态简化表示,以及高效的联合优化展开。具体思路如下:
1.
同步数据和模型的动态
:我们证明了连续时域中模型参数的连续性,而后借助神经微分方程(Neural ODE)建立模型动态系统,从而实现模型动态与数据动态的同步。
2.
表征高维动态到低维空间
:我们将高维模型参数映射到一个结构化的库普曼空间(Koopman Space)中。该空间通过可学习的低维线性动态来捕捉模型的主要动态。
3.
联合优化模型与其动态
:我们将单个领域的模型学习与各时间点上的连续动态进行联合优化,并设计了归纳偏置的约束接口,通过端到端优化保证泛化的稳定性和可控性。
▲ 模型设计
5.3 解决方案
Step 1. 数据动态建模与模型动态同步
分布变化的连续性假设
:我们首先假设数据分布在时间上具有连续演化的特性,即条件概率分布
随时间平滑变化, 其演化规律可由一个函数
所描述的动态系统刻画。尽管真实世界中的渐进概念漂移可能较为复杂,但因概念漂移通常源于底层的连续过程(如自然、生物、物理、社会或经济因素),这一假设不失普适性。
分布变化引发的模型参数连续演化
:基于上述假设,模型的函数功能空间
应随数据分布变化同步调整。我们借助常微分方程来描述这一过程:
由此可推导出模型参数
的演化满足:
其中,
是
对
的雅可比矩阵。
这一结果表明,
如果数据分布的演化在时间上具有连续性,那么
的演化过程也具有连续性
,即模型参数会随数据分布的变化而平滑调整。
上式为
建立了一个由微分方程描述的模型动态系统
。
模型动态系统学习
:由于数据动态
的具体形式未知, 直接求解上述微分方程并不可行。为此, 我们引入一个由神经网络定义的连续动态系统, 用可学习的函数
描述模型参数
的变化。
通过鼓励模型动态和数据动态之间的拓扑共轭(Topological Conjugation)关系使
逼近真实动态。具体而言, 拓扑共轭要求通过泛化获得的模型参数与直接训练得到的参数保持一致。为此, 我们设定以下优化目标, 以学习
的参数
:
其中,
通过在时刻
的领域上直接训练获得,
则表示从时间
通过动态
演变至
的泛化参数:
通过这一优化过程,我们建立了模型动态与数据动态之间的同步机制。借助动态函数
,我们可以在任意时刻精确求解模型的状态。
Step 2. 通过库普曼算子简化模型动态
非线性动态线性化
在实际任务中, 预测模型通常依赖于过参数化的深度神经网络, 使得模型动态
呈现为在高维空间中纠缠的非线性动态。直接对
建模不仅计算量大,且极易导致泛化不稳定。
然而,
受数据动态