【AI模型学习】Gumbel-Softmax —— “软硬皆吃”的函数
文章目录
- 1.介绍
-
- 1.1 公式
- 1.2 和 Softmax 的区别
- 2. 例子
- 3. Straight-Through Gumbel-Softmax(硬采样)
-
- 3.1 作用
- 3.2 GroupViT
1.介绍
Gumbel-Softmax(又叫 Concrete distribution)是一种可微分的替代方案,用于在神经网络中对离散类别变量进行建模与采样。它的核心用途是——在反向传播中保留梯度信息的同时,实现对离散变量的“近似采样”。
1.1 公式
给定 logits 向量 z=[ z 1 , z 2 ,…, z k ] \\mathbf{z} = [z_1, z_2, \\dots, z_k] z=[z1,z2,…,zk],Gumbel-Softmax 的计算过程如下:
-
首先为每个类别添加 Gumbel 噪声:
g i = − log ( − log ( U i ) ) , U i ∼ Uniform ( 0 , 1 ) g_i = -\\log(-\\log(U_i)), \\quad U_i \\sim \\text{Uniform}(0, 1) gi=−log(−log(Ui)),Ui∼Uniform(0,1)
-
加入噪声后进行 softmax 归一化:
y i = exp ( ( z i + g i ) / τ ) ∑ j = 1 k exp ( ( z j + g j ) / τ ) y_i = \\frac{\\exp\\left((z_i + g_i)/\\tau\\right)}{\\sum_{j=1}^{k} \\exp\\left((z_j + g_j)/\\tau\\right)} yi=∑j=1kexp((zj+gj)/τ)exp((zi+gi)/τ)
其中 τ > 0 \\tau > 0 τ>0 是温度参数,决定了输出的平滑程度。
1.2 和 Softmax 的区别
关键点:采样 vs 分类
- Softmax:直接用于分类,输出连续概率分布,训练时通常和
CrossEntropyLoss
一起使用。 - Gumbel-Softmax:用于模拟对离散变量的采样操作(例如 One-Hot 向量),但是保留可微性,使得可以端到端训练(如在 GAN、VAE 中)。
温度参数 τ \\tau τ 的作用
- τ → ∞ \\tau \\rightarrow \\infty τ→∞:分布变得非常平滑,趋于均匀分布。
- τ → 0 \\tau \\rightarrow 0 τ→0:输出趋近于 one-hot,近似离散采样。
- 通常在训练过程中使用 退火策略:逐步减小 τ \\tau τ,使得采样从平滑到离散。
应用场景
- 离散变量的生成模型(如 VAE)
- 强化学习中的离散动作空间
- 神经网络架构搜索(如 one-hot 选择某个子结构)
- 零 shot learning / attention over categorical variables
2. 例子
2.1 公式计算
普通 Softmax
给定 logits 向量 z=[ z 1 , z 2 ,…, z k ] \\mathbf{z} = [z_1, z_2, \\dots, z_k] z=[z1,z2,…,zk],Softmax 公式为:
y i = exp ( z i / τ ) ∑ j = 1 k exp ( z j / τ ) y_i = \\frac{\\exp(z_i / \\tau)}{\\sum_{j=1}^{k} \\exp(z_j / \\tau)} yi=∑j=1kexp(zj/τ)exp(zi/τ)
其中:
- τ \\tau τ 是温度(默认=1),越小越“尖锐”,越大越“平滑”;
- 输出是一个连续的概率分布;
- 不可用于采样,也没有随机性。
Gumbel-Softmax
y i = exp ( ( z i + g i ) / τ ) ∑ j = 1 k exp ( ( z j + g j ) / τ ) y_i = \\frac{\\exp\\left((z_i + g_i)/\\tau\\right)}{\\sum_{j=1}^{k} \\exp\\left((z_j + g_j)/\\tau\\right)} yi=∑j=1kexp((zj+gj)/τ)exp((zi+gi)/τ)
其中:
- g i = − log ( − log ( U i ) ) g_i = -\\log(-\\log(U_i)) gi=−log(−log(Ui)),每个 U i ∼ Uniform ( 0 , 1 ) U_i \\sim \\text{Uniform}(0, 1) Ui∼Uniform(0,1);
- 即 logits 上加了 Gumbel 噪声,相当于做了可导的近似采样;
- 同样是 softmax 结构,但输入是随机扰动过的;
- 输出可以 近似 one-hot(随温度 τ 变化)。
数值举例:logits = [2.0, 1.0, 0.1]
我们令温度 τ=0.5 \\tau = 0.5 τ=0.5,并计算两种输出。
-
普通 Softmax(τ=0.5)
y 1 = exp ( 2.0 / 0.5 ) exp ( 2.0 / 0.5 ) + exp ( 1.0 / 0.5 ) + exp ( 0.1 / 0.5 ) = exp ( 4 ) exp ( 4 ) + exp ( 2 ) + exp ( 0.2 ) ≈ 54.6 54.6 + 7.4 + 1.2 ≈ 0.84 y 2 ≈ 0.11 , y 3 ≈ 0.02 \\begin{align*} y_1 &= \\frac{\\exp(2.0 / 0.5)}{\\exp(2.0/0.5) + \\exp(1.0/0.5) + \\exp(0.1/0.5)} = \\frac{\\exp(4)}{\\exp(4) + \\exp(2) + \\exp(0.2)} \\\\ &\\approx \\frac{54.6}{54.6 + 7.4 + 1.2} \\approx 0.84 \\\\ y_2 &\\approx 0.11,\\quad y_3 \\approx 0.02 \\end{align*} y1y2=exp(2.0/0.5)+exp(1.0/0.5)+exp(0.1/0.5)exp(2.0/0.5)=exp(4)+exp(2)+exp(0.2)exp(4)≈54.6+7.4+1.254.6≈0.84≈0.11,y3≈0.02
是一个 平滑的概率分布。
-
Gumbel-Softmax(τ=0.5 + 噪声)
先采样一组 Gumbel 噪声(假设):
g = [ 0.1 , − 0.3 , 1.2 ] g = [0.1,\\ -0.3,\\ 1.2] g=[0.1, −0.3, 1.2]
然后计算 perturbed logits:
z + g = [ 2.1 , 0.7 , 1.3 ] z + g = [2.1,\\ 0.7,\\ 1.3] z+g=[2.1, 0.7, 1.3]
再 softmax:
exp ( [ 4.2 , 1.4 , 2.6 ] ) ≈ [ 66 , 4.0 , 13.5 ] ⇒ y ≈ [ 0.78 , 0.05 , 0.17 ] \\exp([4.2, 1.4, 2.6]) ≈ [66, 4.0, 13.5] \\Rightarrow y ≈ [0.78,\\ 0.05,\\ 0.17] exp([4.2,1.4,2.6])≈[66,4.0,13.5]⇒y≈[0.78, 0.05, 0.17]
结果:
- 同样是概率分布;
- 有随机性(每次 g 都不同);
- 输出更偏向某个类,有时接近 one-hot。
2.2 温度
- logits: z = [ 2.0 , 1.0 , 0.1 ] \\mathbf{z} = [2.0,\\ 1.0,\\ 0.1] z=[2.0, 1.0, 0.1]
- 一组固定的 Gumbel noise:假设 g = [ 0.1 , − 0.3 , 1.2 ] \\mathbf{g} = [0.1,\\ -0.3,\\ 1.2] g=[0.1, −0.3, 1.2]
- 温度 τ 分别取:1.0 → 0.5 → 0.1 → 0.01
我们每次计算:
y i = exp ( ( z i + g i ) / τ ) ∑ j exp ( ( z j + g j ) / τ ) y_i = \\frac{\\exp\\left((z_i + g_i)/\\tau\\right)}{\\sum_j \\exp\\left((z_j + g_j)/\\tau\\right)} yi=∑jexp((zj+gj)/τ)exp((zi+gi)/τ)
1. τ = 1.0 (比较平滑)
z + g = [ 2.1 , 0.7 , 1.3 ] ⇒ exp ( [ 2.1 , 0.7 , 1.3 ] ) ≈ [ 8.17 , 2.01 , 3.67 ] z + g = [2.1,\\ 0.7,\\ 1.3] \\Rightarrow \\exp([2.1, 0.7, 1.3]) ≈ [8.17, 2.01, 3.67] z+g=[2.1, 0.7, 1.3]⇒exp([2.1,0.7,1.3])≈[8.17,2.01,3.67]归一化:
y ≈ [ 0.58 , 0.14 , 0.26 ] y ≈ [0.58,\\ 0.14,\\ 0.26] y≈[0.58, 0.14, 0.26]
2. τ = 0.5 (更陡峭)
( z + g ) / 0.5 = [ 4.2 , 1.4 , 2.6 ] ⇒ exp ≈ [ 66.7 , 4.05 , 13.46 ] ⇒ y ≈ [ 0.78 , 0.05 , 0.17 ] (z+g)/0.5 = [4.2,\\ 1.4,\\ 2.6] \\Rightarrow \\exp ≈ [66.7,\\ 4.05,\\ 13.46] \\Rightarrow y ≈ [0.78,\\ 0.05,\\ 0.17] (z+g)/0.5=[4.2, 1.4, 2.6]⇒exp≈[66.7, 4.05, 13.46]⇒y≈[0.78, 0.05, 0.17]
3. τ = 0.1 (非常接近 one-hot)
( z + g ) / 0.1 = [ 21 , 7 , 13 ] ⇒ exp ≈ [ 1.3 e 9 , 1.1 e 3 , 4.4 e 5 ] ⇒ y ≈ [ 0.9995 , 0.000001 , 0.0004 ] (z+g)/0.1 = [21,\\ 7,\\ 13] \\Rightarrow \\exp ≈ [1.3e9,\\ 1.1e3,\\ 4.4e5] \\Rightarrow y ≈ [0.9995,\\ ~0.000001,\\ ~0.0004] (z+g)/0.1=[21, 7, 13]⇒exp≈[1.3e9, 1.1e3, 4.4e5]⇒y≈[0.9995, 0.000001, 0.0004]
4. τ = 0.01(几乎就是 one-hot)
( z + g ) / 0.01 = [ 210 , 70 , 130 ] ⇒ exp ≈ [ e 210 , e 70 , e 130 ] ⇒ y ≈ [ 1.0 , 0 , 0 ] (z+g)/0.01 = [210,\\ 70,\\ 130] \\Rightarrow \\exp ≈ [e^{210}, e^{70}, e^{130}] \\Rightarrow y ≈ [1.0,\\ 0,\\ 0] (z+g)/0.01=[210, 70, 130]⇒exp≈[e210,e70,e130]⇒y≈[1.0, 0, 0]
这时:
- 输出已经几乎就是 [ 1.0 , 0.0 , 0.0 ] [1.0,\\ 0.0,\\ 0.0] [1.0, 0.0, 0.0],即 one-hot;
- 但由于 没有真实 argmax 操作,依然保持了可导性!
对比表:
- 温度越低,分布越尖锐;
- 最终趋向于 argmax 的 one-hot,但保留了梯度;
- 这是 Gumbel-Softmax 的最大价值:在训练时可以平滑可导地模拟采样过程。
2.3 温度调节器
温度调度器(temperature annealing) 是使用 Gumbel-Softmax 时常见的一种机制,它能让模型:
- 初期探索更多 group(温度高,分布平滑);
- 后期收敛到明确分组(温度低,趋于 one-hot);
- 从而达到“软→硬”的逐步可导采样。
- 指数退火(Exponential decay)
τ t = max ( τ min , τ 0 ⋅ exp ( − k ⋅ t ) ) \\tau_t = \\max(\\tau_{\\min},\\ \\tau_0 \\cdot \\exp(-k \\cdot t)) τt=max(τmin, τ0⋅exp(−k⋅t))
- τ 0 \\tau_0 τ0:初始温度(比如 1.0)
- k k k:退火速率(如 0.01)
- t t t:当前 epoch
- τ min \\tau_{\\min} τmin:最终保持的最小温度,避免梯度消失
- 线性退火
τ t = max ( τ min , τ 0 − r ⋅ t ) \\tau_t = \\max(\\tau_{\\min},\\ \\tau_0 - r \\cdot t) τt=max(τmin, τ0−r⋅t)
- r r r:线性下降速度
代码
class TemperatureScheduler: def __init__(self, tau0=1.0, tau_min=0.1, decay_rate=0.03): self.tau0 = tau0 self.tau_min = tau_min self.decay_rate = decay_rate def get_tau(self, epoch): tau = self.tau0 * np.exp(-self.decay_rate * epoch) return max(self.tau_min, tau)
用法
scheduler = TemperatureScheduler(tau0=1.0, tau_min=0.1, decay_rate=0.05)for epoch in range(50): tau = scheduler.get_tau(epoch) logits = model(x) gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits))) y = F.softmax((logits + gumbel_noise) / tau, dim=-1) ...
2.4 梯度回传
Softmax
Forward:
y i = exp ( z i ) ∑ j exp ( z j ) y_i = \\frac{\\exp(z_i)}{\\sum_j \\exp(z_j)} yi=∑jexp(zj)exp(zi)
Backward:
Softmax 的梯度对输入 logits z i z_i zi 的导数是:
∂ y i ∂ z k = y i ( δ i k− y k ) \\frac{\\partial y_i}{\\partial z_k} = y_i(\\delta_{ik} - y_k) ∂zk∂yi=yi(δik−yk)
如果我们和交叉熵 loss L=− ∑ i y i true log y i L = -\\sum_i y_i^{\\text{true}} \\log y_i L=−∑iyitruelogyi 联合使用:
∂ L ∂ z k = y k − y k true \\frac{\\partial L}{\\partial z_k} = y_k - y_k^{\\text{true}} ∂zk∂L=yk−yktrue
Gumbel-Softmax
Gumbel-Softmax 的前向过程是:
y i = exp ( ( z i + g i ) / τ ) ∑ j exp ( ( z j + g j ) / τ ) y_i = \\frac{\\exp((z_i + g_i)/\\tau)}{\\sum_j \\exp((z_j + g_j)/\\tau)} yi=∑jexp((zj+gj)/τ)exp((zi+gi)/τ)
虽然加了随机性,但 g_i 被视为常量(不会回传)!
因此,在反向传播时:
- 梯度仍然是从 softmax 函数计算;
- 只是 softmax 的输入变成了 z ~ i = ( z i + g i ) / τ \\tilde{z}_i = (z_i + g_i)/\\tau z~i=(zi+gi)/τ。
∂ y i ∂ z k = 1 τ ⋅ y i ( δ i k− y k ) \\frac{\\partial y_i}{\\partial z_k} = \\frac{1}{\\tau} \\cdot y_i(\\delta_{ik} - y_k) ∂zk∂yi=τ1⋅yi(δik−yk)
跟 Softmax 几乎一样,但额外多了一个 1 τ \\frac{1}{\\tau} τ1 缩放因子,温度越低,梯度越大。
注意:g_i 不传梯度,只是扰动 forward 的 logits;反向传播只看 z 的梯度。
3. Straight-Through Gumbel-Softmax(硬采样)
- 前向输出 one-hot(使用
argmax
); - 反向仍使用 soft Gumbel 的梯度;
- 类似于
detach
技巧:
y_hard = one_hot(torch.argmax(y, dim=-1))y = (y_hard - y).detach() + y # 前向为硬,反向为软
3.1 作用
问题背景:
我们经常希望模型明确地做选择,比如:
- 哪个 token 属于哪个 group(GroupViT)
- 从 N 个动作中选一个(RL / controller)
- 选择某个 module、某个结构(NAS)
但如果你直接用 argmax:
- 前向可以选;
- 反向就挂了(不可导)。
ST Gumbel-Softmax 的解决方案:
你可以:
- 前向:使用 one-hot(离散采样)
- 反向:仍然使用 Gumbel-Softmax 的软梯度(保持可导)
场景举例
3.2 GroupViT
GroupViT 要做的事是:让每个 patch token 分配给一个明确的 group token
也就是:每个 patch → 只属于一个 group(语义聚类);然后对同组 patch 聚合成一个语义 group。
为什么不能直接 argmax?
- argmax 是硬选择:(想要)
- argmax 不可导:(不能训练)
如果把 Gumbel 换成普通 Softmax 会怎么样?
结果:模型会退化为“软分配”,语义不清晰
具体表现如下:
-
输出是概率分布;
-
每个 patch 会被“部分分配”到所有 group 上;
patch i = 0.4 ⋅ g 1 + 0.3 ⋅ g 2 + 0.3 ⋅ g 3 \\text{patch}_i = 0.4 \\cdot g_1 + 0.3 \\cdot g_2 + 0.3 \\cdot g_3 patchi=0.4⋅g1+0.3⋅g2+0.3⋅g3
-
聚类效果变差,语义边界模糊;
-
训练容易,但分组不清晰、不稳定;
总结句
GroupViT 使用 ST Gumbel-Softmax,是为了同时实现:
- 离散分组(清晰 one-hot)
- 可导训练(不依赖 argmax)
如果换成普通 softmax,会变成“模糊分组”,模型将失去“明确聚类语义区域”的能力,本质上不再是 GroupViT 的设计初衷了。