> 技术文档 > 【AI模型学习】Gumbel-Softmax —— “软硬皆吃”的函数

【AI模型学习】Gumbel-Softmax —— “软硬皆吃”的函数


文章目录

  • 1.介绍
    • 1.1 公式
    • 1.2 和 Softmax 的区别
  • 2. 例子
    • 2.1 公式计算
    • 2.2 温度
    • 2.3 温度调节器
    • 2.4 梯度回传
  • 3. Straight-Through Gumbel-Softmax(硬采样)
    • 3.1 作用
    • 3.2 GroupViT

1.介绍

Gumbel-Softmax(又叫 Concrete distribution)是一种可微分的替代方案,用于在神经网络中对离散类别变量进行建模与采样。它的核心用途是——在反向传播中保留梯度信息的同时,实现对离散变量的“近似采样”

1.1 公式

给定 logits 向量 z=[ z 1 , z 2 ,…, z k ] \\mathbf{z} = [z_1, z_2, \\dots, z_k] z=[z1,z2,,zk],Gumbel-Softmax 的计算过程如下:

  1. 首先为每个类别添加 Gumbel 噪声:

    g i = − log ⁡ ( − log ⁡ ( U i ) ) , U i ∼ Uniform ( 0 , 1 ) g_i = -\\log(-\\log(U_i)), \\quad U_i \\sim \\text{Uniform}(0, 1) gi=log(log(Ui)),UiUniform(0,1)

  2. 加入噪声后进行 softmax 归一化:

    y i = exp ⁡ ( ( z i + g i ) / τ ) ∑ j = 1 k exp ⁡ ( ( z j + g j ) / τ ) y_i = \\frac{\\exp\\left((z_i + g_i)/\\tau\\right)}{\\sum_{j=1}^{k} \\exp\\left((z_j + g_j)/\\tau\\right)} yi=j=1kexp((zj+gj)/τ)exp((zi+gi)/τ)

    其中 τ > 0 \\tau > 0 τ>0温度参数,决定了输出的平滑程度。


1.2 和 Softmax 的区别

特性 Softmax Gumbel-Softmax 输入 logits(可学习) logits + Gumbel噪声 输出 概率分布(连续) 概率分布(连续,且近似 one-hot) 可微性 可 可 用途 常规分类 离散采样的可微近似 温度控制 通常固定,有时也会动 可以控制采样“硬度” 是否引入随机性 不可 (通过 Gumbel Noise)

关键点:采样 vs 分类

  • Softmax:直接用于分类,输出连续概率分布,训练时通常和 CrossEntropyLoss 一起使用。
  • Gumbel-Softmax:用于模拟对离散变量的采样操作(例如 One-Hot 向量),但是保留可微性,使得可以端到端训练(如在 GAN、VAE 中)。

温度参数 τ \\tau τ 的作用

  • τ → ∞ \\tau \\rightarrow \\infty τ:分布变得非常平滑,趋于均匀分布。
  • τ → 0 \\tau \\rightarrow 0 τ0:输出趋近于 one-hot,近似离散采样。
  • 通常在训练过程中使用 退火策略:逐步减小 τ \\tau τ,使得采样从平滑到离散。

应用场景

  • 离散变量的生成模型(如 VAE)
  • 强化学习中的离散动作空间
  • 神经网络架构搜索(如 one-hot 选择某个子结构)
  • 零 shot learning / attention over categorical variables

2. 例子

2.1 公式计算

普通 Softmax

给定 logits 向量 z=[ z 1 , z 2 ,…, z k ] \\mathbf{z} = [z_1, z_2, \\dots, z_k] z=[z1,z2,,zk],Softmax 公式为:

y i = exp ⁡ ( z i / τ ) ∑ j = 1 k exp ⁡ ( z j / τ ) y_i = \\frac{\\exp(z_i / \\tau)}{\\sum_{j=1}^{k} \\exp(z_j / \\tau)} yi=j=1kexp(zj/τ)exp(zi/τ)

其中:

  • τ \\tau τ 是温度(默认=1),越小越“尖锐”,越大越“平滑”;
  • 输出是一个连续的概率分布;
  • 不可用于采样,也没有随机性。

Gumbel-Softmax

y i = exp ⁡ ( ( z i + g i ) / τ ) ∑ j = 1 k exp ⁡ ( ( z j + g j ) / τ ) y_i = \\frac{\\exp\\left((z_i + g_i)/\\tau\\right)}{\\sum_{j=1}^{k} \\exp\\left((z_j + g_j)/\\tau\\right)} yi=j=1kexp((zj+gj)/τ)exp((zi+gi)/τ)

