心电图 ECG 分类:心律失常的深度学习自动诊断
心电图(ECG)与心律失常基础
心电图是记录心脏电活动的无创检查方法。每个心动周期包含P波(心房除极)、QRS波群(心室除极)和T波(心室复极)。心律失常指心脏节律或传导异常,常见类型包括:
- 正常窦性心律:心率60-100次/分钟,P波后跟随QRS,节律规整。
- 房性早搏:提前出现的异常P波,QRS形态通常正常。
- 室性早搏:宽大畸形的QRS波,前无相关P波。
- 心房颤动:P波消失,代之以细小f波,心室律绝对不齐。
- 心室颤动:波形杂乱,无法分辨QRS-T,是致命性心律失常。
ECG自动分类任务的目标是,根据输入的ECG信号片段,输出上述类别(或多类)的概率。
深度学习分类流程概览
基于深度学习的ECG分类遵循典型流水线:
- 数据获取与预处理:读取公开数据集(如MIT-BIH、PhysioNet/CinC),进行滤波、分割、归一化。
- 信号片段生成:以心拍为单位截取固定长度片段(例如R峰前后各取一定采样点)。
- 模型构建:设计并训练卷积神经网络(CNN)、循环神经网络(RNN)或混合模型。
- 训练与验证:划分训练集/验证集/测试集,采用类别加权处理不平衡数据。
- 评估与部署:使用准确率、召回率、混淆矩阵等指标评价模型,并可导出为轻量级推理服务。
数据预处理步骤
在深度学习项目中,预处理质量决定模型上限。对ECG信号主要执行以下操作:
读取与滤波
典型ECG采样率为360Hz或500Hz。使用wfdb或scipy读取信号后,应用带通滤波器滤除基线漂移和肌电噪声(如0.5-40Hz巴特沃斯带通)。
from scipy.signal import butter, filtfilt
def bandpass_filter(data, lowcut=0.5, highcut=40.0, fs=360, order=4):
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = butter(order, [low, high], btype='band')
return filtfilt(b, a, data)
R峰检测与心拍提取
利用neurokit2或biosppy的检测算法定位R峰,然后以R峰为中心截取窗口。例如,采样率360Hz时,取前90个采样点(0.25s)和后150个采样点(0.42s),得到固定长度240个采样点的心拍。
import neurokit2 as nk
signals, info = nk.ecg_process(signal, sampling_rate=360)
r_peaks = info['ECG_R_Peaks']
beats = []
window_before = 90 # samples
window_after = 150
for r in r_peaks[1:-1]: # 避免边界
if r - window_before >= 0 and r + window_after < len(signal):
beat = signal[r - window_before : r + window_after]
beats.append(beat)
归一化与标签编码
对每个心拍进行Z-score标准化,使均值0、标准差1。将类别标签映射为整数(如0代表正常,1代表室早),后续使用to_categorical转换为独热编码。
beats = np.array(beats)
beats = (beats - np.mean(beats, axis=1, keepdims=True)) / np.std(beats, axis=1, keepdims=True)
常见数据集介绍
- MIT-BIH心律失常数据库:包含48个半小时双通道记录,带有专家标注的心拍类型(正常、室早、室上早等)。
- PhysioNet/CinC 2017:单导联短时ECG,标签为正常、房颤、其他节律、噪声四类。
- PTB-XL:大型多导联ECG数据集,提供多种诊断标签,适合多任务学习。
在本教程中,我们以MIT-BIH为例,使用前导联MLII的数据。
模型设计与实现
我们构建一个1D CNN模型,因其计算高效且对波形特征敏感。模型包含三个卷积块,每个块后接批归一化和最大池化,最终通过全连接层输出类别概率。
使用Keras定义模型
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, BatchNormalization, Flatten, Dense, Dropout
def create_cnn_model(input_shape=(240, 1), num_classes=5):
model = Sequential([
Conv1D(32, kernel_size=7, activation='relu', padding='same', input_shape=input_shape),
BatchNormalization(),
MaxPooling1D(pool_size=2),
Conv1D(64, kernel_size=5, activation='relu', padding='same'),
BatchNormalization(),
MaxPooling1D(pool_size=2),
Conv1D(128, kernel_size=3, activation='relu', padding='same'),
BatchNormalization(),
MaxPooling1D(pool_size=2),
Flatten(),
Dense(128, activation='relu'),
Dropout(0.5),
Dense(num_classes, activation='softmax')
])
return model
model = create_cnn_model()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
输入形状 (240, 1) 对应240个采样点,单通道。可根据实际窗口长度调整。
处理类别不平衡
心律失常数据集常存在严重不平衡(例如正常心跳远多于室性早搏)。在训练时使用class_weight计算每个类别的权重,或采用焦点损失(Focal Loss)。
from sklearn.utils.class_weight import compute_class_weight
class_weights = compute_class_weight('balanced', classes=np.unique(y_train_int), y=y_train_int)
class_weight_dict = dict(enumerate(class_weights))
训练模型
将数据划分为训练集(70%)、验证集(15%)、测试集(15%)。采用早停法防止过拟合。
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
checkpoint = ModelCheckpoint('best_ecg_model.h5', monitor='val_accuracy', save_best_only=True)
history = model.fit(
X_train, y_train_cat,
validation_data=(X_val, y_val_cat),
epochs=50,
batch_size=64,
class_weight=class_weight_dict,
callbacks=[early_stop, checkpoint]
)
评估与可视化
分类报告与混淆矩阵
在测试集上计算精确率、召回率、F1值,并绘制混淆矩阵以观察各类间的混淆情况。
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
y_pred = model.predict(X_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true_classes = np.argmax(y_test_cat, axis=1)
print(classification_report(y_true_classes, y_pred_classes, target_names=class_names))
cm = confusion_matrix(y_true_classes, y_pred_classes)
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
训练曲线
绘制损失和准确率随epoch的变化,检查过拟合情况。
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.legend()
plt.title('Loss Curve')
plt.subplot(1,2,2)
plt.plot(history.history['accuracy'], label='Train Acc')
plt.plot(history.history['val_accuracy'], label='Val Acc')
plt.legend()
plt.title('Accuracy Curve')
plt.tight_layout()
plt.show()
进阶技巧
使用更复杂的架构
- RNN/LSTM:利用ECG信号的时序依赖,可在CNN后接双向LSTM捕获长程模式。
- 残差网络(ResNet):对于深层模型,残差连接有助于梯度流动,可参考ECG专用ResNet变体。
- 注意力机制:在多导联或长时间序列中,注意力模块可突出关键区域。
数据增强
对ECG信号应用轻微的时间扭曲、缩放、噪声添加等,可提升模型泛化能力。
def add_noise(beat, noise_factor=0.05):
noise = np.random.randn(len(beat))
return beat + noise_factor * noise
# 示例:添加噪声后数据加入训练集
多导联处理
如果有12导联数据,可将每个导联视为一个通道。输入形状变为 (时间步长, 导联数),例如 (240, 12)。模型需调整第一层以接受多通道输入。
模型部署思路
训练好的模型可转换为TensorFlow Lite或ONNX格式,部署到移动设备或嵌入式系统。同时需注意推理延迟和模型大小,可通过量化压缩。
import tensorflow as tf
# 转换为TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open('ecg_classifier.tflite', 'wb') as f:
f.write(tflite_model)
总结
本教程完整展示了从ECG信号预处理到深度学习模型构建、训练、评估的端到端流程。初学者可基于MIT-BIH数据集快速复现实验,并在此基础上尝试更复杂的网络结构和数据增强策略。掌握该方法后,可进一步探索多标签分类、回归(如心率估计)等心脏病智能诊断任务。