FatigueNet论文解读:GNN+Transformer多模态疲劳检测实时部署方案

FatigueNet论文解读:GNN+Transformer多模态疲劳检测实时部署方案

论文来源: Nature Scientific Reports, September 2025
论文标题: FatigueNet: A hybrid graph neural network and transformer framework for real-time multimodal fatigue detection
核心创新: 首次将GNN与Transformer融合用于多模态生物信号疲劳检测


核心洞察

FatigueNet的关键突破:

传统方案 FatigueNet方案
CNN/LSTM单模态 GNN+Transformer多模态融合
固定权重融合 Meta-Gated自适应动态权重
手动特征工程 多域自动特征提取
精度~90% 精度>95%,提升5%+

对IMS开发的启示:

  • 多模态融合(ECG+EDA+EMG+眨眼)比单一视觉方案更鲁棒
  • Meta-Gated模块可实现实时自适应权重调整
  • 适合车规级低延迟部署

一、问题背景

1.1 疲劳的神经科学基础

疲劳是复杂的生理现象,涉及:

类型 成因 神经机制
物理疲劳 剧烈体力劳动 肌肉效率下降、心血管异常
精神疲劳 长期脑力工作 前额叶皮层活动下降、多巴胺/血清素失调
情绪疲劳 慢性压力/抑郁 边缘系统紊乱

对驾驶的影响:

  • 反应时间延长
  • 注意力控制下降
  • 决策能力受损
  • 事故风险急剧上升

1.2 传统检测方法的局限

方法 优点 缺点
主观问卷 简单易行 主观性强、滞后
方向盘传感器 成本低 间接推断、精度差
单一生理信号 有一定精度 易受干扰、适应性差
传统ML (SVM/DT) 可解释 依赖手工特征、泛化能力弱

二、FatigueNet架构详解

2.1 整体架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
┌─────────────────────────────────────────────────────────────────┐
│ FatigueNet 架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ ECG │ │ EDA │ │ EMG │ │ Blink │ 输入信号 │
│ └────┬────┘ └────┬────┘ └────┬────┘ └────┬────┘ │
│ │ │ │ │ │
│ └───────────┴───────────┴───────────┘ │
│ │ │
│ ┌─────▼─────┐ │
│ │ 预处理层 │ ← 滤波、归一化、降采样 │
│ └─────┬─────┘ │
│ │ │
│ ┌─────────────────┼─────────────────┐ │
│ │ │ │ │
│ ┌────▼────┐ ┌────▼────┐ ┌────▼────┐ │
│ │ 时域特征 │ │ 频域特征 │ │ 时频特征 │ 多域特征提取 │
│ └────┬────┘ └────┬────┘ └────┬────┘ │
│ └─────────────────┼─────────────────┘ │
│ │ │
│ ┌─────▼─────┐ │
│ │ GNN │ ← 图神经网络建模信号关系 │
│ └─────┬─────┘ │
│ │ │
│ ┌─────▼─────┐ │
│ │Transformer │ ← 注意力机制捕获长程依赖 │
│ └─────┬─────┘ │
│ │ │
│ ┌─────▼─────┐ │
│ │ MGAF │ ← Meta-Gated自适应融合 │
│ └─────┬─────┘ │
│ │ │
│ ┌─────▼─────┐ │
│ │ MSVM │ ← 多分类SVM │
│ └─────┬─────┘ │
│ │ │
│ ┌──────────┼──────────┐ │
│ ▼ ▼ ▼ │
[低疲劳] [中疲劳] [高疲劳] [超疲劳] 输出 │
│ │
└─────────────────────────────────────────────────────────────────┘

2.2 多模态输入信号

数据集:MePhy Benchmark

  • 来源:意大利摩德纳大学
  • 样本:60名参与者(30男/30女,平均年龄22.85岁)
  • 信号类型:4种生物信号
