人工智能概念:常用的模型压缩技术(剪枝、量化、知识蒸馏)
文章目录
一、模型压缩概述
1.1 什么是模型压缩?
模型压缩是一类通过减少模型参数数量、降低计算复杂度,从而在资源受限设备上高效部署深度学习模型的技术。其核心目标是在模型性能损失最小化的前提下,显著减小模型体积、降低内存占用、提升推理速度,以适应移动端、嵌入式设备等资源受限场景的需求。
1.2 为什么需要模型压缩?
随着Transformer等大模型的兴起,模型参数规模呈指数级增长。例如,原始BERT-base模型参数量约110M,推理时不仅占用大量内存,还需要较高的计算资源,难以直接部署在手机、摄像头等边缘设备上。此外,大模型的高推理延迟也无法满足实时性要求较高的业务场景(如实时推荐、语音助手)。
模型压缩的意义在于:
- 降低存储成本:减小模型文件大小,节省存储空间。
- 提升推理速度:减少计算量,降低延迟,满足实时性需求。
- 降低部署门槛:使模型能够在算力有限的边缘设备上运行。
- 减少能耗:降低推理过程中的能量消耗,适合移动设备。
1.3 四种主流模型压缩技术
目前,业界常用的模型压缩技术主要有四类:
二、模型量化:用低精度换高效能
2.1 量化的数学原理
量化的核心是将高精度浮点数(如float32)映射到低精度整数(如int8),核心公式涉及缩放因子(scale) 和零点(zero_point) 的计算。
-
基本定义
- 设浮点数范围为 [ x min , x max ] [x_{\\text{min}}, x_{\\text{max}}] [xmin,xmax],对应整数范围为 [ q min , q max ] [q_{\\text{min}}, q_{\\text{max}}] [qmin,qmax](如int8的 [ − 128 , 127 ] [-128, 127] [−128,127])。
- 缩放因子 s s s:控制浮点数到整数的比例映射。
- 零点 z z z:确保映射的偏移量(使0附近的浮点数能准确映射)。
-
核心公式
s =x max − x min q max − q min (1) s = \\frac{x_{\\text{max}} - x_{\\text{min}}}{q_{\\text{max}} - q_{\\text{min}}} \\tag{1} s=qmax−qminxmax−xmin(1)
z = q min − round ( x min s ) (2) z = q_{\\text{min}} - \\text{round}\\left(\\frac{x_{\\text{min}}}{s}\\right) \\tag{2} z=qmin−round(sxmin)(2)
q = clip ( round ( x s + z ) , q min , q max ) (3) q = \\text{clip}\\left(\\text{round}\\left(\\frac{x}{s} + z\\right), q_{\\text{min}}, q_{\\text{max}}\\right) \\tag{3} q=clip(round(sx+z),qmin,qmax)(3)- 式(1):计算缩放因子,将浮点数范围映射到整数范围。
- 式(2):计算零点,确保 x min x_{\\text{min}} xmin能映射到 q min q_{\\text{min}} qmin。
- 式(3):将浮点数 x x x量化为整数 q q q,并裁剪到整数范围内。
-
反量化公式(推理时还原)
x recon = s ⋅ ( q − z ) (4) x_{\\text{recon}} = s \\cdot (q - z) \\tag{4} xrecon=s⋅(q−z)(4)
2.2 量化计算示例
以float32到int8的量化为例,假设某层权重的浮点数范围为 [−1.2,3.6] [-1.2, 3.6] [−1.2,3.6],计算过程如下:
步骤1:确定范围
- 浮点数: x min = − 1.2 x_{\\text{min}} = -1.2 xmin=−1.2, x max = 3.6 x_{\\text{max}} = 3.6 xmax=3.6
- int8整数: q min = − 128 q_{\\text{min}} = -128 qmin=−128, q max = 127 q_{\\text{max}} = 127 qmax=127,范围长度 127 − ( − 128 ) = 255 127 - (-128) = 255 127−(−128)=255
步骤2:计算缩放因子 s s s
s = 3.6 − ( − 1.2 ) 255 = 4.8 255 ≈ 0.0188 s = \\frac{3.6 - (-1.2)}{255} = \\frac{4.8}{255} \\approx 0.0188 s=2553.6−(−1.2)=2554.8≈0.0188
步骤3:计算零点 z z z
z = − 128 − round ( − 1.2 0.0188 ) = − 128 − round ( − 63.83 ) = − 128 + 64 = − 64 z = -128 - \\text{round}\\left(\\frac{-1.2}{0.0188}\\right) = -128 - \\text{round}(-63.83) = -128 + 64 = -64 z=−128−round(0.0188−1.2)=−128−round(−63.83)=−128+64=−64
步骤4:量化单个浮点数
例如量化 x=0.5 x = 0.5 x=0.5:
q = round ( 0.5 0.0188 + ( − 64 ) ) = round ( 26.59 − 64 ) = round ( − 37.41 ) = − 37 q = \\text{round}\\left(\\frac{0.5}{0.0188} + (-64)\\right) = \\text{round}(26.59 - 64) = \\text{round}(-37.41) = -37 q=round(0.01880.5+(−64))=round(26.59−64)=round(−37.41)=−37
- 量化结果: q = − 37 q = -37 q=−37(在int8范围内)。
步骤5:反量化验证
x recon = 0.0188 × ( − 37 − ( − 64 ) ) = 0.0188 × 27 ≈ 0.5076 ≈ 0.5 x_{\\text{recon}} = 0.0188 \\times (-37 - (-64)) = 0.0188 \\times 27 \\approx 0.5076 \\approx 0.5 xrecon=0.0188×(−37−(−64))=0.0188×27≈0.5076≈0.5
- 误差: ∣ 0.5076 − 0.5 ∣ = 0.0076 |0.5076 - 0.5| = 0.0076 ∣0.5076−0.5∣=0.0076,精度损失较小。
2.3 量化相关API详解
- PyTorch量化API
torch.quantization.quantize_dynamic
model
:待量化模型-
qconfig_spec
:指定需量化的层类型(如{torch.nn.Linear}
)-
dtype
:目标数据类型(如torch.qint8
)torch.quantization.prepare
model
:待准备模型-
qconfig
:量化配置(如torch.quantization.get_default_qconfig(\'fbgemm\')
)torch.quantization.convert
model
:经prepare
处理的模型torch.quantization.QConfig
activation
:激活量化方式(如FakeQuantize.with_args(observer=MovingAverageMinMaxObserver)
)-
weight
:权重量化方式- TensorFlow量化API
tf.quantization.quantize
input
:待量化张量-
min_range
/max_range
:输入范围-
T
:目标类型(如tf.int8
)tf.keras.layers.experimental.QuantizationAwareTraining
input_shape
:输入形状-
num_bits
:量化位数- ONNX Runtime量化API
onnxruntime.quantization.quantize_dynamic
input_model
:输入ONNX模型路径-
output_model
:输出量化模型路径-
op_types_to_quantize
:需量化的算子类型(如[\'MatMul\', \'Add\']
)- 量化注意事项
- 动态量化适合CPU端部署,GPU量化建议使用TensorRT的INT8校准工具。
- 量化对模型精度的影响与任务相关:图像分类通常比目标检测更耐量化,文本分类比NER更耐量化。
- 混合精度量化(如部分层用float16,部分用int8)可在精度和速度间取得更好平衡。
三、知识蒸馏:让小模型学会大模型的“智慧”
3.1 知识蒸馏的数学原理
知识蒸馏的核心是通过KL散度衡量学生模型与教师模型的输出差异,结合硬标签损失优化学生模型。
-
软标签生成
教师模型的logits经过温度 T T T调整后生成软标签:
p i = exp ( z i / T ) ∑ j exp ( z j / T ) (5) p_i = \\frac{\\exp(z_i / T)}{\\sum_j \\exp(z_j / T)} \\tag{5} pi=∑jexp(zj/T)exp(zi/T)(5)- z i z_i zi:教师模型对第 i i i类的logits输出。
- T T T:温度参数( T > 1 T>1 T>1使分布更平滑,保留更多知识)。
-
KL散度损失(软标签损失)
衡量学生软标签 q q q与教师软标签 p p p的差异:
L KL = ∑ i p i log ( p i q i ) (6) L_{\\text{KL}} = \\sum_i p_i \\log\\left(\\frac{p_i}{q_i}\\right) \\tag{6} LKL=i∑pilog(qipi)(6)- 当 T = 1 T=1 T=1时,KL散度退化为交叉熵损失。
-
总损失函数
L total = α ⋅ L KL + ( 1 − α ) ⋅ L CE (7) L_{\\text{total}} = \\alpha \\cdot L_{\\text{KL}} + (1-\\alpha) \\cdot L_{\\text{CE}} \\tag{7} Ltotal=α⋅LKL+(1−α)⋅LCE(7)- L CE L_{\\text{CE}} LCE:学生模型与真实标签的交叉熵损失(硬标签损失)。
- α \\alpha α:软标签损失的权重(通常取0.5-0.9)。
3.2 蒸馏计算示例
以三分类任务为例,演示损失计算过程:
步骤1:模型输出
- 教师模型logits: z teacher = [ 3.0 , 1.0 , 0.2 ] z_{\\text{teacher}} = [3.0, 1.0, 0.2] zteacher=[3.0,1.0,0.2]
- 学生模型logits: z student = [ 2.5 , 0.8 , 0.1 ] z_{\\text{student}} = [2.5, 0.8, 0.1] zstudent=[2.5,0.8,0.1]
- 真实标签: y = [ 1 , 0 , 0 ] y = [1, 0, 0] y=[1,0,0](第0类)
步骤2:生成软标签( T = 2.0 T=2.0 T=2.0)
- 教师软标签:
p = [ exp ( 3 / 2 ) ∑ , exp ( 1 / 2 ) ∑ , exp ( 0.2 / 2 ) ∑ ] ≈ [ 0.721 , 0.215 , 0.064 ] p = \\left[ \\frac{\\exp(3/2)}{\\sum}, \\frac{\\exp(1/2)}{\\sum}, \\frac{\\exp(0.2/2)}{\\sum} \\right] \\approx [0.721, 0.215, 0.064]p=[∑exp(3/2),∑exp(1/2),∑exp(0.2/2)]≈[0.721,0.215,0.064] - 学生软标签:
q = [ exp ( 2.5 / 2 ) ∑ , exp ( 0.8 / 2 ) ∑ , exp ( 0.1 / 2 ) ∑ ] ≈ [ 0.659 , 0.257 , 0.084 ] q = \\left[ \\frac{\\exp(2.5/2)}{\\sum}, \\frac{\\exp(0.8/2)}{\\sum}, \\frac{\\exp(0.1/2)}{\\sum} \\right] \\approx [0.659, 0.257, 0.084]q=[∑exp(2.5/2),∑exp(0.8/2),∑exp(0.1/2)]≈[0.659,0.257,0.084]
步骤3:计算损失
- KL散度损失( L KL L_{\\text{KL}} LKL):
L KL = 0.721 log ( 0.721 / 0.659 ) + 0.215 log ( 0.215 / 0.257 ) + 0.064 log ( 0.064 / 0.084 ) ≈ 0.018 L_{\\text{KL}} = 0.721\\log(0.721/0.659) + 0.215\\log(0.215/0.257) + 0.064\\log(0.064/0.084) \\approx 0.018LKL=0.721log(0.721/0.659)+0.215log(0.215/0.257)+0.064log(0.064/0.084)≈0.018 - 硬标签损失( L CE L_{\\text{CE}} LCE):
L CE = − log ( q 0 ) ≈ − log ( 0.659 ) ≈ 0.418 L_{\\text{CE}} = -\\log(q_0) \\approx -\\log(0.659) \\approx 0.418LCE=−log(q0)≈−log(0.659)≈0.418 - 总损失( α = 0.7 \\alpha=0.7 α=0.7):
L total = 0.7 × 0.018 + 0.3 × 0.418 ≈ 0.138 L_{\\text{total}} = 0.7 \\times 0.018 + 0.3 \\times 0.418 \\approx 0.138Ltotal=0.7×0.018+0.3×0.418≈0.138
3.3 知识蒸馏相关API详解
- Hugging Face Transformers API
transformers.Trainer
model
:学生模型-
args
:训练参数(如TrainingArguments
)-
compute_loss
:自定义损失函数(融合KL散度和交叉熵)transformers.DistilBertForSequenceClassification
PreTrainedModel
,可直接加载预训练权重(如distilbert-base-uncased
)- PyTorch蒸馏工具
torch.nn.KLDivLoss
reduction
:损失聚合方式(如\'batchmean\'
)-
log_target
:是否目标为对数形式torch.nn.CrossEntropyLoss
weight
:类别权重-
reduction
:损失聚合方式- 专用蒸馏库
HuggingFace/transformers
中的蒸馏工具from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
knowledge-distillation-pytorch
from kd import KnowledgeDistillationLoss
(融合KL散度和硬标签损失)四、模型剪枝:移除冗余参数,保留核心能力
4.1 剪枝的数学原理
剪枝通过评估参数重要性移除冗余权重,常用L1范数衡量重要性(值越小越冗余)。
-
L1范数重要性评估
对于权重矩阵 W ∈ R m × n W \\in \\mathbb{R}^{m \\times n} W∈Rm×n,单个权重 w i j w_{ij} wij的重要性为:
I ( w i j ) = ∣ w i j ∣ (8) \\mathcal{I}(w_{ij}) = |w_{ij}| \\tag{8} I(wij)=∣wij∣(8) -
全局剪枝阈值计算
若剪枝比例为 r r r,则阈值 θ \\theta θ满足:
∑ i , j I ( ∣ w i j ∣ < θ ) 总参数数 = r (9) \\frac{\\sum_{i,j} \\mathbb{I}(|w_{ij}| < \\theta)}{\\text{总参数数}} = r \\tag{9} 总参数数∑i,jI(∣wij∣<θ)=r(9)- I ( ⋅ ) \\mathbb{I}(\\cdot) I(⋅)为指示函数,满足条件时取1。
4.2 剪枝计算示例
以3x3权重矩阵为例,剪枝30%的参数:
步骤1:原始权重矩阵
W = [ 0.1 − 0.02 0.05 − 0.3 0.01 0.2 0.03 − 0.04 0.08 ] W = \\begin{bmatrix} 0.1 & -0.02 & 0.05 \\\\ -0.3 & 0.01 & 0.2 \\\\ 0.03 & -0.04 & 0.08 \\end{bmatrix} W= 0.1−0.30.03−0.020.01−0.040.050.20.08
步骤2:计算L1范数(重要性)
∣ I ( W ) ∣ = [ 0.1 0.02 0.05 0.3 0.01 0.2 0.03 0.04 0.08 ] |\\mathcal{I}(W)| = \\begin{bmatrix} 0.1 & 0.02 & 0.05 \\\\ 0.3 & 0.01 & 0.2 \\\\ 0.03 & 0.04 & 0.08 \\end{bmatrix} ∣I(W)∣= 0.10.30.030.020.010.040.050.20.08
步骤3:排序并确定阈值
将所有权重按L1范数升序排列: 0.01,0.02,0.03,0.04,0.05,0.08,0.1,0.2,0.3 0.01, 0.02, 0.03, 0.04, 0.05, 0.08, 0.1, 0.2, 0.3 0.01,0.02,0.03,0.04,0.05,0.08,0.1,0.2,0.3
总参数9个,剪枝30%即移除3个参数,阈值 θ=0.03 \\theta=0.03 θ=0.03(第3小的值)。
步骤4:剪枝后矩阵(小于 θ \\theta θ的权重置0)
W pruned = [ 0.1 0 0.05 − 0.3 0 0.2 0 − 0.04 0.08 ] W_{\\text{pruned}} = \\begin{bmatrix} 0.1 & 0 & 0.05 \\\\ -0.3 & 0 & 0.2 \\\\ 0 & -0.04 & 0.08 \\end{bmatrix} Wpruned= 0.1−0.3000−0.040.050.20.08
- 稀疏度:3/9=33.3%(接近目标30%)。
4.3 模型剪枝相关API详解
- PyTorch剪枝API
torch.nn.utils.prune.l1_unstructured
module
:待剪枝模块(如model.bert.encoder.layer[0].attention.self.query
)-
name
:待剪枝参数名(如\'weight\'
)-
amount
:剪枝比例(如0.3
)torch.nn.utils.prune.global_unstructured
parameters
:待剪枝参数列表(如[(module, \'weight\')]
)-
pruning_method
:剪枝方法(如prune.L1Unstructured
)-
amount
:剪枝比例torch.nn.utils.prune.remove
module
:已剪枝模块-
name
:剪枝参数名torch.nn.utils.prune.ln_structured
n
:剪枝维度(如0
表示按输出通道剪枝)-
amount
:剪枝比例-
pruning_method
:重要性评估方法(如\'l1_unstructured\'
)- TensorFlow剪枝API
tfmot.sparsity.keras.prune_low_magnitude
model
:待剪枝模型-
pruning_schedule
:剪枝调度(如PolynomialDecay
)tfmot.sparsity.keras.PolynomialDecay
initial_sparsity
:初始稀疏度-
final_sparsity
:目标稀疏度-
num_steps
:总步数- 第三方剪枝工具
TorchPrune
PruneTorch
4.4 剪枝注意事项
- 非结构化剪枝生成稀疏矩阵,需硬件支持(如NVIDIA的Sparse Tensor Core)才能加速,否则可能变慢。
- 结构化剪枝(如按通道)生成密集矩阵,无需特殊硬件,但剪枝比例过高会导致精度大幅下降。
- 剪枝后需微调模型(fine-tuning),恢复因剪枝丢失的性能(通常微调3-5个epoch即可)。
五、总结
模型压缩技术通过数学原理与工程实现的结合,在精度与效率间取得平衡,其核心API为工业界部署提供了便捷工具:
- 量化:通过
torch.quantization.quantize_dynamic
等API实现低精度转换,适合追求极致部署效率的场景,API使用简单但需注意精度权衡。 - 蒸馏:基于
KLDivLoss
与CrossEntropyLoss
组合,或使用DistilBERT
等预训练蒸馏模型,适合需要保留高精度的小模型场景。 - 剪枝:通过
global_unstructured
等API移除冗余参数,适合对模型大小敏感且可接受一定部署复杂度的场景。
实际应用中,可组合多种技术(如“剪枝+量化”)进一步提升压缩效果,例如先剪枝移除30%冗余参数,再量化为int8,可在精度损失5%以内实现模型体积缩减80%以上,推动大模型在边缘设备的落地。