> 技术文档 > AI大模型基础:预训练与微调(迁移学习与微调策略)

AI大模型基础:预训练与微调(迁移学习与微调策略)

AI大模型基础:预训练与微调(迁移学习与微调策略)

🧑 博主简介:CSDN博客专家、CSDN平台优质创作者,高级开发工程师,数学专业,10年以上C/C++, C#,Java等多种编程语言开发经验,拥有高级工程师证书;擅长C/C++、C#等开发语言,熟悉Java常用开发技术,能熟练应用常用数据库SQL server,Oracle,mysql,postgresql等进行开发应用,熟悉DICOM医学影像及DICOM协议,业余时间自学JavaScript,Vue,qt,python等,具备多种混合语言开发能力。撰写博客分享知识,致力于帮助编程爱好者共同进步。欢迎关注、交流及合作,提供技术支持与解决方案。\\n技术合作请加本人wx(注明来自csdn):xt20160813


AI大模型基础:预训练与微调(迁移学习与微调策略)

预训练与微调是现代AI大模型(如BERT、GPT、ViT)的核心技术,基于迁移学习范式,通过在大规模数据集上预训练模型并在特定任务上微调,显著提升性能和效率。本文将深入讲解预训练与微调的原理、实现方法及在实际场景中的应用,适合对AI大模型感兴趣的读者。文章结构如下:

  1. 预训练与微调概述:定义、迁移学习背景及在NLP和CV中的重要性。
  2. 预训练原理与实现:任务设计、数据需求及代码示例。
  3. 微调策略:全参数微调、部分微调、参数高效微调(PEFT)等。
  4. 应用案例:NLP(文本分类)、CV(图像分类)及医学影像领域的实现。
  5. 流程图与图表:提供Mermaid流程图及性能对比图表。
  6. 总结与展望:总结关键点及未来发展趋势。

一、预训练与微调概述

1.1 定义与目标

  • 预训练:在大规模、无标注或弱标注数据上训练模型,学习通用特征(如语言模式、视觉特征),构建强大的初始参数。
  • 微调:在特定下游任务(如文本分类、图像分割)上调整预训练模型参数,适配任务需求。
  • 目标
    • 泛化能力:预训练捕捉通用知识,微调适配特定场景。
    • 高效性:减少从头训练成本,适合小数据集。
    • 高性能:提升任务精度,尤其在数据有限的领域(如医学影像)。

1.2 迁移学习背景

  • 迁移学习:将从一个任务(源任务)学到的知识应用于另一个任务(目标任务)。预训练与微调是迁移学习的典型实现。
  • 发展历程
    • 早期:特征提取(如SIFT、HOG)+简单分类器(如SVM)。
    • 深度学习时代:CNN(如VGG、ResNet)预训练于ImageNet,微调下游任务。
    • Transformer时代:BERT、GPT、ViT通过自监督预训练,革新NLP和CV。
  • 优势
    • 利用大规模数据(如Wikipedia、ImageNet)学习通用表示。
    • 小数据集也能实现高性能,适合医学影像等场景。

1.3 重要性

  • NLP:预训练(如BERT的MLM)捕获语义,微调适配情感分析、问答等任务。
  • CV:ViT预训练于ImageNet,微调用于肿瘤检测、器官分割。
  • 医学影像:数据稀缺,预训练模型(如ViT)通过迁移学习显著提升分类精度。

1.4 挑战

  • 计算成本:预训练需大量GPU/TPU资源。
  • 过拟合风险:微调时小数据集可能导致过拟合。
  • 任务适配:不同任务需不同微调策略(如全参数 vs. 部分微调)。
  • 可解释性:微调后模型行为难以解释,医学领域需谨慎。

二、预训练原理与实现

2.1 原理

预训练通过自监督或弱监督任务,在大规模数据上学习通用表示,无需特定任务的标注数据。

核心机制
  • 自监督学习(SSL)
    • NLP:掩码语言模型(MLM,如BERT)、自回归语言建模(如GPT)。
    • CV:图像分类(如ViT)、对比学习(如SimCLR)。
  • 数据规模:需大规模语料(如Wikipedia、BooksCorpus)或图像数据集(如ImageNet、JFT-300M)。
  • 模型结构
    • Transformer:基于自注意力机制,捕获全局依赖。
    • 多层设计:如BERT的12/24层,ViT的Patch嵌入+编码器。