信号 采样率 设备 检测内容
ECG 1 Hz Polar H10 心率、HRV、R波振幅
EDA 1000 Hz BITalino 皮肤电导、GSR
EMG 1000 Hz BITalino 肌肉活动、MAV
Eye Blink 30 Hz Logitech C920 眨眼频率、持续时间

实验条件:

  1. 休息状态(无疲劳)
  2. 精神疲劳(Stroop测试、数学计算、2-back记忆任务)
  3. 物理疲劳(等长运动、肘部弯曲)
  4. 混合疲劳(同时精神+物理任务)

2.3 预处理流水线

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
FatigueNet 信号预处理流水线
论文 Section 2.2
"""

import numpy as np
from scipy import signal
from scipy.signal import wavelet
import pywt

class FatigueNetPreprocessor:
"""
多模态生物信号预处理器

支持信号:
- ECG: 心电图
- EDA: 皮肤电活动
- EMG: 肌电图
- Blink: 眨眼信号
"""

def __init__(self, target_rate: int = 30):
"""
初始化预处理器

Args:
target_rate: 目标采样率(统一降采样)
"""
self.target_rate = target_rate

def preprocess_ecg(self, ecg_signal: np.ndarray,
original_rate: int = 1) -> np.ndarray:
"""
ECG预处理

Args:
ecg_signal: 原始ECG信号
original_rate: 原始采样率

Returns:
处理后的ECG信号
"""
# Polar H10 采样率较低(1Hz),无需高通滤波
# 基线漂移已由设备内置算法处理

# 降采样到目标速率(如果需要)
if original_rate != self.target_rate:
ecg_signal = signal.resample(ecg_signal,
int(len(ecg_signal) * self.target_rate / original_rate))

# 归一化
ecg_signal = (ecg_signal - np.mean(ecg_signal)) / (np.std(ecg_signal) + 1e-8)

return ecg_signal

def preprocess_eda(self, eda_signal: np.ndarray,
original_rate: int = 1000) -> np.ndarray:
"""
EDA预处理

Args:
eda_signal: 原始EDA信号
original_rate: 原始采样率

Returns:
处理后的EDA信号
"""
# 小波分解去除低频噪声
coeffs = pywt.wavedec(eda_signal, 'db4', level=6)

# 去除低频近似系数(基线漂移)
coeffs[0] = np.zeros_like(coeffs[0])

# 重构信号
eda_signal = pywt.waverec(coeffs, 'db4')

# 低通滤波 (30 Hz cutoff)
sos = signal.butter(4, 30, 'low', fs=original_rate, output='sos')
eda_signal = signal.sosfilt(sos, eda_signal)

# 陷波滤波 (50 Hz 电源干扰)
b, a = signal.iirnotch(50, 30, original_rate)
eda_signal = signal.filtfilt(b, a, eda_signal)

# 降采样
eda_signal = signal.resample(eda_signal,
int(len(eda_signal) * self.target_rate / original_rate))

# 归一化
eda_signal = (eda_signal - np.mean(eda_signal)) / (np.std(eda_signal) + 1e-8)

return eda_signal

def preprocess_emg(self, emg_signal: np.ndarray,
original_rate: int = 1000) -> np.ndarray:
"""
EMG预处理

Args:
emg_signal: 原始EMG信号
original_rate: 原始采样率

Returns:
处理后的EMG信号
"""
# 带通滤波 (20-500 Hz)
sos = signal.butter(4, [20, 500], 'band', fs=original_rate, output='sos')
emg_signal = signal.sosfilt(sos, emg_signal)

# 陷波滤波 (50 Hz)
b, a = signal.iirnotch(50, 30, original_rate)
emg_signal = signal.filtfilt(b, a, emg_signal)

# 中值滤波去除运动伪影
emg_signal = signal.medfilt(emg_signal, kernel_size=5)

# Z-score 异常值检测与替换
z_scores = np.abs((emg_signal - np.mean(emg_signal)) / (np.std(emg_signal) + 1e-8))
outliers = z_scores > 3
emg_signal[outliers] = np.interp(
np.where(outliers)[0],
np.where(~outliers)[0],
emg_signal[~outliers]
)

# 降采样
emg_signal = signal.resample(emg_signal,
int(len(emg_signal) * self.target_rate / original_rate))

# 归一化
emg_signal = (emg_signal - np.mean(emg_signal)) / (np.std(emg_signal) + 1e-8)

return emg_signal

def preprocess_blink(self, frames: np.ndarray,
original_rate: int = 30) -> np.ndarray:
"""
眨眼信号预处理(视频帧)

