X 光片分类:胸部疾病的自动识别

FreeGuideOnline 最新 2026-06-20

胸部X光片(Chest X-ray)智能分类实战教程

从数据加载到模型部署,使用深度学习实现胸部疾病的自动识别


一、引言:为什么需要X光片自动分类?

胸部X光检查是诊断肺炎、气胸、肺结节等疾病最常用的影像学方法。然而,全球放射科医生的缺口巨大,诊断耗时且存在一定主观性。自动分类技术可辅助医生快速筛查、优先处理危重病例,缓解医疗资源压力。

本教程面向零基础学习者,将系统讲解如何使用卷积神经网络(CNN)完成胸部X光片的疾病分类。你将掌握:

  • 医学影像数据集的特点与预处理方法
  • 构建可解释的深度学习分类模型
  • 模型性能评估及实际部署注意事项

二、数据集解析:NIH Chest X-ray Dataset

我们选用公开的NIH Chest X-ray Dataset,它包含112,120张来自30,805名患者的正面胸部X光片,用14种疾病标签(如肺不张、心脏肥大、渗出、浸润、肿块等)进行弱标记。

关键特点:

  • 图像分辨率:1024×1024像素(下采样后常用224×224或256×256)
  • 标签为多标签分类(一张片子可能同时存在多种病变)
  • 标签来源于放射学报告的自然语言处理,存在一定噪声
  • 训练/测试划分:官方提供按患者划分的文件train_val_list.txttest_list.txt

数据获取:

# 下载数据集(约42GB),需在Kaggle平台申请权限
kaggle datasets download -d nih-chest-xrays/data
unzip data.zip -d nih_data

三、环境搭建与关键库

推荐使用PyTorch或TensorFlow/Keras,本教程以PyTorch为例。

pip install torch torchvision pandas numpy matplotlib scikit-learn opencv-python tqdm

导入基础模块:

import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import cv2
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

四、数据预处理与自定义Dataset

4.1 读取标签 NIH标签文件为Data_Entry_2017.csv,包含图像名和对应的“Finding Labels”字段(多个疾病以“|”分隔)。

df = pd.read_csv('nih_data/Data_Entry_2017.csv')
# 将Finding Labels转为14维的one-hot编码
all_labels = ['Atelectasis','Cardiomegaly','Effusion','Infiltration',
              'Mass','Nodule','Pneumonia','Pneumothorax',
              'Consolidation','Edema','Emphysema','Fibrosis',
              'Pleural_Thickening','Hernia']
for c in all_labels:
    df[c] = df['Finding Labels'].apply(lambda x: 1 if c in x else 0)

4.2 自定义Dataset类

class ChestXrayDataset(Dataset):
    def __init__(self, dataframe, image_dir, transforms=None, label_cols=all_labels):
        self.df = dataframe
        self.image_dir = image_dir
        self.transforms = transforms
        self.label_cols = label_cols

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['Image Index']
        img_path = os.path.join(self.image_dir, img_name)
        # 读取灰度图并转为3通道(适配预训练模型)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        
        if self.transforms:
            img = self.transforms(img)
        
        labels = torch.tensor(self.df.iloc[idx][self.label_cols].values.astype(np.float32))
        return img, labels

4.3 数据增强与标准化 考虑到类别不平衡和图像差异,采用:

data_transforms = {
    'train': transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
}

五、模型构建:从经典CNN到现代EfficientNet

5.1 主干网络选择 迁移学习是医疗影像任务的标准做法。推荐使用在ImageNet上预训练的DenseNet121EfficientNet-B0,它们在参数量和性能间取得良好平衡。

model = models.densenet121(pretrained=True)
# 替换最后的全连接分类层以适配14个输出
num_ftrs = model.classifier.in_features
model.classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(num_ftrs, 14),
    nn.Sigmoid()  # 多标签分类用Sigmoid激活
)

5.2 损失函数与优化器 多标签分类使用二元交叉熵损失BCEWithLogitsLoss或直接使用BCELoss配合Sigmoid)。如果使用Sigmoid在模型中,则:

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

若在损失函数中集成Sigmoid,可改为:

criterion = nn.BCEWithLogitsLoss()
# 模型最后一层不要Sigmoid

六、训练循环与Focal Loss优化类别不平衡

NIH数据集中部分疾病(如“Hernia”)阳性样本极少。普通BCELoss会使模型偏向负类。引入Focal Loss可降低易分类样本的权重,聚焦难分样本。

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        bce_loss = nn.BCELoss(reduction='none')(inputs, targets)
        pt = torch.where(targets==1, inputs, 1-inputs)
        focal_weight = (1 - pt) ** self.gamma
        if self.alpha >= 0:
            alpha_t = torch.where(targets==1, self.alpha, 1-self.alpha)
            focal_weight = alpha_t * focal_weight
        loss = focal_weight * bce_loss
        return loss.mean()

