离散扩散模型:基于连续时间马尔可夫链的文本与序列生成新范式
发布时间:2026/6/22 2:58:42
分类:文化教育
浏览:1234

1. 项目概述当离散数据遇上连续时间在生成式AI的浪潮里扩散模型无疑是当前最耀眼的明星之一。从生成逼真图像的DALL-E、Stable Diffusion到合成音频、视频的Sora其核心都离不开扩散过程。然而绝大多数人熟悉的扩散模型无论是DDPM还是其后续变体都建立在一个默认的假设之上数据是连续的比如图像中的像素值可以平滑地从噪声过渡到清晰。但现实世界充满了离散数据——文本、分子结构、代码、用户点击序列这些数据点之间没有平滑的中间态。你无法想象一个介于单词“猫”和“狗”之间的“半猫半狗”的单词是什么样子。这就引出了一个核心挑战如何将强大的扩散生成能力迁移到离散数据领域传统的做法比如引入嵌入层并添加高斯噪声往往效果不佳或训练不稳定。而“基于连续时间马尔可夫链的离散扩散模型”正是为解决这一难题而生。它巧妙地绕开了“连续中间态”的障碍将离散数据的生成过程建模为一个在离散状态空间上、但时间维度连续演化的随机过程。简单来说它不再试图“模糊”一个单词而是定义了这个单词以何种速率“跳转”到词典中的其他单词包括一个特殊的[MASK]状态整个过程的时间是连续的。这种方法不仅理论优美在实践中也展现出了生成高质量、多样性离散序列的潜力为文本生成、分子设计、代码补全等领域提供了新的强大工具。如果你正在处理自然语言、图结构或任何符号序列的生成任务并且对传统自回归模型如GPT系列的逐 token 生成方式感到局限或者对VAE在离散数据上的表现不满意那么理解这套框架将为你打开一扇新的大门。它提供了一种全新的、并行化的生成范式。2. 核心原理拆解连续时间马尔可夫链如何驱动离散扩散要理解这个模型我们需要拆解两个核心概念“离散扩散”和“连续时间马尔可夫链”并看它们是如何结合在一起的。2.1 离散扩散的本质状态空间的随机游走首先忘掉图像扩散中那幅从模糊到清晰的渐变图。对于离散数据比如一个来自大小为V的词汇表的单词它的状态是离散的x可以是V个可能值中的任何一个。离散扩散过程描述的是这个状态如何随时间t从0到T变化。这个过程通常定义为一个前向过程和一个反向过程前向过程从真实数据x_0开始随着时间t增加逐步地、随机地破坏它。在连续时间设定下破坏不是按固定步长发生的而是在任意时刻数据点都有一定的概率“跃迁”到其他状态。最终当tT时数据被完全破坏成一个简单的先验分布例如一个均匀分布或一个吸收态如全部变成[MASK]。反向过程这是生成的核心。我们学习一个模型使其能够从简单的先验分布tT时的状态开始沿着时间t反向演化逐步“去噪”或“修复”数据最终得到真实数据分布x_0。关键点在于这个“破坏”和“修复”的“动作”是离散状态之间的瞬时切换而不是连续值的微小扰动。2.2 连续时间马尔可夫链为扩散提供数学引擎连续时间马尔可夫链是描述上述过程的完美数学工具。一个CTMC由两部分定义状态空间对我们来说就是所有可能的数据状态集合如所有单词。转移速率矩阵 Q(t)这是一个V x V的矩阵。矩阵中的每个元素Q_{ij}(t)代表了在时间t从状态i瞬时跃迁到状态j的速率。速率不是概率它的单位是“每单位时间”。在极短的时间Δt内从i跳到j的概率近似为Q_{ij}(t) * Δt。在离散扩散的前向过程中我们人为设计一个速率矩阵Q_fwd(t)。一个常见且简单的设计是让每个状态都以一个固定的速率β(t)跃迁到一个共同的“吸收态”或“掩码态”[M]。同时也可以定义状态之间以某种规则相互跃迁。这个设计保证了随着时间推移数据最终会趋于一个我们已知的简单分布。注意这里β(t)是一个关于时间的函数类似于连续扩散中的噪声调度。它控制了破坏的“速度”通常在开始时较小结束时较大。2.3 核心桥梁时间反转与得分匹配那么我们如何学习反向生成过程呢这里需要用到CTMC的一个深刻性质在一定的正则条件下一个CTMC的时间反转过程本身也是一个CTMC。假设前向过程由Q_fwd(t)定义。理论上存在一个反向过程的速率矩阵Q_rev(t)它使得从终点T开始、按照Q_rev(t)演化最终得到的路径分布正好是前向路径分布的时间反转。我们的生成模型的目标就是去参数化并学习这个反向速率矩阵Q_rev(t)。但是直接学习Q_rev(t)是困难的。一个关键的推导基于随机过程的反向时间生成器理论给出了一个可操作的公式Q_rev_{ij}(t) (p_t(j) / p_t(i)) * Q_fwd_{ji}(t)其中p_t(i)表示在时间t时状态为i的概率边缘概率。这个公式告诉我们反向速率不仅依赖于前向速率的转置Q_fwd_{ji}还依赖于两个状态的相对概率密度比p_t(j)/p_t(i)。这个概率密度比是未知的但正是我们的神经网络需要学习的东西我们可以定义一个模型s_θ(i, t)去估计这个“得分”或者说概率对数的梯度在离散空间的一种推广。最终学习目标可以转化为一种加权的交叉熵损失让模型预测在给定当前状态和时间下下一个瞬时可能发生的跃迁类型。实操心得理解这一点至关重要。模型的输出层通常是一个大小为V的向量但它不是直接预测下一个token而是预测每个token作为“跃迁目标”的“倾向性得分”。训练时我们从前向过程中采样一个真实跃迁事件在时间t从状态x_s跳到了x_t然后让模型去拟合这个事件发生的“速率”。这避免了在离散空间做难以处理的似然最大化。3. 模型训练全流程详解理论之后我们进入实战。训练一个离散扩散模型步骤比连续扩散模型要更精细地处理时间。3.1 前向过程设计与速率矩阵构造第一步是设计前向破坏过程。这里有几个常见的选择直接影响模型性能和训练稳定性。均匀吸收每个token以速率β(t)独立地变成[MASK] token。这是最简单的设计Q_fwd矩阵非常稀疏对角线元素为-β(t)表示离开状态i的总速率i- [MASK] 的速率为β(t)其他为0。优点简单易于实现反向过程的学习目标清晰。缺点破坏过程过于简单可能无法让模型学到足够复杂的依赖关系来生成高质量样本。均匀转移每个token以速率β(t)跃迁到词汇表中的任何其他token且概率均匀。此时Q_fwd的非对角线元素i≠j为β(t)/(V-1)对角线元素相应调整。优点破坏更“剧烈”迫使模型在反向过程中学习更强的生成能力。缺点计算开销稍大因为需要考虑所有可能的转移。基于嵌入相似度的转移跃迁到另一个tokenj的速率与当前tokeni的嵌入向量相似度成正比例如使用点积或余弦相似度。这更接近“语义模糊”的直觉。优点破坏过程更自然可能提升生成样本的语义连贯性。缺点需要预训练的嵌入矩阵且速率矩阵不再是静态的计算更复杂。我的建议对于初次实现强烈推荐从均匀吸收开始。它足够简单能让你快速搭建起整个训练-采样流程并验证代码的正确性。β(t)函数可以选择线性调度例如β(t) 0.01 (1.0 - 0.01) * (t/T)其中t从0到1。3.2 损失函数推导与实现损失函数是训练的核心。基于之前提到的理论对于一条从x_s在时间s跃迁到x_t在时间t的路径其中t s ΔtΔt很小我们可以推导出以下训练目标L(θ) E[ -log π_θ(yx_t | x_s, s) ]这里π_θ是模型参数化的反向跃迁概率分布。经过一系列推导涉及Girsanov定理和重要性采样这个期望可以转化为一个更易于计算的形式。在实践中一个稳定且常用的实现方式是加权交叉熵损失从训练数据中采样一个真实样本x_0。从[0, T]中均匀采样一个连续时间t。根据前向过程Q_fwd(t)模拟从x_0在时间t内的演化得到当前被破坏的状态x_t。由于CTMC的模拟需要事件时间我们可以利用一个技巧对于均匀吸收x_t中每个token保持原样的概率是exp(-∫_0^t β(s) ds)被掩码的概率是1 - exp(...)。我们可以直接按这个概率采样而无需真正模拟连续时间路径。将(x_t, t)输入神经网络模型s_θ。模型输出一个形状为(batch_size, seq_len, V)的张量可以理解为每个位置、每个可能目标token的未归一化“得分”。计算损失。这里的关键是损失关注的是模型如何预测“数据原本的样子”。对于被掩码的位置我们希望模型预测出原始的x_0对于未被掩码的位置我们希望模型“保持不动”或预测自身。因此损失可以构造为Loss mean( mask * CE(s_θ(x_t, t), x_0) )其中mask是一个指示哪些位置被破坏的掩码CE是交叉熵损失。mask的权重可以根据时间t进行调整通常给中间时间点更高的权重因为它们最难预测。代码片段示意PyTorch风格def compute_loss(model, x0, t, Q_fwd_integral): x0: 原始数据形状 (B, L) t: 连续时间形状 (B,) Q_fwd_integral: 计算出的累积速率积分用于计算掩码概率 # 1. 计算掩码概率 mask_prob 1 - torch.exp(-Q_fwd_integral(t)) # 形状 (B,) # 2. 采样掩码 mask torch.bernoulli(mask_prob.unsqueeze(-1).expand_as(x0)) # 形状 (B, L) # 3. 生成被破坏的状态 x_t x_t x0.clone() x_t[mask.bool()] MASK_TOKEN_ID # 将被掩码位置替换为[MASK] # 4. 模型前向传播 logits model(x_t, t) # 形状 (B, L, V) # 5. 计算加权交叉熵损失 loss F.cross_entropy(logits.transpose(1, 2), x0, reductionnone) # 形状 (B, L) weighted_loss (mask * loss).sum() / (mask.sum() 1e-8) # 只对被破坏的位置计算损失 return weighted_loss3.3 网络架构选择与输入处理模型架构没有严格限制任何能处理序列数据的网络都可以但需要处理连续时间输入t。主干网络Transformer Encoder 是最自然的选择因为它能很好地建模序列内部的依赖关系。对于中等长度的序列CNN如Temporal Convolutional Networks也是一个轻量高效的选项。时间嵌入连续时间t需要被编码成向量后注入模型。通常采用高斯随机特征映射Random Fourier Features即γ(t) [sin(ω_1 t), cos(ω_1 t), ..., sin(ω_m t), cos(ω_m t)]其中ω是从某个分布如高斯分布中采样的固定频率。然后将γ(t)通过一个线性层加到每个位置的token嵌入上或者作为AdaGN自适应组归一化的参数。输入表示模型的输入是已被部分破坏的序列x_t。我们需要一个标准的嵌入层将token ID映射为向量。对于[MASK] token使用一个可学习的专属嵌入向量。注意事项训练初期由于大部分位置未被掩码损失可能很小且下降很快但这不代表模型学得好。要密切关注在验证集上模型对被掩码部分的预测准确率。一个有效的监控指标是“掩码位置token预测准确率”。4. 采样算法从噪声到数据的连续时间逆过程训练完成后我们拥有了一个学习了反向速率Q_rev_θ的模型。采样过程就是从先验分布出发沿着反向CTMC模拟一条轨迹最终生成样本。4.1 采样器类型模拟与ODE求解与连续扩散类似离散扩散的采样也有两种主要思路随机采样器直接模拟反向CTMC。这需要生成随机跃迁事件和时间。可以使用“拒绝采样”或“ thinning”算法来模拟非齐次时间依赖的CTMC。这种方法生成的样本多样性好但速度较慢且实现复杂。确定性采样器通过求解一个概率流常微分方程进行采样。对于离散扩散也存在一个对应的ODE其“速度场”由模型预测的得分和速率矩阵决定。使用ODE求解器如欧拉法、Heun法可以对其进行数值积分。这种方法速度快、确定性强适合需要快速生成或插值的场景。对于大多数应用确定性采样器是更实用、更推荐的选择。它的步骤可以概括为初始化在时间t T从先验分布采样初始状态x_T。对于均匀吸收前向过程先验就是全部为[MASK]的序列。迭代去噪从tT到t0将时间离散化为多个步长{τ_0T, τ_1, ..., τ_N0}。对于每一步k a. 在时间τ_k计算当前状态x_{τ_k}对应的模型输出得分。 b. 根据ODE公式计算状态在时间上的导数dx/dt。这个导数本质上描述了每个位置从当前token变为其他token的“倾向性流量”。 c. 使用ODE求解器更新状态x_{τ_{k1}} ODESolverStep(x_{τ_k}, dx/dt, τ_k, τ_{k1})。对于离散状态这个“更新”不是简单的加法。一种常见的方法是计算一个转移概率矩阵P expm(Δt * Q_rev_θ)然后根据这个矩阵对每个位置的状态进行采样或取期望贪心选择概率最大的token。expm是矩阵指数对于均匀吸收等简单设计它有闭式解。4.2 实用采样步骤与调参在实际实现中为了平衡速度和质量我们通常采用以下步骤设定采样步数例如N50或100。步数越多采样越精确但耗时越长。时间离散化将连续时间区间[0, T]划分为N步。可以使用线性间隔或者更优的、根据噪声调度β(t)调整的非线性间隔在β(t)大的区域步长更小。循环采样def deterministic_sample(model, seq_len, steps50): # 初始化全部为[MASK] x torch.full((1, seq_len), MASK_TOKEN_ID).to(device) # 创建时间步 timesteps torch.linspace(T, 0, steps1) # 从T到0 for i in range(steps): t timesteps[i] dt timesteps[i] - timesteps[i1] # 负的时间步长 # 1. 获取模型预测的得分/logits with torch.no_grad(): logits model(x, t.unsqueeze(0)) # (1, L, V) # 2. 根据前向Q_fwd和模型logits构造反向转移矩阵P # 对于均匀吸收P有一个简洁形式 # P(keep) exp(-beta_int) , P(switch to j) (1-exp(-beta_int)) * softmax(logits)[j] # 其中 beta_int 是速率β从t到tdt的积分 beta_int integral_of_beta(t, dt) p_keep torch.exp(-beta_int) p_change_dist F.softmax(logits, dim-1) # 模型预测的分布 # 3. 计算每个位置的新状态分布 # 首先构建一个 (L, V) 的分布矩阵 # 对于当前位置是token k新分布 p_keep * one_hot(k) (1-p_keep) * p_change_dist # 这里需要一点张量操作 eye torch.eye(V).to(x.device) # (V, V) current_one_hot eye[x] # (1, L, V) transition_dist p_keep.unsqueeze(-1) * current_one_hot (1-p_keep.unsqueeze(-1)) * p_change_dist # 4. 贪心采样选择每个位置概率最大的token x transition_dist.argmax(dim-1) return x温度调节在从transition_dist采样时可以对logits除以一个温度参数τ。τ 1会使分布更尖锐确定性更强多样性降低τ 1会使分布更平滑多样性增加但可能影响质量。实操心得采样步数N是一个关键的超参数。开始时可以用较少的步数如20快速测试采样效果。如果发现生成结果不连贯或重复尝试增加到50或100。另外在最后几步t接近0时可以尝试减小步长以获得更精细的调整。5. 实战挑战与性能优化策略将理论转化为实践的路上你会遇到几个典型的挑战。以下是我在复现和实验过程中总结出的问题和解决方案。5.1 训练不稳定的常见原因与对策问题现象可能原因排查与解决策略损失NaN或爆炸学习率过高梯度裁剪未启用β(t)调度过于激进导致某些概率计算出现数值溢出如exp(-large_number)。1. 从较低的学习率开始如3e-5。2. 务必添加梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm1.0)。3. 检查β(t)函数确保其积分∫β(s)ds在[0, T]区间内不会过大导致exp(-integral)下溢为0。可以尝试更平缓的调度如β(t) 0.001 * t。模型预测准确率始终很低模型容量不足时间嵌入方式不佳模型无法有效利用时间信息掩码率过高任务过难。1. 增加模型深度或宽度。2. 尝试不同的时间嵌入方法确保γ(t)被充分注入到网络的每一层例如除了加在输入嵌入还可以作为AdaGN的参数。3. 可视化不同时间t下的平均掩码率如果一开始就接近1模型学不到从轻微破坏中恢复的能力。调整β(t)使t在0附近时掩码率较低。生成样本重复、缺乏多样性采样温度τ过低使用了贪心解码argmax模型陷入了某种退化的模式。1. 在采样时尝试τ1.0或更高如1.2并从分布中随机采样而不是argmax。2. 检查训练数据是否本身多样性不足。3. 尝试在训练损失中加入轻微的熵正则化项鼓励模型预测分布更均匀。5.2 效率瓶颈分析与优化离散扩散模型在序列长度L和词汇表大小V较大时计算和内存开销可能成为瓶颈。计算复杂度核心操作是模型前向传播和Softmax over Vocabulary。Transformer的复杂度是O(L^2 * d_model)而最后的Logits矩阵是O(L * V * d_model)。当V很大如数万时最后的线性输出层是主要瓶颈。优化策略词汇表裁剪对于特定领域任务可以考虑使用子词词元如BPE或领域相关的精简词汇表。自适应Softmax或采样Softmax在训练时对于每个批次只计算目标token所在类别的梯度可以大幅减少计算量。PyTorch的F.cross_entropy已经非常优化但对于极大的V可能需要更高级的技巧。模型蒸馏训练一个大模型后将其知识蒸馏到一个更小、更高效的模型中进行部署采样。内存占用存储(B, L, V)的Logits张量会消耗大量显存。优化策略使用混合精度训练AMP可以显著减少显存占用并加速计算。在PyTorch中只需几行代码即可启用。5.3 高级技巧条件生成与引导基础的模型是无条件生成。如何实现像“生成一段关于夏天的文字”这样的条件生成Classifier-Free Guidance这是目前最流行且有效的方法。在训练时随机以一定概率如10%将条件信息如类别标签、文本描述置空。这样模型同时学会了无条件生成p(x)和条件生成p(x|c)。在采样时使用以下公式进行引导logits_guided logits_uncond guidance_scale * (logits_cond - logits_uncond)其中guidance_scale 1。这个技巧能显著提升生成样本与条件的对齐质量但过大的guidance_scale会损害多样性。后处理编辑对于文本可以先使用模型生成一个草案然后利用一个判别式模型如情感分类器、语法检查器对生成结果进行打分并通过MCMC采样等方法迭代地优化草案使其满足特定条件。这种方法更灵活但计算成本更高。一个重要的提醒离散扩散模型的训练和采样在概念上比自回归模型更复杂调试周期可能更长。建议从一个极小的数据集如一个简单的文本数据集和一个小模型开始确保前向破坏、损失计算、反向采样整个流程的每个环节都符合预期再逐步扩展到更大规模的任务上。理解每个张量的形状和含义是成功复现的关键。