Args:
frames: 视频帧序列
original_rate: 原始帧率

Returns:
眨眼频率序列
"""
# 光流分析追踪眼睑运动
blink_signal = self._optical_flow_blink_detection(frames)

# 带通滤波 (0.5-30 Hz)
sos = signal.butter(4, [0.5, 30], 'band', fs=original_rate, output='sos')
blink_signal = signal.sosfilt(sos, blink_signal)

# 小波分解去除残余伪影
coeffs = pywt.wavedec(blink_signal, 'db4', level=3)
blink_signal = pywt.waverec(coeffs, 'db4')

# 峰值检测(眨眼事件)
peaks, _ = signal.find_peaks(blink_signal, height=0.5, distance=10)

# 有效眨眼筛选 (100-400ms)
valid_blinks = []
for i in range(len(peaks) - 1):
duration = (peaks[i+1] - peaks[i]) / original_rate * 1000 # ms
if 100 <= duration <= 400:
valid_blinks.append(peaks[i])

# 生成眨眼频率序列
blink_freq = np.zeros(len(blink_signal))
for peak in valid_blinks:
blink_freq[peak] = 1

return blink_freq

def _optical_flow_blink_detection(self, frames: np.ndarray) -> np.ndarray:
"""光流法眨眼检测"""
# 简化实现:实际使用OpenCV光流
# 此处返回模拟信号用于演示
return np.random.randn(len(frames)) * 0.5 + 1.0


# 实际测试
if __name__ == "__main__":
preprocessor = FatigueNetPreprocessor(target_rate=30)

# 模拟数据
ecg_raw = np.random.randn(300) * 0.1 + 1.0 # 300秒数据
eda_raw = np.random.randn(300000) * 0.05 + 0.5 # 1000Hz采样
emg_raw = np.random.randn(300000) * 0.2 # 1000Hz采样

# 预处理
ecg_clean = preprocessor.preprocess_ecg(ecg_raw, original_rate=1)
eda_clean = preprocessor.preprocess_eda(eda_raw, original_rate=1000)
emg_clean = preprocessor.preprocess_emg(emg_raw, original_rate=1000)

print(f"ECG处理后: {ecg_clean.shape}, 均值={ecg_clean.mean():.4f}")
print(f"EDA处理后: {eda_clean.shape}, 均值={eda_clean.mean():.4f}")
print(f"EMG处理后: {emg_clean.shape}, 均值={emg_clean.mean():.4f}")

2.4 多域特征提取

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
"""
FatigueNet 多域特征提取
论文 Section 2.4
"""

import numpy as np
from scipy import signal, stats
from scipy.fft import fft, fftfreq
from typing import Dict, List, Tuple

class MultiDomainFeatureExtractor:
"""
多域特征提取器

支持特征类型:
- 时域特征
- 频域特征
- 时频特征
- 混沌特征
- 分形特征
"""

def __init__(self, window_size: int = 20, step_size: int = 5):
"""
初始化特征提取器

Args:
window_size: 滑动窗口大小(秒)
step_size: 滑动步长(秒)
"""
self.window_size = window_size
self.step_size = step_size

def extract_ecg_features(self, ecg_signal: np.ndarray,
fps: int = 30) -> Dict[str, np.ndarray]:
"""
ECG特征提取

Args:
ecg_signal: ECG信号序列
fps: 采样率