其中:

  • g i = − log ⁡ ( − log ⁡ ( U i ) ) g_i = -\\log(-\\log(U_i)) gi=log(log(Ui)),每个 U i ∼ Uniform ( 0 , 1 ) U_i \\sim \\text{Uniform}(0, 1) UiUniform(0,1)
  • 即 logits 上加了 Gumbel 噪声,相当于做了可导的近似采样;
  • 同样是 softmax 结构,但输入是随机扰动过的;
  • 输出可以 近似 one-hot(随温度 τ 变化)

数值举例:logits = [2.0, 1.0, 0.1]

我们令温度 τ=0.5 \\tau = 0.5 τ=0.5,并计算两种输出。

  1. 普通 Softmax(τ=0.5)

    y 1 = exp ⁡ ( 2.0 / 0.5 ) exp ⁡ ( 2.0 / 0.5 ) + exp ⁡ ( 1.0 / 0.5 ) + exp ⁡ ( 0.1 / 0.5 ) = exp ⁡ ( 4 ) exp ⁡ ( 4 ) + exp ⁡ ( 2 ) + exp ⁡ ( 0.2 ) ≈ 54.6 54.6 + 7.4 + 1.2 ≈ 0.84 y 2 ≈ 0.11 , y 3 ≈ 0.02 \\begin{align*} y_1 &= \\frac{\\exp(2.0 / 0.5)}{\\exp(2.0/0.5) + \\exp(1.0/0.5) + \\exp(0.1/0.5)} = \\frac{\\exp(4)}{\\exp(4) + \\exp(2) + \\exp(0.2)} \\\\ &\\approx \\frac{54.6}{54.6 + 7.4 + 1.2} \\approx 0.84 \\\\ y_2 &\\approx 0.11,\\quad y_3 \\approx 0.02 \\end{align*} y1y2=exp(2.0/0.5)+exp(1.0/0.5)+exp(0.1/0.5)exp(2.0/0.5)=exp(4)+exp(2)+exp(0.2)exp(4)54.6+7.4+1.254.60.840.11,y30.02

    是一个 平滑的概率分布

  2. Gumbel-Softmax(τ=0.5 + 噪声)

    先采样一组 Gumbel 噪声(假设):

    g = [ 0.1 ,   − 0.3 ,   1.2 ] g = [0.1,\\ -0.3,\\ 1.2] g=[0.1, 0.3, 1.2]

    然后计算 perturbed logits:

    z + g = [ 2.1 ,   0.7 ,   1.3 ] z + g = [2.1,\\ 0.7,\\ 1.3] z+g=[2.1, 0.7, 1.3]

    再 softmax:

    exp ⁡ ( [ 4.2 , 1.4 , 2.6 ] ) ≈ [ 66 , 4.0 , 13.5 ] ⇒ y ≈ [ 0.78 ,   0.05 ,   0.17 ] \\exp([4.2, 1.4, 2.6]) ≈ [66, 4.0, 13.5] \\Rightarrow y ≈ [0.78,\\ 0.05,\\ 0.17] exp([4.2,1.4,2.6])[66,4.0,13.5]y[0.78, 0.05, 0.17]

    结果:

    • 同样是概率分布;
    • 有随机性(每次 g 都不同);
    • 输出更偏向某个类,有时接近 one-hot。

2.2 温度

  • logits: z = [ 2.0 ,   1.0 ,   0.1 ] \\mathbf{z} = [2.0,\\ 1.0,\\ 0.1] z=[2.0, 1.0, 0.1]
  • 一组固定的 Gumbel noise:假设 g = [ 0.1 ,   − 0.3 ,   1.2 ] \\mathbf{g} = [0.1,\\ -0.3,\\ 1.2] g=[0.1, 0.3, 1.2]
  • 温度 τ 分别取:1.0 → 0.5 → 0.1 → 0.01

我们每次计算:

y i = exp ⁡ ( ( z i + g i ) / τ ) ∑ j exp ⁡ ( ( z j + g j ) / τ ) y_i = \\frac{\\exp\\left((z_i + g_i)/\\tau\\right)}{\\sum_j \\exp\\left((z_j + g_j)/\\tau\\right)} yi=jexp((zj+gj)/τ)exp((zi+gi)/τ)

1. τ = 1.0 (比较平滑)

z + g = [ 2.1 ,   0.7 ,   1.3 ] ⇒ exp ⁡ ( [ 2.1 , 0.7 , 1.3 ] ) ≈ [ 8.17 , 2.01 , 3.67 ] z + g = [2.1,\\ 0.7,\\ 1.3] \\Rightarrow \\exp([2.1, 0.7, 1.3]) ≈ [8.17, 2.01, 3.67] z+g=[2.1, 0.7, 1.3]exp([2.1,0.7,1.3])[8.17,2.01,3.67]归一化:

y ≈ [ 0.58 ,   0.14 ,   0.26 ] y ≈ [0.58,\\ 0.14,\\ 0.26] y[0.58, 0.14, 0.26]