训练循环(单epoch):

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for imgs, labels in tqdm(loader, desc='Training'):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    return running_loss / len(loader.dataset)

七、模型评估:AUC与ROC曲线

多标签分类的主要指标是每个类别的ROC AUC,以及宏观/加权平均AUC。

def evaluate(model, loader, device):
    model.eval()
    preds, targets = [], []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            outputs = model(imgs).cpu().numpy()
            preds.extend(outputs)
            targets.extend(labels.numpy())
    preds = np.array(preds)
    targets = np.array(targets)
    aucs = []
    for i in range(targets.shape[1]):
        aucs.append(roc_auc_score(targets[:, i], preds[:, i]))
    return np.mean(aucs), aucs

绘制ROC曲线:

from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 8))
for i, disease in enumerate(all_labels):
    fpr, tpr, _ = roc_curve(targets[:, i], preds[:, i])
    plt.plot(fpr, tpr, label=f'{disease} (AUC={aucs[i]:.2f})')
plt.plot([0,1],[0,1],'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves per Disease')
plt.legend(loc='lower right')
plt.show()

八、模型解释性:Grad-CAM可视化

在医疗领域,模型决策过程必须透明。使用Grad-CAM查看网络关注区域。

def grad_cam(model, img_tensor, target_layer=model.features.denseblock4):
    model.eval()
    gradients = []
    activations = []

    def backward_hook(module, grad_in, grad_out):
        gradients.append(grad_out[0])

    def forward_hook(module, input, output):
        activations.append(output)

    hook1 = target_layer.register_forward_hook(forward_hook)
    hook2 = target_layer.register_backward_hook(backward_hook)

    # 前向传播并保留梯度
    output = model(img_tensor.unsqueeze(0))
    # 取某个疾病的得分,比如Pneumonia的索引
    class_idx = all_labels.index('Pneumonia')
    model.zero_grad()
    output[0, class_idx].backward(retain_graph=True)

    # 计算权重
    pooled_grad = torch.mean(gradients[0], dim=[0,2,3], keepdim=True)
    cam = torch.sum(pooled_grad * activations[0], dim=1).squeeze().cpu().detach().numpy()
    cam = np.maximum(cam, 0)  # ReLU
    cam = cv2.resize(cam, (224,224))
    cam = (cam - cam.min()) / (cam.max() - cam.min())
    return cam

# 使用:
# img, _ = dataset[0]
# cam = grad_cam(model, img)
# 叠加显示

九、部署为轻量级推理服务

训练完成后,将模型转为TorchScript,使用Flask提供REST API。

# 导出模型
example = torch.rand(1, 3, 224, 224)
traced_model = torch.jit.trace(model.cpu(), example)
traced_model.save("cxr_classifier.pt")

# Flask服务
from flask import Flask, request, jsonify
import io
from PIL import Image

app = Flask(__name__)
model = torch.jit.load('cxr_classifier.pt')
model.eval()

def preprocess(image_bytes):
    img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
    img = data_transforms['val'](img).unsqueeze(0)
    return img

@app.route('/predict', methods=['POST'])
def predict():
    img = preprocess(request.files['image'].read())
    with torch.no_grad():
        output = model(img)
    prob = output[0].tolist()
    return jsonify({disease: p for disease, p in zip(all_labels, prob)})

if __name__ == '__main__':
    app.run(debug=False, host='0.0.0.0', port=5000)

十、常见问题与调优建议

  1. 数据泄露:务必按患者ID划分数据集,避免同一患者的图像出现在训练集和测试集中。
  2. 标签不平衡:除了Focal Loss,还可采用重采样、调整类别权重。
  3. 影像增强:直方图均衡化、CLAHE可提高对比度,肺野分割可去除非解剖区域干扰。
  4. 多标签相关性:疾病间存在共现关系,可尝试图神经网络或标签嵌入建模依赖。
  5. 外部验证:在CheXpert、MIMIC-CXR等数据集上测试泛化能力。

十一、总结与进阶方向

本教程带你走完了一个完整的胸部X光片多标签分类流程。从数据预处理、DenseNet模型微调、Focal Loss应对不平衡,到解释性分析和API部署,你已具备实际落地的能力。

进阶方向:

  • 采用Vision Transformer(ViT)或ConvNeXt等更新架构
  • 加入病灶检测(如Faster R-CNN)进行区域级分析
  • 使用半监督学习利用大量无标签数据
  • 结合临床报告进行多模态学习

愿技术为医疗公平贡献一份力量。