Returns:
特征字典
"""
window_samples = self.window_size * fps
step_samples = self.step_size * fps

features = {
'HR': [], # 心率
'HRV': [], # 心率变异性
'RA': [], # R波振幅
'SDNN': [], # NN间隔标准差
'RMSSD': [], # 连续差异均方根
'pNN50': [], # NN50百分比
'pNN20': [], # NN20百分比
}

for i in range(0, len(ecg_signal) - window_samples, step_samples):
window = ecg_signal[i:i + window_samples]

# 检测R峰(简化实现)
r_peaks, _ = signal.find_peaks(window, height=0.5, distance=10)

if len(r_peaks) < 2:
continue

# 计算RR间隔
rr_intervals = np.diff(r_peaks) / fps * 1000 # ms

# 心率 (BPM)
hr = 60000 / np.mean(rr_intervals)
features['HR'].append(hr)

# HRV
features['HRV'].append(np.std(rr_intervals))

# R波振幅
features['RA'].append(np.mean(window[r_peaks]))

# SDNN
features['SDNN'].append(np.std(rr_intervals))

# RMSSD
diff = np.diff(rr_intervals)
features['RMSSD'].append(np.sqrt(np.mean(diff ** 2)))

# pNN50
nn50 = np.sum(np.abs(diff) > 50)
features['pNN50'].append(nn50 / len(diff) * 100)

# pNN20
nn20 = np.sum(np.abs(diff) > 20)
features['pNN20'].append(nn20 / len(diff) * 100)

return {k: np.array(v) for k, v in features.items()}

def extract_eda_features(self, eda_signal: np.ndarray,
fps: int = 30) -> Dict[str, np.ndarray]:
"""
EDA特征提取

Args:
eda_signal: EDA信号序列
fps: 采样率

Returns:
特征字典
"""
window_samples = self.window_size * fps
step_samples = self.step_size * fps

features = {
'SCL': [], # 皮肤电导水平
'GSR_count': [], # GSR事件计数
'GSR_amp': [], # GSR振幅
'GSR_duration': [], # GSR持续时间
}

for i in range(0, len(eda_signal) - window_samples, step_samples):
window = eda_signal[i:i + window_samples]

# SCL(低频成分)
sos = signal.butter(4, 0.05, 'low', fs=fps, output='sos')
scl = signal.sosfilt(sos, window)
features['SCL'].append(np.mean(scl))

# GSR(高频成分)
sos = signal.butter(4, 0.05, 'high', fs=fps, output='sos')
gsr = signal.sosfilt(sos, window)

# GSR事件检测
peaks, _ = signal.find_peaks(gsr, height=0.01, distance=50)
features['GSR_count'].append(len(peaks))

if len(peaks) > 0:
features['GSR_amp'].append(np.mean(gsr[peaks]))
durations = np.diff(peaks) / fps
features['GSR_duration'].append(np.mean(durations) if len(durations) > 0 else 0)
else:
features['GSR_amp'].append(0)
features['GSR_duration'].append(0)

return {k: np.array(v) for k, v in features.items()}

def extract_emg_features(self, emg_signal: np.ndarray,
fps: int = 30) -> Dict[str, np.ndarray]:
"""
EMG特征提取

Args:
emg_signal: EMG信号序列
fps: 采样率

Returns:
特征字典
"""
window_samples = self.window_size * fps
step_samples = self.step_size * fps

features = {
'MAV': [], # 平均绝对值
'RMS': [], # 均方根
'ZC': [], # 过零率
'SSC': [], # 斜率符号变化
'WL': [], # 波长
}

for i in range(0, len(emg_signal) - window_samples, step_samples):
window = emg_signal[i:i + window_samples]

# MAV
features['MAV'].append(np.mean(np.abs(window)))

# RMS
features['RMS'].append(np.sqrt(np.mean(window ** 2)))

# ZC
zc = np.sum(np.diff(np.sign(window)) != 0)
features['ZC'].append(zc)

# SSC
ssc = np.sum(np.diff(np.sign(np.diff(window))) != 0)
features['SSC'].append(ssc)

# WL
features['WL'].append(np.sum(np.abs(np.diff(window))))

return {k: np.array(v) for k, v in features.items()}

def extract_blink_features(self, blink_signal: np.ndarray,
fps: int = 30) -> Dict[str, np.ndarray]:
"""
眨眼特征提取

