AI大模型基础实践:微调BERT模型处理医学文本数据:原理、实现与应用
🧑 博主简介:CSDN博客专家、CSDN平台优质创作者,高级开发工程师,数学专业,10年以上C/C++, C#,Java等多种编程语言开发经验,拥有高级工程师证书;擅长C/C++、C#等开发语言,熟悉Java常用开发技术,能熟练应用常用数据库SQL server,Oracle,mysql,postgresql等进行开发应用,熟悉DICOM医学影像及DICOM协议,业余时间自学JavaScript,Vue,qt,python等,具备多种混合语言开发能力。撰写博客分享知识,致力于帮助编程爱好者共同进步。欢迎关注、交流及合作,提供技术支持与解决方案。\\n技术合作请加本人wx(注明来自csdn):xt20160813
微调BERT模型处理医学文本数据:原理、实现与应用
本文深入探讨如何使用Hugging Face的Transformers库微调BERT模型,处理医学影像中的文本数据(如DICOM元数据的自然语言处理),并结合医学影像分类任务(如肺结节检测、乳腺癌诊断、脑肿瘤分类),实现从文本到诊断的端到端辅助系统。内容涵盖BERT的原理、预处理流程、模型微调、推理与评估,辅以详细的Python代码、流程图。本文特别关注医学文本的挑战(如专业术语、数据稀缺),提出优化策略,并结合决策树和随机森林增强可解释性,展望多模态融合应用,适合自然语言处理(NLP)和医学影像领域的从业者。
一、前言摘要
医学影像分析不仅是图像处理,还涉及DICOM元数据中的文本信息(如放射学报告、患者病史),这些文本蕴含丰富的诊断线索。BERT(Bidirectional Encoder Representations from Transformers)作为一种强大的预训练语言模型,凭借其双向上下文理解能力,广泛应用于医学文本分类、实体识别和信息提取。本文基于Hugging Face的Transformers库,系统讲解如何微调BERT处理DICOM元数据,实现医学文本分类(如疾病诊断、报告结构化)。结合决策树和随机森林,增强模型的可解释性和临床适用性。内容涵盖数据预处理、特征提取、模型训练、评估与优化,辅以详细代码、解决医学文本的挑战(如术语复杂性、数据不平衡)。本文旨在为研究者和开发者提供理论与实践的全面参考,推动AI在医学领域的落地。
二、项目概述
2.1 项目目标
- 功能:微调BERT模型,从DICOM元数据或放射学报告中提取诊断信息,分类疾病(如肺结节良恶性、乳腺癌类型)。
- 意义:自动处理医学文本,辅助放射科医生,提高诊断效率,减少人工解读负担。
- 目标:
- 掌握BERT的预处理、微调和推理流程。
- 实现高精度和高召回率的文本分类,优先减少漏诊。
- 结合决策树/随机森林,提升模型可解释性。
- 比较BERT与传统机器学习方法在医学文本任务中的性能。
2.2 数据集
- MIMIC-CXR:
- 包含377,110张胸部X光影像及其DICOM元数据和放射学报告。
- 文本数据:自由文本报告,包含诊断、病史和建议。
- 挑战:术语复杂、文本长度不一、标注噪声。
- LUNA16(元数据):
- 包含CT扫描的DICOM元数据,记录扫描参数、患者信息和初步诊断。
- 挑战:元数据结构化程度高,但诊断信息有限。
- i2b2(临床文本):
- 包含去标识化的临床记录,标注疾病类别(如癌症相关)。
- 挑战:小数据集,需迁移学习。
- 数据挑战:
- 术语复杂性:医学报告包含专业术语(如“结节”“浸润”),需领域适配。
- 数据稀缺:标注文本有限,需预训练模型。
- 类不平衡:恶性病例少,需加权损失或过采样。
2.3 技术栈
- Hugging Face Transformers:加载预训练BERT,简化微调流程。
- PyTorch:深度学习框架,支持BERT和决策树实现。
- pydicom:解析DICOM元数据,提取文本字段。
- scikit-learn:实现决策树/随机森林,评估指标。
- NLTK/Spacy:医学文本预处理(分词、实体识别)。
- Matplotlib/Chart.js:可视化性能(混淆矩阵、ROC曲线)。
2.4 医学文本处理挑战
- 术语复杂性:需领域特定预训练(如BioBERT、ClinicalBERT)。
- 文本多样性:报告长度从几十到上千词,需截断或分块。
- 可解释性:医生需理解模型预测依据,决策树可提供支持。
- 高召回需求:漏诊成本高,需优化召回率。
三、原理
3.1 BERT(Bidirectional Encoder Representations from Transformers)
BERT是基于Transformer编码器的预训练语言模型,通过双向上下文建模,捕获复杂的语义关系,适合医学文本分类。
3.1.1 架构
- 输入表示:
- 文本分词为Token(使用WordPiece分词器)。
- 每个Token转换为嵌入向量:
Input Embedding=Token Embedding+Position Embedding+Segment Embedding\\text{Input Embedding} = \\text{Token Embedding} + \\text{Position Embedding} + \\text{Segment Embedding}Input Embedding=Token Embedding+Position Embedding+Segment Embedding - 特殊Token:
[CLS]
:分类任务的聚合表示。[SEP]
:分隔句子或文本段。
- Transformer编码器:
- 多层自注意力机制和前馈网络:
zl′=MultiHeadAttention(LN(zl−1))+zl−1z\'_l = \\text{MultiHeadAttention}(\\text{LN}(z_{l-1})) + z_{l-1}zl′=MultiHeadAttention(LN(zl−1))+zl−1
zl=FFN(LN(zl′))+zl′z_l = \\text{FFN}(\\text{LN}(z\'_l)) + z\'_lzl=FFN(LN(zl′))+zl′
其中,LN为Layer Normalization,FFN为前馈网络。 - 自注意力机制:
Attention(Q,K,V)=softmax(QKTdk)V\\text{Attention}(Q, K, V) = \\text{softmax}\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)VAttention(Q,K,V)=softmax(dkQKT)V
其中,Q,K,VQ, K, VQ,K,V为查询、键和值向量,dkd_kdk为维度。
- 多层自注意力机制和前馈网络:
- 分类头:
- 使用
[CLS]
Token的输出,经过全连接层和Softmax:
y=softmax(WzL[CLS]+b)y = \\text{softmax}(W z_L^{\\text{[CLS]}} + b)y=softmax(WzL[CLS]+b)
其中,zL[CLS]z_L^{\\text{[CLS]}}zL[CLS]为最后一层CLS Token。
- 使用
3.1.2 预训练与微调
- 预训练:
- 掩码语言模型(MLM):随机掩盖15%的Token,预测原始词。
- 下一句预测(NSP):判断两句话是否连续。
- 数据:BooksCorpus和Wikipedia(通用BERT),PubMed(BioBERT)。
- 微调:
- 替换分类头,适配下游任务(如二分类:良性/恶性)。
- 调整参数:全参数微调或LoRA(低秩适配)。
- 医学适用性:
- BioBERT/ClinicalBERT在PubMed或MIMIC数据上预训练,适配医学术语。
- 适合放射学报告分类、实体识别和信息提取。
3.2 决策树与随机森林
决策树和随机森林在医学文本处理中可结合BERT提取的特征,增强可解释性。
3.2.1 决策树原理
- 结构:树状模型,通过特征阈值递归分割数据。
- 分裂准则:Gini指数或信息增益:
Gini=1−∑i=1Cpi2\\text{Gini} = 1 - \\sum_{i=1}^C p_i^2Gini=1−i=1∑Cpi2
Information Gain=H(parent)−∑childNchildNH(child)\\text{Information Gain} = H(\\text{parent}) - \\sum_{child} \\frac{N_{child}}{N} H(\\text{child})Information Gain=H(parent)−child∑NNchildH(child) - 适用性:适合BERT输出的高维特征向量。
3.2.2 随机森林
- 原理:集成多棵决策树,通过投票或平均输出结果。
- 优势:
- 减少过拟合:随机特征选择和Bagging。
- 可解释性:特征重要性分析。
- 数学基础:
- 特征重要性:
Importance(f)=∑nodeΔGini(f,node)\\text{Importance}(f) = \\sum_{\\text{node}} \\Delta \\text{Gini}(f, \\text{node})Importance(f)=node∑ΔGini(f,node)
- 特征重要性:
3.2.3 医学文本中的应用
- 特征提取:使用BERT的
[CLS]
Token输出作为特征。 - 分类:结合深度特征,提升鲁棒性和可解释性。
3.3 迁移学习与LoRA
- 预训练:
- BERT:BooksCorpus+Wikipedia。
- BioBERT:PubMed文献,适配医学领域。
- ClinicalBERT:MIMIC-III临床记录。
- 微调:
- 全参数微调:更新所有参数,适合大数据集。
- LoRA:仅更新低秩矩阵:
W=W0+ΔW,ΔW=BA,B∈Rd×r,A∈Rr×kW = W_0 + \\Delta W, \\quad \\Delta W = BA, \\quad B \\in \\mathbb{R}^{d \\times r}, A \\in \\mathbb{R}^{r \\times k}W=W0+ΔW,ΔW=BA,B∈Rd×r,A∈Rr×k
其中,rrr为低秩参数。
- 优势:LoRA降低计算成本,适配医学小数据集。
3.4 评估指标
- 混淆矩阵:计算真阳性(TP)、假阳性(FP)、真阴性(TN)、假阴性(FN)。
- 指标:
- 准确率:Accuracy=TP+TNTP+TN+FP+FN\\text{Accuracy} = \\frac{TP+TN}{TP+TN+FP+FN}Accuracy=TP+TN+FP+FNTP+TN
- 精确率:Precision=TPTP+FP\\text{Precision} = \\frac{TP}{TP+FP}Precision=TP+FPTP
- 召回率:Recall=TPTP+FN\\text{Recall} = \\frac{TP}{TP+FN}Recall=TP+FNTP
- F1分数:F1=2⋅Precision⋅RecallPrecision+Recall\\text{F1} = 2 \\cdot \\frac{\\text{Precision} \\cdot \\text{Recall}}{\\text{Precision} + \\text{Recall}}F1=2⋅Precision+RecallPrecision⋅Recall
- ROC曲线与AUC:量化区分能力。
四、数据预处理
4.1 预处理流程
针对DICOM元数据和放射学报告,预处理包括以下步骤:
- 提取文本:
- DICOM:提取
StudyDescription
、SeriesDescription
等字段。 - 放射学报告:从MIMIC-CXR提取自由文本。
- DICOM:提取
- 文本清洗:
- 去除标点、数字、特殊字符。
- 统一大小写,处理拼写错误。
- 分词与编码:
- 使用BERT分词器(WordPiece),将文本转为Token ID。
- 截断或填充到固定长度(如512)。
- 数据增强:
- 同义词替换、随机插入,增加数据多样性。
- 使用
nlpaug
实现医学文本增强。
- 数据集划分:
- 80%训练,10%验证,10%测试,分层采样确保类平衡。
4.2 流程图
以下为医学文本预处理的Mermaid流程图:
#mermaid-svg-IP0fqWc13Wb8e7FR {font-family:\"trebuchet ms\",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-IP0fqWc13Wb8e7FR .error-icon{fill:#552222;}#mermaid-svg-IP0fqWc13Wb8e7FR .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-IP0fqWc13Wb8e7FR .edge-thickness-normal{stroke-width:2px;}#mermaid-svg-IP0fqWc13Wb8e7FR .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-IP0fqWc13Wb8e7FR .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-IP0fqWc13Wb8e7FR .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-IP0fqWc13Wb8e7FR .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-IP0fqWc13Wb8e7FR .marker{fill:#333333;stroke:#333333;}#mermaid-svg-IP0fqWc13Wb8e7FR .marker.cross{stroke:#333333;}#mermaid-svg-IP0fqWc13Wb8e7FR svg{font-family:\"trebuchet ms\",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-IP0fqWc13Wb8e7FR .label{font-family:\"trebuchet ms\",verdana,arial,sans-serif;color:#333;}#mermaid-svg-IP0fqWc13Wb8e7FR .cluster-label text{fill:#333;}#mermaid-svg-IP0fqWc13Wb8e7FR .cluster-label span{color:#333;}#mermaid-svg-IP0fqWc13Wb8e7FR .label text,#mermaid-svg-IP0fqWc13Wb8e7FR span{fill:#333;color:#333;}#mermaid-svg-IP0fqWc13Wb8e7FR .node rect,#mermaid-svg-IP0fqWc13Wb8e7FR .node circle,#mermaid-svg-IP0fqWc13Wb8e7FR .node ellipse,#mermaid-svg-IP0fqWc13Wb8e7FR .node polygon,#mermaid-svg-IP0fqWc13Wb8e7FR .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-IP0fqWc13Wb8e7FR .node .label{text-align:center;}#mermaid-svg-IP0fqWc13Wb8e7FR .node.clickable{cursor:pointer;}#mermaid-svg-IP0fqWc13Wb8e7FR .arrowheadPath{fill:#333333;}#mermaid-svg-IP0fqWc13Wb8e7FR .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-IP0fqWc13Wb8e7FR .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-IP0fqWc13Wb8e7FR .edgeLabel{background-color:#e8e8e8;text-align:center;}#mermaid-svg-IP0fqWc13Wb8e7FR .edgeLabel rect{opacity:0.5;background-color:#e8e8e8;fill:#e8e8e8;}#mermaid-svg-IP0fqWc13Wb8e7FR .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-IP0fqWc13Wb8e7FR .cluster text{fill:#333;}#mermaid-svg-IP0fqWc13Wb8e7FR .cluster span{color:#333;}#mermaid-svg-IP0fqWc13Wb8e7FR div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:\"trebuchet ms\",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-IP0fqWc13Wb8e7FR :root{--mermaid-font-family:\"trebuchet ms\",verdana,arial,sans-serif;}原始文本: DICOM元数据/放射学报告提取文本: pydicom解析DICOM文本清洗: 去除标点, 统一大小写分词与编码: BERT WordPiece数据增强: 同义词替换, 随机插入标准化: 截断/填充到512数据集划分: 训练, 验证, 测试
说明:
- A:DICOM元数据或MIMIC-CXR报告。
- B:使用
pydicom
提取文本字段。 - C:清洗噪声(如标点、拼写错误)。
- D:BERT分词器转为Token ID。
- E:增强数据多样性。
- F:统一长度,适配BERT输入。
- G:分层划分,确保类平衡。
4.3 代码实现
以下为MIMIC-CXR报告的预处理代码:
import pydicomimport pandas as pdimport numpy as npfrom transformers import BertTokenizerfrom torch.utils.data import Datasetimport nlpaug.augmenter.word as nawimport torch# 提取DICOM元数据def extract_dicom_text(dicom_path): ds = pydicom.dcmread(dicom_path) fields = [\'StudyDescription\', \'SeriesDescription\', \'PatientComments\'] text = \' \'.join([str(getattr(ds, field, \'\')) for field in fields if getattr(ds, field, \'\')]) return text# 文本清洗def clean_text(text): import re text = re.sub(r\'[^\\w\\s]\', \'\', text.lower()) # 去除标点,统一小写 text = re.sub(r\'\\s+\', \' \', text).strip() # 去除多余空格 return text# 数据增强aug = naw.SynonymAug(aug_src=\'wordnet\', aug_p=0.3)# 自定义数据集class MedicalTextDataset(Dataset): def __init__(self, dicom_dir, annotations_file, tokenizer_name=\'bert-base-uncased\', max_length=512): self.dicom_dir = dicom_dir self.annotations = pd.read_csv(annotations_file) self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name) self.max_length = max_length def __len__(self): return len(self.annotations) def __getitem__(self, idx): # 提取文本 dicom_path = os.path.join(self.dicom_dir, self.annotations.iloc[idx][\'dicom_id\']) text = extract_dicom_text(dicom_path) text = clean_text(text) # 数据增强 text = aug.augment(text)[0] # 分词与编码 encoding = self.tokenizer( text, max_length=self.max_length, padding=\'max_length\', truncation=True, return_tensors=\'pt\' ) label = self.annotations.iloc[idx][\'label\'] # 0: 良性,1: 恶性 return { \'input_ids\': encoding[\'input_ids\'].squeeze(), \'attention_mask\': encoding[\'attention_mask\'].squeeze(), \'label\': torch.tensor(label, dtype=torch.long) }# 加载数据集dataset = MedicalTextDataset( dicom_dir=\'path/to/mimic-cxr\', annotations_file=\'annotations.csv\', tokenizer_name=\'dmis-lab/biobert-v1.1\')
代码注释:
extract_dicom_text
:从DICOM提取描述字段,拼接为文本。clean_text
:去除标点、空格,统一格式。SynonymAug
:使用WordNet替换同义词,增强数据。BertTokenizer
:BioBERT分词器,适配医学术语。max_length=512
:BERT最大输入长度,截断或填充。
五、模型实现
5.1 BERT实现(Hugging Face)
使用Hugging Face的transformers
库加载BioBERT,结合LoRA微调,适配医学文本分类。
5.1.1 代码实现
from transformers import BertForSequenceClassification, BertTokenizerfrom peft import LoraConfig, get_peft_modelimport torchimport torch.nn as nnfrom torch.utils.data import DataLoaderfrom sklearn.metrics import accuracy_score# 加载预训练BioBERTtokenizer = BertTokenizer.from_pretrained(\'dmis-lab/biobert-v1.1\')model = BertForSequenceClassification.from_pretrained(\'dmis-lab/biobert-v1.1\', num_labels=2)# LoRA微调lora_config = LoraConfig( r=8, # 低秩矩阵维度 lora_alpha=16, # 缩放因子 target_modules=[\"query\", \"value\"], # 微调自注意力模块 lora_dropout=0.1)model = get_peft_model(model, lora_config)# 数据加载器dataloader = DataLoader(dataset, batch_size=16, shuffle=True)# 训练设置device = torch.device(\'cuda\' if torch.cuda.is_available() else \'cpu\')model = model.to(device)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)# 训练循环num_epochs = 5for epoch in range(num_epochs): model.train() running_loss = 0.0 for batch in dataloader: input_ids = batch[\'input_ids\'].to(device) attention_mask = batch[\'attention_mask\'].to(device) labels = batch[\'label\'].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask).logits loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() print(f\'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}\')# 推理model.eval()predictions, true_labels = [], []with torch.no_grad(): for batch in dataloader: input_ids = batch[\'input_ids\'].to(device) attention_mask = batch[\'attention_mask\'].to(device) labels = batch[\'label\'].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask).logits preds = torch.argmax(outputs, dim=1) predictions.extend(preds.cpu().numpy()) true_labels.extend(labels.cpu().numpy())print(\"准确率:\", accuracy_score(true_labels, predictions))
代码注释:
BertTokenizer
:BioBERT分词器,适配医学术语。BertForSequenceClassification
:加载BioBERT,分类头设为2类。LoraConfig
:低秩适配,减少微调参数。CrossEntropyLoss
:适合分类任务。Adam
:学习率2e-5,适配预训练模型。
5.2 决策树与随机森林实现
结合BERT提取的特征,使用随机森林分类。
5.2.1 代码实现
from sklearn.ensemble import RandomForestClassifierfrom sklearn.metrics import accuracy_score, classification_reportimport numpy as np# 提取BERT特征model.eval()features, labels = [], []with torch.no_grad(): for batch in dataloader: input_ids = batch[\'input_ids\'].to(device) attention_mask = batch[\'attention_mask\'].to(device) labels_batch = batch[\'label\'].to(device) outputs = model.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :] # 提取CLS Token features.extend(outputs.cpu().numpy()) labels.extend(labels_batch.cpu().numpy())features = np.array(features)labels = np.array(labels)# 随机森林分类rf = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42)rf.fit(features, labels)# 推理rf_predictions = rf.predict(features)print(\"随机森林准确率:\", accuracy_score(labels, rf_predictions))print(\"分类报告:\\n\", classification_report(labels, rf_predictions, target_names=[\'良性\', \'恶性\']))# 特征重要性importances = rf.feature_importances_indices = np.argsort(importances)[::-1][:10]print(\"Top 10 特征重要性:\", importances[indices])
代码注释:
last_hidden_state[:, 0, :]
:提取CLS Token作为特征。RandomForestClassifier
:100棵树,最大深度10。feature_importances_
:分析特征重要性。
六、评估与优化
6.1 评估方法
- 交叉验证:5折分层K折。
- 混淆矩阵:计算TP、FP、FN、TN。
- ROC曲线与AUC:量化区分能力。
6.2 代码实现
from sklearn.metrics import confusion_matrix, roc_curve, aucimport matplotlib.pyplot as plt# 混淆矩阵cm = confusion_matrix(true_labels, predictions)print(\"混淆矩阵:\\n\", cm)print(\"分类报告:\\n\", classification_report(true_labels, predictions, target_names=[\'良性\', \'恶性\']))# ROC曲线model.eval()probs = []with torch.no_grad(): for batch in dataloader: input_ids = batch[\'input_ids\'].to(device) attention_mask = batch[\'attention_mask\'].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask).logits probs.extend(torch.softmax(outputs, dim=1)[:, 1].cpu().numpy())fpr, tpr, _ = roc_curve(true_labels, probs)roc_auc = auc(fpr, tpr)plt.figure()plt.plot(fpr, tpr, color=\'#FF6384\', lw=2, label=f\'ROC曲线 (AUC = {roc_auc:.2f})\')plt.plot([0, 1], [0, 1], color=\'navy\', lw=2, linestyle=\'--\')plt.xlim([0.0, 1.0])plt.ylim([0.0, 1.05])plt.xlabel(\'假阳性率 (FPR)\')plt.ylabel(\'真阳性率 (TPR)\')plt.title(\'BioBERT ROC曲线(医学文本分类)\')plt.legend(loc=\"lower right\")plt.show()
代码注释:
confusion_matrix
:计算分类性能。roc_curve
:绘制ROC曲线。auc
:量化模型区分能力。
6.3 优化策略
- 类不平衡:加权损失、SMOTE过采样。
- 正则化:Dropout、权重衰减。
- 超参数调优:网格搜索学习率(1e-5到5e-5)、批大小(8-32)。
- 领域适配:在MIMIC-CXR上继续预训练BioBERT。
6.4 图表:BioBERT与随机森林性能对比
{ \"type\": \"line\", \"data\": { \"labels\": [\"2折\", \"3折\", \"5折\", \"10折\"], \"datasets\": [ { \"label\": \"BioBERT 召回率\", \"data\": [0.88, 0.90, 0.91, 0.90], \"borderColor\": \"#36A2EB\", \"fill\": false }, { \"label\": \"随机森林 召回率\", \"data\": [0.83, 0.85, 0.86, 0.85], \"borderColor\": \"#FF6384\", \"fill\": false } ] }, \"options\": { \"title\": { \"display\": true, \"text\": \"BioBERT与随机森林召回率对比(医学文本分类)\" }, \"scales\": { \"x\": { \"title\": { \"display\": true, \"text\": \"交叉验证折数\" } }, \"y\": { \"title\": { \"display\": true, \"text\": \"召回率\" }, \"ticks\": { \"min\": 0.8, \"max\": 1.0 } } } }}
说明:
- X轴:交叉验证折数。
- Y轴:召回率。
- 数据:假设BioBERT优于随机森林。
七、可解释性分析
7.1 Attention可视化(BERT)
可视化BERT的自注意力权重,显示模型关注的文本区域。
from transformers import BertTokenizer, BertModelimport matplotlib.pyplot as plt# 加载模型和分词器tokenizer = BertTokenizer.from_pretrained(\'dmis-lab/biobert-v1.1\')model = BertModel.from_pretrained(\'dmis-lab/biobert-v1.1\').to(device)# 示例文本text = \"Lung nodule detected, likely malignant.\"inputs = tokenizer(text, return_tensors=\'pt\', padding=True, truncation=True).to(device)# 获取注意力权重model.eval()with torch.no_grad(): outputs = model(**inputs, output_attentions=True) attentions = outputs.attentions[-1][0, :, 0, :].cpu().numpy() # 最后一层CLS Token的注意力# 可视化tokens = tokenizer.convert_ids_to_tokens(inputs[\'input_ids\'][0])plt.figure(figsize=(10, 8))plt.imshow(attentions, cmap=\'viridis\')plt.xticks(range(len(tokens)), tokens, rotation=45)plt.yticks(range(12), [f\'Head {i+1}\' for i in range(12)])plt.colorbar()plt.title(\'BioBERT Attention Weights(CLS Token)\')plt.show()
说明:
attentions[-1]
:最后一层注意力权重。CLS Token
:显示模型对输入Token的关注度。
7.2 随机森林特征重要性
分析BERT特征在随机森林中的重要性。
八、总结与展望
8.1 总结
- 成果:
- 实现基于BioBERT的医学文本分类器。
- 结合随机森林,提升可解释性。
- 完成MIMIC-CXR文本预处理、训练和评估。
- 关键点:
- LoRA降低微调成本。
- 召回率优先优化。
- Attention可视化和特征重要性增强信任。
8.2 展望
- 多模态融合:结合影像(ViT)和文本(BERT)。
- 实体识别:提取报告中的关键实体(如“结节”“恶性”)。
- 自动化报告生成:从影像到诊断报告。
- 可解释性:结合SHAP值深入分析。