ID3RSNet:单通道EEG跨被试疲劳检测可解释网络

ID3RSNet:单通道EEG跨被试疲劳检测可解释网络(Frontiers in Neuroscience 2025)

论文信息:


核心创新

问题定义:

  1. 多通道EEG设备复杂、成本高、佩戴不便
  2. 单通道EEG信号非平稳、个体差异大
  3. 深度学习模型”黑箱”问题,缺乏可解释性

核心方法:

  1. 残差收缩网络(RSBU):自适应特征重标定 + 软阈值去噪
  2. 权重冻结全连接层:抑制神经元负面影响
  3. ECAM可解释方法:可视化EEG频段激活模式

性能: 跨被试LOSOCV准确率 > 90%,同时提供神经生理学可解释证据


1. 问题背景

1.1 疲劳检测方法对比

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class DrowsinessDetectionMethods:
"""
疲劳检测方法分类
"""

METHODS = {
"行为检测": {
"数据源": ["面部表情", "眼睛状态", "嘴部动作"],
"优势": ["非接触", "成本低"],
"劣势": ["受光照影响", "依赖头部姿态"]
},
"车辆行为": {
"数据源": ["方向盘角度", "车道偏离", "速度变化"],
"优势": ["已有传感器", "无需额外设备"],
"劣势": ["个体差异大", "延迟高"]
},
"生理信号": {
"数据源": ["EEG", "ECG", "EOG", "EMG"],
"优势": ["直接反映状态", "实时性强"],
"劣势": ["设备复杂", "佩戴不便"]
}
}

@staticmethod
def why_eeg():
"""
为什么EEG是疲劳检测金标准

疲劳与大脑活动直接相关
"""
return {
"原因": "疲劳本质上是大脑活动状态改变",
"优势": [
"直接测量中枢神经系统",
"毫秒级时间分辨率",
"可检测早期疲劳信号"
],
"挑战": [
"信号非平稳",
"个体差异大",
"多通道设备佩戴复杂"
]
}

1.2 单通道EEG优势

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class SingleChannelEEG:
"""
单通道EEG优势与挑战
"""

ADVANTAGES = {
"成本": "设备成本降低80%",
"便携": "可穿戴设计,用户友好",
"采集": "信号采集简单,减少设置时间",
"移动": "用户活动受限小"
}

CHALLENGES = {
"信息量": "相比多通道信息量减少",
"非平稳": "EEG信号高度非平稳",
"个体差异": "跨被试泛化困难",
"信噪比": "信号噪声比低"
}

SOLUTION = "ID3RSNet: 残差收缩网络 + 可解释性"

2. 方法架构

2.1 网络整体架构

graph TB
    A[原始单通道EEG] --> B[BaseFE基础特征提取]
    B --> C[RSBU残差收缩单元]
    C --> D[GAP全局平均池化]
    D --> E[FC-WF权重冻结全连接]
    E --> F[疲劳/清醒分类]
    
    C --> G[ECAM类激活图]
    G --> H[可解释性分析]

2.2 基础特征提取器(BaseFE)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn

class BaseFeatureExtractor(nn.Module):
"""
基础特征提取器

从原始EEG提取频域特征
"""

def __init__(self, in_channels: int = 1, out_channels: int = 16):
super().__init__()

# 时域卷积
self.conv1d = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=64, stride=8),
nn.BatchNorm1d(out_channels),
nn.ELU(),
nn.Dropout(0.3),

nn.Conv1d(out_channels, out_channels*2, kernel_size=16, stride=2),
nn.BatchNorm1d(out_channels*2),
nn.ELU(),
nn.Dropout(0.3),
)

def forward(self, x):
"""
前向传播

Args:
x: 原始EEG信号 (B, 1, T)

Returns:
features: 频域特征 (B, C, L)
"""
return self.conv1d(x)

2.3 残差收缩单元(RSBU)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class ResidualShrinkageBlock(nn.Module):
"""
残差收缩单元

核心创新:
1. 注意力机制自适应特征重标定
2. 软阈值去噪消除噪声
"""

def __init__(self, in_channels: int, reduction: int = 4):
super().__init__()

# 卷积层
self.conv1 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)

# 注意力模块(SE-Net风格)
self.global_pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction),
nn.ReLU(),
nn.Linear(in_channels // reduction, in_channels),
nn.Sigmoid()
)

# 软阈值(可学习)
self.threshold = nn.Parameter(torch.zeros(1, in_channels, 1))

def soft_threshold(self, x, threshold):
"""
软阈值函数

y = sign(x) * max(|x| - τ, 0)
"""
return torch.sign(x) * torch.relu(torch.abs(x) - threshold)

def forward(self, x):
"""
前向传播
"""
identity = x

# 卷积
out = torch.relu(self.conv1(x))
out = self.conv2(out)