2. τ = 0.5 (更陡峭)

( z + g ) / 0.5 = [ 4.2 ,   1.4 ,   2.6 ] ⇒ exp ⁡ ≈ [ 66.7 ,   4.05 ,   13.46 ] ⇒ y ≈ [ 0.78 ,   0.05 ,   0.17 ] (z+g)/0.5 = [4.2,\\ 1.4,\\ 2.6] \\Rightarrow \\exp ≈ [66.7,\\ 4.05,\\ 13.46] \\Rightarrow y ≈ [0.78,\\ 0.05,\\ 0.17] (z+g)/0.5=[4.2, 1.4, 2.6]exp[66.7, 4.05, 13.46]y[0.78, 0.05, 0.17]

3. τ = 0.1 (非常接近 one-hot)

( z + g ) / 0.1 = [ 21 ,   7 ,   13 ] ⇒ exp ⁡ ≈ [ 1.3 e 9 ,   1.1 e 3 ,   4.4 e 5 ] ⇒ y ≈ [ 0.9995 ,    0.000001 ,    0.0004 ] (z+g)/0.1 = [21,\\ 7,\\ 13] \\Rightarrow \\exp ≈ [1.3e9,\\ 1.1e3,\\ 4.4e5] \\Rightarrow y ≈ [0.9995,\\ ~0.000001,\\ ~0.0004] (z+g)/0.1=[21, 7, 13]exp[1.3e9, 1.1e3, 4.4e5]y[0.9995,  0.000001,  0.0004]

4. τ = 0.01(几乎就是 one-hot)

( z + g ) / 0.01 = [ 210 ,   70 ,   130 ] ⇒ exp ⁡ ≈ [ e 210 , e 70 , e 130 ] ⇒ y ≈ [ 1.0 ,   0 ,   0 ] (z+g)/0.01 = [210,\\ 70,\\ 130] \\Rightarrow \\exp ≈ [e^{210}, e^{70}, e^{130}] \\Rightarrow y ≈ [1.0,\\ 0,\\ 0] (z+g)/0.01=[210, 70, 130]exp[e210,e70,e130]y[1.0, 0, 0]

这时:

  • 输出已经几乎就是 [ 1.0 ,   0.0 ,   0.0 ] [1.0,\\ 0.0,\\ 0.0] [1.0, 0.0, 0.0],即 one-hot;
  • 但由于 没有真实 argmax 操作,依然保持了可导性!

对比表:

温度 τ 输出分布 y(近似) 1.0 [0.58, 0.14, 0.26] 0.5 [0.78, 0.05, 0.17] 0.1 [0.9995, 0.000001, 0.0004] 0.01 [1.0, 0.0, 0.0] ← 近似 one-hot
  • 温度越低,分布越尖锐
  • 最终趋向于 argmax 的 one-hot,但保留了梯度
  • 这是 Gumbel-Softmax 的最大价值:在训练时可以平滑可导地模拟采样过程

2.3 温度调节器

温度调度器(temperature annealing) 是使用 Gumbel-Softmax 时常见的一种机制,它能让模型:

  • 初期探索更多 group(温度高,分布平滑);
  • 后期收敛到明确分组(温度低,趋于 one-hot);
  • 从而达到“软→硬”的逐步可导采样。
  1. 指数退火(Exponential decay)

τ t = max ⁡ ( τ min ⁡ ,   τ 0 ⋅ exp ⁡ ( − k ⋅ t ) ) \\tau_t = \\max(\\tau_{\\min},\\ \\tau_0 \\cdot \\exp(-k \\cdot t)) τt=max(τmin, τ0exp(kt))

  • τ 0 \\tau_0 τ0:初始温度(比如 1.0)
  • k k k:退火速率(如 0.01)
  • t t t:当前 epoch
  • τ min ⁡ \\tau_{\\min} τmin:最终保持的最小温度,避免梯度消失
  1. 线性退火

τ t = max ⁡ ( τ min ⁡ ,   τ 0 − r ⋅ t ) \\tau_t = \\max(\\tau_{\\min},\\ \\tau_0 - r \\cdot t) τt=max(τmin, τ0rt)

  • r r r:线性下降速度

代码

class TemperatureScheduler: def __init__(self, tau0=1.0, tau_min=0.1, decay_rate=0.03): self.tau0 = tau0 self.tau_min = tau_min self.decay_rate = decay_rate def get_tau(self, epoch): tau = self.tau0 * np.exp(-self.decay_rate * epoch) return max(self.tau_min, tau)

用法

scheduler = TemperatureScheduler(tau0=1.0, tau_min=0.1, decay_rate=0.05)for epoch in range(50): tau = scheduler.get_tau(epoch) logits = model(x) gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits))) y = F.softmax((logits + gumbel_noise) / tau, dim=-1) ...