预训练任务
  1. 掩码语言模型(MLM)
    • 随机掩盖输入词(15%),预测被掩盖词,学习双向上下文(BERT)。
    • 公式:最大化条件概率 P(wi∣w1:i−1,wi+1:n)P(w_i | w_{1:i-1}, w_{i+1:n})P(wiw1:i1,wi+1:n)
  2. 自回归语言建模
    • 预测下一个词,基于前文(GPT)。
    • 公式:最大化 P(wt∣w1:t−1)P(w_t | w_{1:t-1})P(wtw1:t1)
  3. 图像分类
    • ViT在ImageNet上预测类别,学习视觉特征。
    • 公式:最小化交叉熵损失L=−∑yilog⁡(y^i)L = -\\sum y_i \\log(\\hat{y}_i)L=yilog(y^i)
  4. 对比学习
    • 增强图像(如旋转、裁剪),使相同图像的表示靠近(SimCLR)。
数学基础
  • 自注意力
    Attention(Q,K,V)=softmax(QKTdk)V\\text{Attention}(Q, K, V) = \\text{softmax}\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)VAttention(Q,K,V)=softmax(dkQKT)V
    其中 Q,K,VQ, K, VQ,K,V为查询、键、值向量,dkd_kdk 为键维度。
  • 损失函数
    • MLM:交叉熵损失,预测掩盖词。
    • 自回归:最大化序列似然。
    • 分类:交叉熵或对比损失。
优缺点
  • 优点:学习通用表示,减少下游任务数据需求。
  • 缺点:计算成本高,需大规模数据支持。
  • 适用场景:NLP(文本理解/生成)、CV(图像分类/分割)。

2.2 实现示例(Python)

以下以BERT的MLM预训练为例,使用Hugging Face模拟小规模预训练:

from transformers import BertTokenizer, BertForMaskedLMfrom transformers import DataCollatorForLanguageModeling, Trainer, TrainingArgumentsfrom datasets import load_datasetimport torch# 加载数据(示例:WikiText)dataset = load_dataset(\'wikitext\', \'wikitext-2-raw-v1\', split=\'train\')texts = [text for text in dataset[\'text\'] if len(text) > 0][:1000] # 取前1000条# 加载分词器和模型tokenizer = BertTokenizer.from_pretrained(\'bert-base-uncased\')model = BertForMaskedLM.from_pretrained(\'bert-base-uncased\')# 预处理数据def tokenize_function(examples): return tokenizer(examples[\'text\'], padding=\'max_length\', truncation=True, max_length=128)tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=[\'text\'])data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)# 设置训练参数training_args = TrainingArguments( output_dir=\'./pretrain_results\', num_train_epochs=1, per_device_train_batch_size=8, logging_steps=100, save_steps=500)# 训练模型trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=data_collator)trainer.train()

代码注释

  • load_dataset:加载WikiText数据集,模拟大规模语料。
  • BertTokenizer:将文本转换为输入ID和注意力掩码。
  • BertForMaskedLM:BERT模型,带MLM头。
  • DataCollatorForLanguageModeling:动态掩盖15%输入词,生成MLM任务。
  • Trainer:Hugging Face训练接口,简化预训练流程。
  • 注意:实际预训练需更大数据集(如Wikipedia)和更多计算资源。

三、微调策略

3.1 原理

微调通过调整预训练模型参数,适配特定下游任务,分为以下策略:

1. 全参数微调
  • 调整模型所有参数,适合数据充足场景。
  • 优点:充分利用预训练知识,性能最佳。
  • 缺点:计算成本高,易过拟合(小数据集)。
2. 部分微调
  • 冻结部分层(如低层编码器),仅微调顶层或分类头。
  • 优点:降低计算成本,减少过拟合风险。
  • 缺点:可能损失部分预训练知识。
3. 参数高效微调(PEFT)
  • 方法
    • LoRA(Low-Rank Adaptation):在权重矩阵上添加低秩更新,调整少量参数。
    • Prompt Tuning:添加可训练的提示向量,冻结模型参数。
    • Adapter:在每层插入小型适配器模块。
  • 公式(LoRA)
    W=W0+ΔW,ΔW=BAW = W_0 + \\Delta W, \\quad \\Delta W = BAW=W0+ΔW,ΔW=BA
    其中 W0W_0W0为预训练权重,ΔW\\Delta WΔW为低秩更新,B,AB, AB,A为小矩阵。
  • 优点:参数量少(<1%),适合资源受限场景。
  • 缺点:性能略低于全参数微调。
优缺点
  • 优点:灵活适配任务,降低训练成本。
  • 缺点:需根据任务选择策略,调试复杂。
  • 适用场景:小数据集(如医学影像)、资源受限环境。

3.2 实现示例(Python)

以下以BERT全参数微调(文本分类)和LoRA微调为例:

from transformers import BertTokenizer, BertForSequenceClassificationfrom peft import LoraConfig, get_peft_modelfrom transformers import Trainer, TrainingArgumentsimport torchfrom torch.utils.data import Dataset# 自定义数据集class TextDataset(Dataset): def __init__(self, texts, labels, tokenizer, max_len=128): self.texts = texts self.labels = labels self.tokenizer = tokenizer self.max_len = max_len def __len__(self): return len(self.texts) def __getitem__(self, idx): text = str(self.texts[idx]) label = self.labels[idx] encoding = self.tokenizer( text, add_special_tokens=True, max_length=self.max_len, padding=\'max_length\', truncation=True, return_tensors=\'pt\' ) return { \'input_ids\': encoding[\'input_ids\'].flatten(), \'attention_mask\': encoding[\'attention_mask\'].flatten(), \'labels\': torch.tensor(label, dtype=torch.long) }# 数据准备(示例:情感分析)texts = [\"I love this movie!\", \"This movie is terrible.\"]labels = [1, 0] # 1: 正向,0: 负向tokenizer = BertTokenizer.from_pretrained(\'bert-base-uncased\')dataset = TextDataset(texts, labels, tokenizer)# 全参数微调model_full = BertForSequenceClassification.from_pretrained(\'bert-base-uncased\', num_labels=2)training_args_full = TrainingArguments( output_dir=\'./full_finetune_results\', num_train_epochs=3, per_device_train_batch_size=8, logging_steps=10, save_steps=100)trainer_full = Trainer(model=model_full, args=training_args_full, train_dataset=dataset)trainer_full.train()# LoRA微调model_lora = BertForSequenceClassification.from_pretrained(\'bert-base-uncased\', num_labels=2)lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=[\"query\", \"value\"])model_lora = get_peft_model(model_lora, lora_config)training_args_lora = TrainingArguments( output_dir=\'./lora_finetune_results\', num_train_epochs=3, per_device_train_batch_size=8, logging_steps=10, save_steps=100)trainer_lora = Trainer(model=model_lora, args=training_args_lora, train_dataset=dataset)trainer_lora.train()# 推理(以全参数微调为例)model_full.eval()text = \"This is a great film!\"inputs = tokenizer(text, return_tensors=\'pt\', max_length=128, padding=True, truncation=True)outputs = model_full(**inputs)predictions = torch.argmax(outputs.logits, dim=-1)print(\"全参数微调预测:\", \"正向\" if predictions.item() == 1 else \"负向\")

代码注释

  • TextDataset:处理文本和标签,适配BERT输入格式。
  • BertForSequenceClassification:预训练BERT,添加分类头(2类)。
  • LoraConfig:配置LoRA参数,r=8控制低秩矩阵维度,target_modules指定微调自注意力层。
  • get_peft_model:应用LoRA,冻结大部分参数,仅训练低秩更新。
  • Trainer:Hugging Face接口,简化微调流程。
  • 注意:LoRA显著减少参数量(约0.1%),适合小数据集或低资源场景。

四、应用案例

4.1 NLP:文本分类

  • 任务:情感分析(正向/负向),如电影评论分类。
  • 预训练模型:BERT,学习通用语义表示。
  • 微调策略
    • 全参数微调:适配大数据集(如IMDB)。
    • LoRA:适合小数据集(如医疗报告情感分析)。
  • 代码:见3.2实现,微调BERT进行二分类。

4.2 CV:图像分类

  • 任务:肿瘤检测(如乳腺癌X光片分类)。
  • 预训练模型:ViT,预训练于ImageNet。
  • 微调策略
    • 全参数微调:适配大型医学影像数据集。
    • Adapter:添加小型适配器,适合小数据集。
  • 实现示例
from transformers import ViTImageProcessor, ViTForImageClassificationfrom peft import LoraConfig, get_peft_modelfrom PIL import Imageimport torch# 加载数据(示例:医学影像)image = Image.open(\"tumor_image.jpg\").convert(\'RGB\')labels = [1] # 1: 恶性,0: 良性# 预处理processor = ViTImageProcessor.from_pretrained(\'google/vit-base-patch16-224\')inputs = processor(images=image, return_tensors=\'pt\')# LoRA微调model = ViTForImageClassification.from_pretrained(\'google/vit-base-patch16-224\', num_labels=2)lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=[\"query\", \"value\"])model = get_peft_model(model, lora_config)# 训练(假设数据集)# 类似TextDataset,省略训练代码# 推理model.eval()outputs = model(**inputs)predictions = torch.argmax(outputs.logits, dim=-1)print(\"预测结果:\", \"恶性\" if predictions.item() == 1 else \"良性\")

代码注释

  • ViTImageProcessor:处理图像,分块并归一化。
  • ViTForImageClassification:ViT模型,带分类头。
  • LoraConfig:应用LoRA微调,减少参数量。

4.3 医学影像领域

  • 任务:肺癌CT分类、脑部MRI分割。
  • 挑战:样本少、类别不平衡、误诊成本高。
  • 解决方案
    • 预训练:ViT在ImageNet或医学影像数据集(如LIDC-IDRI)上预训练。
    • 微调:使用LoRA或Adapter,适配小数据集,关注召回率(减少漏诊)。
  • 优势:迁移学习利用通用视觉特征,提升小数据集性能。