# 注意力权重
b, c, _ = out.size()
attention = self.global_pool(out).view(b, c)
attention = self.fc(attention).view(b, c, 1)

# 特征重标定
out = out * attention

# 软阈值去噪
out = self.soft_threshold(out, torch.abs(self.threshold))

# 残差连接
out = out + identity

return out

2.4 权重冻结全连接层(FC-WF)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class WeightFrozenFC(nn.Module):
"""
权重冻结全连接层

抑制神经元负面影响
"""

def __init__(self, in_features: int, out_features: int, freeze_ratio: float = 0.3):
super().__init__()

self.fc = nn.Linear(in_features, out_features)
self.freeze_ratio = freeze_ratio

# 记录需要冻结的权重索引
self.frozen_indices = None

def analyze_weights(self):
"""
分析权重重要性,冻结负面影响权重
"""
with torch.no_grad():
weights = self.fc.weight.data
importance = torch.abs(weights).mean(dim=0)

# 冻结最不重要的权重
num_freeze = int(len(importance) * self.freeze_ratio)
_, indices = torch.topk(importance, num_freeze, largest=False)
self.frozen_indices = indices

def forward(self, x):
"""
前向传播(冻结指定权重)
"""
if self.frozen_indices is not None and self.training:
with torch.no_grad():
self.fc.weight.data[self.frozen_indices] = 0

return self.fc(x)

2.5 完整ID3RSNet

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class ID3RSNet(nn.Module):
"""
ID3RSNet完整网络

可解释残差收缩网络
"""

def __init__(self, num_classes: int = 2):
super().__init__()

# 基础特征提取
self.base_fe = BaseFeatureExtractor(in_channels=1, out_channels=32)

# 残差收缩模块
self.rs_block1 = ResidualShrinkageBlock(64, reduction=4)
self.rs_block2 = ResidualShrinkageBlock(64, reduction=4)

# 全局平均池化
self.gap = nn.AdaptiveAvgPool1d(1)

# 权重冻结FC
self.fc = WeightFrozenFC(64, num_classes, freeze_ratio=0.3)

# 存储特征图用于可解释性
self.feature_maps = None

def forward(self, x):
"""
前向传播

Args:
x: 原始EEG (B, 1, T)

Returns:
logits: 分类结果 (B, num_classes)
"""
# 特征提取
x = self.base_fe(x)

# 残差收缩
x = self.rs_block1(x)
x = self.rs_block2(x)

# 保存特征图
self.feature_maps = x

# 全局池化
x = self.gap(x).squeeze(-1)

# 分类
x = self.fc(x)

return x

def get_ecam(self, class_idx: int = 1):
"""
获取ECAM(EEG-based Class Activation Map)

可解释性分析
"""
if self.feature_maps is None:
raise ValueError("需要先运行forward")

# 获取FC权重
fc_weights = self.fc.fc.weight.data[class_idx]

# 计算激活图
feature_maps = self.feature_maps.detach()
cam = torch.zeros(feature_maps.size(0), feature_maps.size(2))

for i, w in enumerate(fc_weights):
cam += w * feature_maps[:, i, :]

return cam


# 使用示例
if __name__ == "__main__":
# 创建模型
model = ID3RSNet(num_classes=2)

# 模拟EEG数据(采样率200Hz,5秒)
batch_size = 4
seq_length = 1000 # 5秒 * 200Hz
eeg_signal = torch.randn(batch_size, 1, seq_length)

# 前向传播
output = model(eeg_signal)
print(f"输出形状: {output.shape}")

# 获取可解释性
cam = model.get_ecam(class_idx=1) # 疲劳类
print(f"CAM形状: {cam.shape}")

3. 实验结果

3.1 数据集

数据集 被试数 状态数 通道数
DROZY 14 3 多通道
ULg 22 2 多通道

3.2 性能对比

方法 LOSOCV准确率 参数量 可解释性
CNN-LSTM 82.3% 1.2M
EEGNet 85.7% 0.8M
DeepCNN 88.2% 1.5M
ID3RSNet 91.5% 0.6M

3.3 可解释性分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class EEGInterpretability:
"""
EEG可解释性分析

基于ECAM识别疲劳相关频段
"""

# EEG频段定义
FREQUENCY_BANDS = {
"Delta": (0.5, 4),
"Theta": (4, 8),
"Alpha": (8, 13),
"Beta": (13, 30),
"Gamma": (30, 100)
}

@staticmethod
def analyze_cam(cam: np.ndarray, fs: float = 200):
"""
分析CAM激活频段

疲劳状态通常表现为:
- Alpha波增强(8-13Hz)
- Theta波增强(4-8Hz)
- Beta波减弱(13-30Hz)
"""
# FFT分析
fft = np.fft.fft(cam)
freqs = np.fft.fftfreq(len(cam), 1/fs)