2.4 梯度回传

Softmax

Forward:

y i = exp ⁡ ( z i ) ∑ j exp ⁡ ( z j ) y_i = \\frac{\\exp(z_i)}{\\sum_j \\exp(z_j)} yi=jexp(zj)exp(zi)

Backward:

Softmax 的梯度对输入 logits z i z_i zi 的导数是:

∂ y i ∂ z k = y i ( δ i k− y k ) \\frac{\\partial y_i}{\\partial z_k} = y_i(\\delta_{ik} - y_k) zkyi=yi(δikyk)

如果我们和交叉熵 loss L=− ∑ i y i true log⁡ y i L = -\\sum_i y_i^{\\text{true}} \\log y_i L=iyitruelogyi 联合使用:

∂ L ∂ z k = y k − y k true \\frac{\\partial L}{\\partial z_k} = y_k - y_k^{\\text{true}} zkL=ykyktrue

Gumbel-Softmax

Gumbel-Softmax 的前向过程是:

y i = exp ⁡ ( ( z i + g i ) / τ ) ∑ j exp ⁡ ( ( z j + g j ) / τ ) y_i = \\frac{\\exp((z_i + g_i)/\\tau)}{\\sum_j \\exp((z_j + g_j)/\\tau)} yi=jexp((zj+gj)/τ)exp((zi+gi)/τ)

虽然加了随机性,但 g_i 被视为常量(不会回传)

因此,在反向传播时:

  • 梯度仍然是从 softmax 函数计算;
  • 只是 softmax 的输入变成了 z ~ i = ( z i + g i ) / τ \\tilde{z}_i = (z_i + g_i)/\\tau z~i=(zi+gi)/τ

∂ y i ∂ z k = 1 τ ⋅ y i ( δ i k− y k ) \\frac{\\partial y_i}{\\partial z_k} = \\frac{1}{\\tau} \\cdot y_i(\\delta_{ik} - y_k) zkyi=τ1yi(δikyk)

跟 Softmax 几乎一样,但额外多了一个 1 τ \\frac{1}{\\tau} τ1 缩放因子,温度越低,梯度越大

注意:g_i 不传梯度,只是扰动 forward 的 logits;反向传播只看 z 的梯度


3. Straight-Through Gumbel-Softmax(硬采样)

  • 前向输出 one-hot(使用 argmax);
  • 反向仍使用 soft Gumbel 的梯度;
  • 类似于 detach 技巧:
y_hard = one_hot(torch.argmax(y, dim=-1))y = (y_hard - y).detach() + y # 前向为硬,反向为软

3.1 作用

问题背景:

我们经常希望模型明确地做选择,比如:

  • 哪个 token 属于哪个 group(GroupViT)
  • 从 N 个动作中选一个(RL / controller)
  • 选择某个 module、某个结构(NAS)

但如果你直接用 argmax

  • 前向可以选;
  • 反向就挂了(不可导)。

ST Gumbel-Softmax 的解决方案:

你可以:

  • 前向:使用 one-hot(离散采样)
  • 反向:仍然使用 Gumbel-Softmax 的软梯度(保持可导)

场景举例

应用场景 作用 GroupViT 每个 patch 真的只分给一个 group(不是混合) 控制器/策略网络 模拟离散动作(像强化学习中的 action selection) 神经架构搜索(NAS) 从多个结构中 hard 选一个结构路径 VQ-VAE 对 latent 空间进行离散编码 多任务选择 在多个 heads/branches 中选一条路径

3.2 GroupViT

GroupViT 要做的事是:让每个 patch token 分配给一个明确的 group token

也就是:每个 patch → 只属于一个 group(语义聚类);然后对同组 patch 聚合成一个语义 group。

为什么不能直接 argmax?

  • argmax 是硬选择:(想要)
  • argmax 不可导:(不能训练)

如果把 Gumbel 换成普通 Softmax 会怎么样?

结果:模型会退化为“软分配”,语义不清晰

具体表现如下:

  • 输出是概率分布;

  • 每个 patch 会被“部分分配”到所有 group 上;

    patch i = 0.4 ⋅ g 1 + 0.3 ⋅ g 2 + 0.3 ⋅ g 3 \\text{patch}_i = 0.4 \\cdot g_1 + 0.3 \\cdot g_2 + 0.3 \\cdot g_3 patchi=0.4g1+0.3g2+0.3g3

  • 聚类效果变差,语义边界模糊;

  • 训练容易,但分组不清晰、不稳定

总结句

GroupViT 使用 ST Gumbel-Softmax,是为了同时实现:

  1. 离散分组(清晰 one-hot)
  2. 可导训练(不依赖 argmax)

如果换成普通 softmax,会变成“模糊分组”,模型将失去“明确聚类语义区域”的能力,本质上不再是 GroupViT 的设计初衷了。