Args:
blink_signal: 眨眼信号序列
fps: 采样率

Returns:
特征字典
"""
window_samples = self.window_size * fps
step_samples = self.step_size * fps

features = {
'blink_rate': [], # 眨眼频率
'blink_duration': [], # 眨眼持续时间
'PERCLOS': [], # 闭眼百分比
}

for i in range(0, len(blink_signal) - window_samples, step_samples):
window = blink_signal[i:i + window_samples]

# 眨眼频率
blink_events = window[window > 0]
blink_rate = len(blink_events) / self.window_size * 60 # 次/分钟
features['blink_rate'].append(blink_rate)

# PERCLOS (简化计算)
perclos = np.sum(window > 0.5) / len(window) * 100
features['PERCLOS'].append(perclos)

return {k: np.array(v) for k, v in features.items()}

def extract_all_features(self, signals: Dict[str, np.ndarray],
fps: int = 30) -> np.ndarray:
"""
提取所有模态特征并合并

Args:
signals: 各信号字典 {'ecg': ..., 'eda': ..., 'emg': ..., 'blink': ...}
fps: 采样率

Returns:
合并特征矩阵 (n_samples, n_features)
"""
all_features = []

if 'ecg' in signals:
ecg_feats = self.extract_ecg_features(signals['ecg'], fps)
all_features.append(np.column_stack(list(ecg_feats.values())))

if 'eda' in signals:
eda_feats = self.extract_eda_features(signals['eda'], fps)
all_features.append(np.column_stack(list(eda_feats.values())))

if 'emg' in signals:
emg_feats = self.extract_emg_features(signals['emg'], fps)
all_features.append(np.column_stack(list(emg_feats.values())))

if 'blink' in signals:
blink_feats = self.extract_blink_features(signals['blink'], fps)
all_features.append(np.column_stack(list(blink_feats.values())))

# 合并所有特征
# 注意:需要对齐时间窗口
min_len = min(f.shape[0] for f in all_features)
aligned_features = [f[:min_len] for f in all_features]

return np.hstack(aligned_features)


# 实际测试
if __name__ == "__main__":
extractor = MultiDomainFeatureExtractor(window_size=20, step_size=5)

# 模拟多模态信号
signals = {
'ecg': np.random.randn(600) * 0.1 + 1.0, # 20秒数据
'eda': np.random.randn(600) * 0.05,
'emg': np.random.randn(600) * 0.2,
'blink': np.random.randn(600) * 0.3,
}

features = extractor.extract_all_features(signals, fps=30)
print(f"特征矩阵形状: {features.shape}")
print(f"特征数量: {features.shape[1]}")

2.5 Meta-Gated自适应融合模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
"""
FatigueNet Meta-Gated Adaptive Fusion (MGAF) 模块
论文核心创新:动态模态权重自适应
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List

class MetaGatedAdaptiveFusion(nn.Module):
"""
Meta-Gated 自适应融合模块

功能:
1. 动态计算各模态权重
2. 基于信号质量自适应调整
3. 低延迟实时推理

论文 Section 2.5
"""

def __init__(self,
feature_dim: int,
num_modalities: int = 4,
hidden_dim: int = 64):
"""
初始化MGAF模块

Args:
feature_dim: 特征维度
num_modalities: 模态数量(ECG, EDA, EMG, Blink = 4)
hidden_dim: 隐藏层维度
"""
super().__init__()

self.num_modalities = num_modalities
self.feature_dim = feature_dim

# Meta网络:学习模态权重
self.meta_net = nn.Sequential(
nn.Linear(feature_dim * num_modalities, hidden_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim, num_modalities),
nn.Softmax(dim=-1)
)