五、流程图与图表

5.1 预训练与微调流程图

以下是流程图:

#mermaid-svg-64uEgUVwrPnLUSoT {font-family:\"trebuchet ms\",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-64uEgUVwrPnLUSoT .error-icon{fill:#552222;}#mermaid-svg-64uEgUVwrPnLUSoT .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-64uEgUVwrPnLUSoT .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-64uEgUVwrPnLUSoT .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-64uEgUVwrPnLUSoT .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-64uEgUVwrPnLUSoT .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-64uEgUVwrPnLUSoT .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-64uEgUVwrPnLUSoT .marker{fill:#333333;stroke:#333333;}#mermaid-svg-64uEgUVwrPnLUSoT .marker.cross{stroke:#333333;}#mermaid-svg-64uEgUVwrPnLUSoT svg{font-family:\"trebuchet ms\",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-64uEgUVwrPnLUSoT .label{font-family:\"trebuchet ms\",verdana,arial,sans-serif;color:#333;}#mermaid-svg-64uEgUVwrPnLUSoT .cluster-label text{fill:#333;}#mermaid-svg-64uEgUVwrPnLUSoT .cluster-label span{color:#333;}#mermaid-svg-64uEgUVwrPnLUSoT .label text,#mermaid-svg-64uEgUVwrPnLUSoT span{fill:#333;color:#333;}#mermaid-svg-64uEgUVwrPnLUSoT .node rect,#mermaid-svg-64uEgUVwrPnLUSoT .node circle,#mermaid-svg-64uEgUVwrPnLUSoT .node ellipse,#mermaid-svg-64uEgUVwrPnLUSoT .node polygon,#mermaid-svg-64uEgUVwrPnLUSoT .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-64uEgUVwrPnLUSoT .node .label{text-align:center;}#mermaid-svg-64uEgUVwrPnLUSoT .node.clickable{cursor:pointer;}#mermaid-svg-64uEgUVwrPnLUSoT .arrowheadPath{fill:#333333;}#mermaid-svg-64uEgUVwrPnLUSoT .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-64uEgUVwrPnLUSoT .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-64uEgUVwrPnLUSoT .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-64uEgUVwrPnLUSoT .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-64uEgUVwrPnLUSoT .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-64uEgUVwrPnLUSoT .cluster text{fill:#333;}#mermaid-svg-64uEgUVwrPnLUSoT .cluster span{color:#333;}#mermaid-svg-64uEgUVwrPnLUSoT div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:\"trebuchet ms\",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-64uEgUVwrPnLUSoT :root{--mermaid-font-family:\"trebuchet ms\",verdana,arial,sans-serif;}数据准备: 大规模语料/图像预训练: 自监督任务(MLM/分类)预训练模型: BERT/GPT/ViT微调: 全参数/LoRA/Adapter下游任务: 分类/生成/分割输出: 预测/生成结果

说明

  • A(数据准备):大规模文本(Wikipedia)或图像(ImageNet)。
  • B(预训练):MLM(BERT)、自回归(GPT)、分类(ViT)。
  • C(预训练模型):生成通用表示的模型。
  • D(微调):全参数、LoRA或Adapter适配任务。
  • E(下游任务):分类(情感分析、肿瘤检测)、生成(对话)、分割(器官)。
  • F(输出):任务特定结果,如类别标签或生成文本。

5.2 图表:微调策略性能对比

以下为全参数微调与LoRA在分类任务上的性能对比折线图(假设数据)。
AI大模型基础:预训练与微调(迁移学习与微调策略)

说明

  • 图表类型:折线图,比较全参数微调与LoRA的准确率。
  • X轴:数据集大小(100、1000、10000样本)。
  • Y轴:准确率,范围0.7-1.0。
  • 数据:假设数据,显示全参数微调在大数据集上更优,LoRA在小数据集上接近。
  • 医学意义:LoRA适合医学影像小数据集,平衡性能与资源。

六、总结与展望

6.1 总结

  • 预训练:通过自监督任务(如MLM、自回归)学习通用表示,降低下游任务数据需求。
  • 微调
    • 全参数微调:适合大数据集,性能最佳。
    • 部分微调/PEFT(如LoRA):适合小数据集或低资源场景。
  • 应用
    • NLP:情感分析、医学报告分类。
    • CV:肿瘤检测、器官分割。
    • capped at 100 samples for demonstration; actual pretraining requires much larger datasets.

6.2 展望

  • 高效预训练:探索更高效的自监督任务(如MAE for ViT),减少数据需求。
  • 自动化微调:开发自动化微调框架,动态选择最佳策略(如AutoML)。
  • 多模态迁移:结合文本和图像预训练,适配医学影像+报告任务。
  • 可解释性:结合SHAP或注意力可视化,解释微调后模型行为。