# 各频段能量
band_energy = {}
for band_name, (low, high) in EEGInterpretability.FREQUENCY_BANDS.items():
band_mask = (np.abs(freqs) >= low) & (np.abs(freqs) < high)
band_energy[band_name] = np.sum(np.abs(fft[band_mask]))

# 疲劳特征
fatigue_indicator = (
band_energy["Theta"] + band_energy["Alpha"]
) / band_energy["Beta"]

return {
"band_energy": band_energy,
"fatigue_indicator": fatigue_indicator,
"interpretation": EEGInterpretability._interpret(fatigue_indicator)
}

@staticmethod
def _interpret(indicator: float) -> str:
"""
解释疲劳指标
"""
if indicator > 2.0:
return "高度疲劳:Theta+Alpha显著增强,Beta减弱"
elif indicator > 1.5:
return "中度疲劳:Theta+Alpha增强,Beta正常"
elif indicator > 1.0:
return "轻度疲劳:Alpha略增强"
else:
return "清醒状态:Beta主导"

4. IMS 应用启示

4.1 单通道EEG设备选型

设备 通道数 采样率 无线 价格
NeuroSky MindWave 1 512Hz 蓝牙 $99
Muse S 4 256Hz 蓝牙 $399
Emotiv EPOC X 14 256Hz 蓝牙 $849
推荐:NeuroSky 1 512Hz 低成本

4.2 实时疲劳检测系统

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class RealTimeDrowsinessDetector:
"""
实时疲劳检测系统
"""

def __init__(self, model_path: str, fs: int = 200):
# 加载模型
self.model = ID3RSNet(num_classes=2)
self.model.load_state_dict(torch.load(model_path))
self.model.eval()

self.fs = fs
self.window_size = 5 * fs # 5秒窗口
self.buffer = []

def process_sample(self, sample: float) -> dict:
"""
处理单个采样点

Args:
sample: EEG采样值

Returns:
result: 检测结果
"""
# 添加到缓冲区
self.buffer.append(sample)

# 保持窗口大小
if len(self.buffer) > self.window_size:
self.buffer.pop(0)

# 窗口未满
if len(self.buffer) < self.window_size:
return {"status": "BUFFERING"}

# 推理
window = np.array(self.buffer)
window_tensor = torch.FloatTensor(window).unsqueeze(0).unsqueeze(0)

with torch.no_grad():
logits = self.model(window_tensor)
prob = torch.softmax(logits, dim=1)
drowsy_prob = prob[0, 1].item()

# 可解释性
cam = self.model.get_ecam(class_idx=1)
interpretation = EEGInterpretability.analyze_cam(cam.numpy(), self.fs)

return {
"status": "DROWSY" if drowsy_prob > 0.5 else "ALERT",
"drowsy_probability": drowsy_prob,
"interpretation": interpretation
}

4.3 与DMS融合方案

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class IMSDrowsinessFusion:
"""
IMS疲劳检测融合

EEG + 摄像头DMS
"""

def __init__(self):
self.eeg_detector = RealTimeDrowsinessDetector("id3rsnet.pth")
self.camera_detector = None # 摄像头DMS

def fuse_decision(self, eeg_result: dict, camera_result: dict) -> dict:
"""
融合决策

Args:
eeg_result: EEG检测结果
camera_result: 摄像头检测结果

Returns:
final_decision: 最终决策
"""
# EEG疲劳概率
eeg_prob = eeg_result.get("drowsy_probability", 0)

# 摄像头疲劳指标
camera_perclos = camera_result.get("perclos", 0)
camera_yawn = camera_result.get("yawn_count", 0)

# 加权融合
camera_prob = 0.5 * camera_perclos + 0.5 * min(camera_yawn / 3, 1.0)

final_prob = 0.6 * eeg_prob + 0.4 * camera_prob

# 决策
if final_prob > 0.7:
level = "严重疲劳"
action = "建议立即停车休息"
elif final_prob > 0.4:
level = "轻度疲劳"
action = "建议休息或喝咖啡"
else:
level = "清醒"
action = "正常驾驶"

return {
"fatigue_level": level,
"fatigue_probability": final_prob,
"recommendation": action,
"eeg_contribution": eeg_prob,
"camera_contribution": camera_prob
}

5. 总结

方面 内容
核心创新 残差收缩 + 可解释性
输入 单通道原始EEG
性能 LOSOCV准确率91.5%
可解释性 ECAM可视化疲劳频段
部署 支持低成本可穿戴设备

参考链接:

  1. 论文原文: https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2024.1508747/full
  2. DROZY数据集: https://zenodo.org/record/1154610
  3. NeuroSky MindWave: https://neurosky.com/products/mindwave/

ID3RSNet:单通道EEG跨被试疲劳检测可解释网络
https://dapalm.com/2026/06/10/2026-06-10-ID3RSNet-Single-Channel-EEG-Drowsiness/
作者
Mars
发布于
2026年6月10日
许可协议