# 门控网络:信号质量评估
self.gate_net = nn.ModuleList([
nn.Sequential(
nn.Linear(feature_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid()
) for _ in range(num_modalities)
])

# 特征变换层
self.feature_transform = nn.ModuleList([
nn.Sequential(
nn.Linear(feature_dim, feature_dim),
nn.LayerNorm(feature_dim),
nn.ReLU()
) for _ in range(num_modalities)
])

def forward(self,
modality_features: List[torch.Tensor],
signal_quality: List[float] = None) -> torch.Tensor:
"""
前向传播

Args:
modality_features: 各模态特征列表 [Tensor(B, D), ...]
signal_quality: 信号质量评分 [0-1],可选

Returns:
fused_features: 融合特征 (B, D)
"""
batch_size = modality_features[0].size(0)

# 1. 特征变换
transformed = []
for i, feat in enumerate(modality_features):
trans = self.feature_transform[i](feat)
transformed.append(trans)

# 2. Meta网络计算动态权重
concat_feat = torch.cat(transformed, dim=-1)
meta_weights = self.meta_net(concat_feat) # (B, num_modalities)

# 3. 门控网络计算信号质量权重
gate_weights = []
for i, feat in enumerate(transformed):
gate = self.gate_net[i](feat) # (B, 1)
gate_weights.append(gate)

gate_weights = torch.cat(gate_weights, dim=-1) # (B, num_modalities)

# 4. 外部信号质量(如果提供)
if signal_quality is not None:
quality_weights = torch.tensor(signal_quality,
device=transformed[0].device)
quality_weights = quality_weights.unsqueeze(0).expand(batch_size, -1)
gate_weights = gate_weights * quality_weights

# 5. 综合权重 = Meta权重 * 门控权重
final_weights = meta_weights * gate_weights # (B, num_modalities)
final_weights = final_weights / (final_weights.sum(dim=-1, keepdim=True) + 1e-8)

# 6. 加权融合
fused = torch.zeros_like(transformed[0])
for i, feat in enumerate(transformed):
fused = fused + feat * final_weights[:, i:i+1]

return fused, final_weights


class FatigueNet(nn.Module):
"""
FatigueNet完整模型

架构:
特征提取 -> GNN -> Transformer -> MGAF -> 分类器
"""

def __init__(self,
feature_dim: int = 32,
gnn_hidden: int = 64,
transformer_heads: int = 4,
transformer_layers: int = 2,
num_classes: int = 4):
"""
初始化FatigueNet

Args:
feature_dim: 特征维度
gnn_hidden: GNN隐藏层维度
transformer_heads: Transformer注意力头数
transformer_layers: Transformer层数
num_classes: 分类数(低/中/高/超疲劳)
"""
super().__init__()

self.feature_dim = feature_dim
self.num_modalities = 4 # ECG, EDA, EMG, Blink

# 1. 特征投影
self.feature_projection = nn.ModuleList([
nn.Linear(feature_dim, feature_dim)
for _ in range(self.num_modalities)
])

# 2. GNN层(建模模态间关系)
self.gnn = GraphNeuralNetwork(
input_dim=feature_dim,
hidden_dim=gnn_hidden,
output_dim=feature_dim
)

# 3. Transformer编码器(捕获时序依赖)
encoder_layer = nn.TransformerEncoderLayer(
d_model=feature_dim,
nhead=transformer_heads,
dim_feedforward=gnn_hidden * 2,
dropout=0.1,
batch_first=True
)
self.transformer = nn.TransformerEncoder(
encoder_layer,
num_layers=transformer_layers
)

# 4. MGAF模块
self.mgaf = MetaGatedAdaptiveFusion(
feature_dim=feature_dim,
num_modalities=self.num_modalities
)

# 5. 分类头
self.classifier = nn.Sequential(
nn.Linear(feature_dim, gnn_hidden),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(gnn_hidden, num_classes)
)

def forward(self,
modality_features: List[torch.Tensor],
return_weights: bool = False) -> torch.Tensor:
"""
前向传播

Args:
modality_features: 各模态特征列表
[Tensor(B, T, D), ...]
return_weights: 是否返回融合权重

Returns:
logits: 分类输出 (B, num_classes)
weights: 融合权重(可选)
"""
# 1. 特征投影
projected = []
for i, feat in enumerate(modality_features):
proj = self.feature_projection[i](feat)
projected.append(proj)

# 2. GNN处理
gnn_out = self.gnn(projected)

# 3. Transformer处理
# 将GNN输出序列化
batch_size, seq_len = gnn_out.size(0), gnn_out.size(1)

transformer_out = self.transformer(gnn_out) # (B, T, D)

# 4. MGAF融合(取最后时刻)
last_features = [out[:, -1, :] for out in projected]
fused, weights = self.mgaf(last_features)

# 5. 分类
logits = self.classifier(fused)

if return_weights:
return logits, weights
return logits


class GraphNeuralNetwork(nn.Module):
"""简化的图神经网络层"""

def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()

self.message = nn.Linear(input_dim, hidden_dim)
self.update = nn.Linear(hidden_dim + input_dim, output_dim)

def forward(self, node_features: List[torch.Tensor]) -> torch.Tensor:
"""
前向传播

Args:
node_features: 节点特征列表

Returns:
更新后的特征
"""
# 消息传递
messages = []
for feat in node_features:
msg = self.message(feat)
messages.append(msg)

# 聚合
aggregated = torch.stack(messages, dim=0).mean(dim=0)

# 更新
updated = []
for feat in node_features:
update_input = torch.cat([feat, aggregated], dim=-1)
update = self.update(update_input)
updated.append(update)

# 合并返回
return torch.stack(updated, dim=1).mean(dim=1)


# 实际测试
if __name__ == "__main__":
# 创建模型
model = FatigueNet(
feature_dim=32,
gnn_hidden=64,
transformer_heads=4,
transformer_layers=2,
num_classes=4
)

# 模拟输入
batch_size = 8
seq_len = 10

modality_features = [
torch.randn(batch_size, seq_len, 32) # ECG
for _ in range(4)
]

# 前向传播
logits, weights = model(modality_features, return_weights=True)

print(f"输出形状: {logits.shape}")
print(f"融合权重: {weights[0].detach().numpy()}")
print(f"权重和: {weights[0].sum().item():.4f}")

三、实验结果

3.1 性能对比

模型 准确率 F1-Score 推理延迟
CNN 88.2% 0.876 45ms
LSTM 89.5% 0.891 62ms
CNN-LSTM 91.3% 0.908 78ms
FatigueNet 96.1% 0.958 35ms

3.2 各疲劳等级检测效果

疲劳等级 Precision Recall F1-Score
低疲劳 0.95 0.97 0.96
中疲劳 0.94 0.92 0.93
高疲劳 0.97 0.95 0.96
超疲劳 0.99 0.98 0.98

3.3 MGAF模块有效性

配置 准确率 说明
固定权重融合 91.2% 各模态等权重
注意力融合 93.8% 学习权重
MGAF 96.1% Meta+门控自适应

四、IMS开发启示

4.1 技术选型建议

需求 推荐方案 理由
高精度检测 多模态融合 单一视觉易受干扰
实时部署 MGAF模块 低延迟35ms
鲁棒性 Meta自适应 动态调整权重
车规级 简化Transformer 减少计算量

4.2 部署方案

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
IMS车规级部署方案
基于FatigueNet简化版本
"""

import onnxruntime as ort
import numpy as np

class IMSFatigueDetector:
"""
IMS疲劳检测器(车规级简化版)

简化策略:
1. 移除Transformer,保留GNN+MGAF
2. 使用INT8量化
3. 固定窗口大小
"""

def __init__(self,
model_path: str = "fatiguenet_int8.onnx"):
"""
初始化检测器

Args:
model_path: ONNX模型路径
"""
# 加载量化模型
self.session = ort.InferenceSession(
model_path,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)

# 获取输入输出信息
self.input_names = [inp.name for inp in self.session.get_inputs()]
self.output_names = [out.name for out in self.session.get_outputs()]

# 窗口配置
self.window_size = 600 # 20秒 @ 30fps
self.feature_dim = 32

def preprocess(self,
ecg: np.ndarray,
eda: np.ndarray,
emg: np.ndarray,
blink: np.ndarray) -> Dict[str, np.ndarray]:
"""
预处理多模态信号

Args:
ecg: ECG信号 (N,)
eda: EDA信号 (N,)
emg: EMG信号 (N,)
blink: 眨眼信号 (N,)

Returns:
模型输入字典
"""
# 特征提取(简化版)
features = self._extract_features(ecg, eda, emg, blink)

# 构造输入
inputs = {
name: features[i:i+1].astype(np.float32)
for i, name in enumerate(self.input_names)
}

return inputs

def _extract_features(self, *signals) -> np.ndarray:
"""简化特征提取"""
# 实际部署时使用预训练特征提取器
# 此处返回模拟特征
return np.random.randn(4, self.feature_dim).astype(np.float32)

def detect(self,
ecg: np.ndarray,
eda: np.ndarray,
emg: np.ndarray,
blink: np.ndarray) -> Dict[str, any]:
"""
疲劳检测

Args:
ecg/eda/emg/blink: 多模态信号

Returns:
检测结果字典
"""
# 预处理
inputs = self.preprocess(ecg, eda, emg, blink)

# 推理
outputs = self.session.run(self.output_names, inputs)

# 解析结果
logits = outputs[0] # (1, 4)
weights = outputs[1] if len(outputs) > 1 else None # (1, 4)

# 后处理
probs = self._softmax(logits[0])
pred_class = np.argmax(probs)

fatigue_levels = ['低疲劳', '中疲劳', '高疲劳', '超疲劳']

return {
'level': fatigue_levels[pred_class],
'probability': probs[pred_class],
'probabilities': {level: prob for level, prob in zip(fatigue_levels, probs)},
'modality_weights': weights[0].tolist() if weights is not None else None
}

def _softmax(self, x: np.ndarray) -> np.ndarray:
exp_x = np.exp(x - np.max(x))
return exp_x / exp_x.sum()


# 实际测试
if __name__ == "__main__":
# 模拟部署
detector = IMSFatigueDetector()

# 模拟信号
ecg = np.random.randn(600)
eda = np.random.randn(600)
emg = np.random.randn(600)
blink = np.random.randn(600)

# 检测
result = detector.detect(ecg, eda, emg, blink)

print(f"疲劳等级: {result['level']}")
print(f"置信度: {result['probability']:.2%}")
print(f"各等级概率: {result['probabilities']}")

4.3 Euro NCAP合规建议

Euro NCAP要求 FatigueNet对应 合规状态
疲劳检测(KSS≥7) 高疲劳等级
实时检测(≤3s) 35ms延迟
多环境适应 MGAF自适应
误报率<5% F1=0.958

五、总结

5.1 核心贡献

  1. 架构创新:首次将GNN+Transformer融合用于疲劳检测
  2. 自适应融合:MGAF模块实现动态模态权重调整
  3. 实时部署:35ms延迟满足车规级要求
  4. 精度提升:相比传统方法提升5%+

5.2 局限与展望

局限 改进方向
需要多种传感器 探索视觉替代方案
数据集规模有限 大规模数据集验证
未考虑个体差异 引入个性化模型
仅限疲劳检测 扩展到分心/酒驾检测

论文链接: https://www.nature.com/articles/s41598-025-00640-z
代码开源: 待发布(关注GitHub FatigueNet)
数据集: MePhy Benchmark (60人,4种信号)


FatigueNet论文解读:GNN+Transformer多模态疲劳检测实时部署方案
https://dapalm.com/2026/04/24/2026-04-24-fatiguenet-gnn-transformer-multimodal-detection/
作者
Mars
发布于
2026年4月24日
许可协议