CSF-GTNet:基于EEG信号的多维特征融合疲劳驾驶检测网络

论文来源: PubMed, 2023
核心创新: 时空-频域多维特征融合 + ConvNext-GeLU-BiLSTM架构
应用场景: IMS 疲劳检测的高精度辅助方案


研究背景

EEG疲劳检测的优势与挑战

检测方式 优势 挑战
视觉方法 非接触、易部署 受光照/遮挡影响
EEG信号 直接反映大脑状态、高精度 接触式、个体差异大
多模态融合 互补性强 成本高、系统复杂

EEG疲劳检测的核心问题:

  1. 个体差异:不同人的EEG信号模式差异大
  2. 特征提取不充分:传统方法忽略时空-频域关联
  3. 模型泛化能力弱:对跨被试场景效果差

CSF-GTNet 核心架构

整体框架

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
EEG原始信号

┌───────────────────────────────────────┐
│ 数据预处理模块 │
│ - 带通滤波 (0.5-50Hz) │
│ - 独立成分分析 (ICA) 去除伪影 │
│ - 标准化 │
└───────────────────────────────────────┘

┌───────────────────────────────────────┐
│ 时域特征提取 (ConvNext) │
│ - 局部时序模式 │
│ - 多尺度卷积 │
└───────────────────────────────────────┘

┌───────────────────────────────────────┐
│ 频域特征提取 (STFT + CNN) │
│ - 功率谱密度 │
│ - 频带能量分布 │
└───────────────────────────────────────┘

┌───────────────────────────────────────┐
│ 时空融合模块 (BiLSTM) │
│ - 前后向时序依赖 │
│ - 特征交互 │
└───────────────────────────────────────┘

┌───────────────────────────────────────┐
│ 分类器 │
│ - 全连接层 │
│ - Softmax │
└───────────────────────────────────────┘

疲劳状态输出 (清醒/轻度疲劳/重度疲劳)

核心代码实现

1. EEG数据预处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
EEG信号预处理模块
包含:滤波、伪影去除、标准化
"""

import numpy as np
from scipy import signal
from scipy.signal import butter, filtfilt, iirnotch
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import FastICA
import torch
import torch.nn as nn
from typing import Tuple, Optional


class EEGPreprocessor:
"""
EEG信号预处理器

处理步骤:
1. 带通滤波 (0.5-50Hz)
2. 陷波滤波 (50Hz工频干扰)
3. ICA去除眼电/肌电伪影
4. 标准化

参数:
sample_rate: 采样率 (Hz)
lowcut: 带通滤波低频截止
highcut: 带通滤波高频截止
notch_freq: 陷波频率
"""

def __init__(
self,
sample_rate: int = 256,
lowcut: float = 0.5,
highcut: float = 50.0,
notch_freq: float = 50.0
):
self.sample_rate = sample_rate
self.lowcut = lowcut
self.highcut = highcut
self.notch_freq = notch_freq
self.scaler = StandardScaler()

def bandpass_filter(self, data: np.ndarray) -> np.ndarray:
"""
带通滤波

Args:
data: EEG数据, shape=(channels, samples)

Returns:
filtered: 滤波后数据
"""
nyquist = self.sample_rate / 2
low = self.lowcut / nyquist
high = self.highcut / nyquist

b, a = butter(4, [low, high], btype='band')
filtered = filtfilt(b, a, data, axis=1)

return filtered

def notch_filter(self, data: np.ndarray) -> np.ndarray:
"""
陷波滤波去除工频干扰

Args:
data: EEG数据

Returns:
filtered: 滤波后数据
"""
quality_factor = 30.0
b, a = iirnotch(self.notch_freq, quality_factor, self.sample_rate)
filtered = filtfilt(b, a, data, axis=1)

return filtered

def remove_artifacts_ica(
self,
data: np.ndarray,
n_components: Optional[int] = None
) -> np.ndarray:
"""
使用ICA去除伪影

Args:
data: EEG数据, shape=(channels, samples)
n_components: ICA成分数量

Returns:
clean_data: 去除伪影后的数据
"""
n_channels = data.shape[0]
if n_components is None:
n_components = min(n_channels, 20)

# 转置以适应FastICA
data_t = data.T # (samples, channels)

# ICA分解
ica = FastICA(
n_components=n_components,
random_state=42,
max_iter=500
)
components = ica.fit_transform(data_t)

# 识别并去除眼电成分
# 眼电成分通常在前部电极有高幅度
mixing_matrix = ica.components_

# 简单启发式:去除方差最大的成分(通常是伪影)
component_var = np.var(components, axis=0)
threshold = np.percentile(component_var, 90)

# 保留非伪影成分
keep_mask = component_var < threshold
clean_components = components[:, keep_mask]

# 重建信号
clean_data_t = ica.inverse_transform(clean_components)
clean_data = clean_data_t.T

return clean_data

def normalize(
self,
data: np.ndarray,
fit: bool = True
) -> np.ndarray:
"""
标准化

Args:
data: EEG数据, shape=(channels, samples)
fit: 是否拟合scaler

Returns:
normalized: 标准化后的数据
"""
# 转置以适应sklearn
data_t = data.T

if fit:
normalized = self.scaler.fit_transform(data_t)
else:
normalized = self.scaler.transform(data_t)

return normalized.T

def process(
self,
raw_data: np.ndarray,
remove_artifacts: bool = True
) -> np.ndarray:
"""
完整预处理流程

Args:
raw_data: 原始EEG数据, shape=(channels, samples)
remove_artifacts: 是否去除伪影

Returns:
processed: 预处理后的数据
"""
# 1. 带通滤波
data = self.bandpass_filter(raw_data)

# 2. 陷波滤波
data = self.notch_filter(data)

# 3. ICA去伪影
if remove_artifacts:
data = self.remove_artifacts_ica(data)

# 4. 标准化
data = self.normalize(data, fit=True)

return data


# 测试代码
if __name__ == "__main__":
# 模拟EEG数据
np.random.seed(42)
n_channels = 14
n_samples = 256 * 5 # 5秒数据 @ 256Hz

# 模拟信号 + 噪声
t = np.linspace(0, 5, n_samples)
raw_eeg = np.zeros((n_channels, n_samples))

for ch in range(n_channels):
# Alpha波 (8-13 Hz)
alpha = 50 * np.sin(2 * np.pi * 10 * t + np.random.random() * 2 * np.pi)
# Beta波 (13-30 Hz)
beta = 30 * np.sin(2 * np.pi * 20 * t + np.random.random() * 2 * np.pi)
# 噪声
noise = 20 * np.random.randn(n_samples)
# 50Hz工频干扰
powerline = 10 * np.sin(2 * np.pi * 50 * t)

raw_eeg[ch] = alpha + beta + noise + powerline

# 预处理
preprocessor = EEGPreprocessor(sample_rate=256)
clean_eeg = preprocessor.process(raw_eeg)

print("=== EEG预处理测试 ===")
print(f"输入形状: {raw_eeg.shape}")
print(f"输出形状: {clean_eeg.shape}")
print(f"输入范围: [{raw_eeg.min():.2f}, {raw_eeg.max():.2f}]")
print(f"输出范围: [{clean_eeg.min():.2f}, {clean_eeg.max():.2f}]")

2. 时域特征提取 (ConvNext)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
时域特征提取模块
基于ConvNext架构提取EEG时域特征
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple


class LayerNorm2d(nn.Module):
"""2D Layer Normalization"""

def __init__(self, normalized_shape: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.normalized_shape = (normalized_shape, )

def forward(self, x: torch.Tensor) -> torch.Tensor:
mean = x.mean(1, keepdim=True)
std = x.std(1, keepdim=True)
x = (x - mean) / (std + self.eps)
return self.weight[:, None, None] * x + self.bias[:, None, None]


class ConvNextBlock(nn.Module):
"""
ConvNext Block

核心设计:
1. Depthwise卷积(大感受野)
2. LayerNorm
3. GELU激活
4. Pointwise卷积
"""

def __init__(
self,
dim: int,
kernel_size: int = 7,
expansion_ratio: int = 4
):
super().__init__()

# Depthwise卷积
self.dwconv = nn.Conv1d(
dim, dim,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=dim
)

# LayerNorm
self.norm = nn.LayerNorm(dim)

# Pointwise卷积
hidden_dim = dim * expansion_ratio
self.pwconv1 = nn.Linear(dim, hidden_dim)
self.act = nn.GELU()
self.pwconv2 = nn.Linear(hidden_dim, dim)

def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x

# Depthwise卷积
x = self.dwconv(x)

# LayerNorm (需要转置)
x = x.transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2)

# Pointwise卷积
x = x.transpose(1, 2)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
x = x.transpose(1, 2)

# 残差连接
return residual + x


class TemporalFeatureExtractor(nn.Module):
"""
时域特征提取器

基于ConvNext架构,提取EEG时域特征

输入: (batch, channels, time_steps)
输出: (batch, features, time_steps)
"""

def __init__(
self,
in_channels: int = 14,
embed_dim: int = 64,
num_blocks: int = 3,
kernel_size: int = 7
):
super().__init__()

# 输入投影
self.proj_in = nn.Sequential(
nn.Conv1d(in_channels, embed_dim, kernel_size=3, padding=1),
nn.BatchNorm1d(embed_dim),
nn.GELU()
)

# ConvNext Blocks
self.blocks = nn.ModuleList([
ConvNextBlock(embed_dim, kernel_size=kernel_size)
for _ in range(num_blocks)
])

# 下采样层
self.downsample = nn.Sequential(
nn.Conv1d(embed_dim, embed_dim * 2, kernel_size=3, stride=2, padding=1),
nn.BatchNorm1d(embed_dim * 2),
nn.GELU()
)

self.out_dim = embed_dim * 2

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, channels, time_steps)

Returns:
features: (batch, out_dim, time_steps // 2)
"""
# 输入投影
x = self.proj_in(x)

# ConvNext Blocks
for block in self.blocks:
x = block(x)

# 下采样
x = self.downsample(x)

return x


# 测试代码
if __name__ == "__main__":
# 模拟EEG数据
batch_size = 8
n_channels = 14
time_steps = 256 # 1秒 @ 256Hz

x = torch.randn(batch_size, n_channels, time_steps)

# 时域特征提取
model = TemporalFeatureExtractor(
in_channels=n_channels,
embed_dim=64,
num_blocks=3
)

features = model(x)

print("=== 时域特征提取测试 ===")
print(f"输入形状: {x.shape}")
print(f"输出形状: {features.shape}")
print(f"参数量: {sum(p.numel() for p in model.parameters()):,}")

3. 频域特征提取 (STFT + CNN)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
频域特征提取模块
使用STFT提取频谱特征,CNN提取频域模式
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple


class SpectralFeatureExtractor(nn.Module):
"""
频域特征提取器

步骤:
1. STFT变换
2. 功率谱密度计算
3. CNN提取频域模式

输入: (batch, channels, time_steps)
输出: (batch, features, freq_bins)
"""

def __init__(
self,
sample_rate: int = 256,
n_fft: int = 64,
hop_length: int = 16,
n_mels: int = 32,
embed_dim: int = 64
):
super().__init__()

self.sample_rate = sample_rate
self.n_fft = n_fft
self.hop_length = hop_length
self.n_mels = n_mels

# 频域CNN
self.freq_cnn = nn.Sequential(
# 输入: (batch, channels, freq, time)
nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(32),
nn.GELU(),
nn.MaxPool2d(2),

nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.GELU(),
nn.MaxPool2d(2),

nn.Conv2d(64, embed_dim, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(embed_dim),
nn.GELU(),
)

# 频带注意力(重点关注Alpha、Beta频带)
freq_bins = n_fft // 2 + 1
self.freq_attention = nn.Sequential(
nn.AdaptiveAvgPool2d((freq_bins, 1)),
nn.Conv2d(embed_dim, embed_dim // 4, kernel_size=1),
nn.GELU(),
nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=1),
nn.Sigmoid()
)

self.out_dim = embed_dim

def compute_spectrogram(
self,
x: torch.Tensor
) -> torch.Tensor:
"""
计算频谱图

Args:
x: (batch, channels, time_steps)

Returns:
spec: (batch, channels, freq, time)
"""
batch, channels, time_steps = x.shape

# 对每个通道计算STFT
specs = []
for ch in range(channels):
# STFT
window = torch.hann_window(self.n_fft, device=x.device)
stft = torch.stft(
x[:, ch, :],
n_fft=self.n_fft,
hop_length=self.hop_length,
window=window,
return_complex=True
)
# 功率谱
power = torch.abs(stft) ** 2
specs.append(power)

# 堆叠: (batch, channels, freq, time)
spec = torch.stack(specs, dim=1)

# 平均所有通道: (batch, 1, freq, time)
spec = spec.mean(dim=1, keepdim=True)

# Log压缩
spec = torch.log(spec + 1e-8)

return spec

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, channels, time_steps)

Returns:
features: (batch, out_dim, freq_bins)
"""
# 计算频谱图
spec = self.compute_spectrogram(x)

# CNN提取特征
features = self.freq_cnn(spec)

# 频带注意力
attention = self.freq_attention(features)
features = features * attention

# 全局时间池化
features = features.mean(dim=-1) # (batch, embed_dim, freq)

return features


# 测试代码
if __name__ == "__main__":
# 模拟EEG数据
batch_size = 8
n_channels = 14
time_steps = 256

x = torch.randn(batch_size, n_channels, time_steps)

# 频域特征提取
model = SpectralFeatureExtractor(
sample_rate=256,
n_fft=64,
hop_length=16
)

features = model(x)

print("=== 频域特征提取测试 ===")
print(f"输入形状: {x.shape}")
print(f"输出形状: {features.shape}")
print(f"参数量: {sum(p.numel() for p in model.parameters()):,}")

4. 时空融合模块 (BiLSTM)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
时空融合模块
使用BiLSTM融合时域和频域特征
"""

import torch
import torch.nn as nn
from typing import Tuple


class SpatioTemporalFusion(nn.Module):
"""
时空特征融合模块

使用BiLSTM融合:
1. 时域特征(来自ConvNext)
2. 频域特征(来自频域CNN)
"""

def __init__(
self,
temporal_dim: int = 128,
spectral_dim: int = 64,
hidden_dim: int = 128,
num_layers: int = 2,
dropout: float = 0.3
):
super().__init__()

# 特征对齐
self.temporal_proj = nn.Linear(temporal_dim, hidden_dim)
self.spectral_proj = nn.Linear(spectral_dim, hidden_dim)

# BiLSTM
self.bilstm = nn.LSTM(
input_size=hidden_dim * 2, # 时域 + 频域
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
dropout=dropout if num_layers > 1 else 0
)

# 输出投影
self.output_proj = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.GELU(),
nn.Dropout(dropout)
)

self.out_dim = hidden_dim

def forward(
self,
temporal_features: torch.Tensor,
spectral_features: torch.Tensor
) -> torch.Tensor:
"""
Args:
temporal_features: (batch, temporal_dim, time_steps)
spectral_features: (batch, spectral_dim, freq_bins)

Returns:
fused: (batch, out_dim)
"""
batch = temporal_features.size(0)

# 特征投影
temporal = self.temporal_proj(
temporal_features.transpose(1, 2)
) # (batch, time, hidden)

spectral = self.spectral_proj(
spectral_features.transpose(1, 2)
) # (batch, freq, hidden)

# 广播频域特征到时域维度
# 简单方法:复制到匹配时域长度
time_steps = temporal.size(1)
freq_bins = spectral.size(1)

# 线性插值
spectral_expanded = F.interpolate(
spectral.transpose(1, 2),
size=time_steps,
mode='linear',
align_corners=False
).transpose(1, 2) # (batch, time, hidden)

# 拼接时域和频域特征
combined = torch.cat([temporal, spectral_expanded], dim=-1)

# BiLSTM
lstm_out, _ = self.bilstm(combined)

# 全局平均池化
fused = lstm_out.mean(dim=1)

# 输出投影
fused = self.output_proj(fused)

return fused


class FatigueClassifier(nn.Module):
"""
疲劳分类器

分类等级:
- 0: 清醒
- 1: 轻度疲劳
- 2: 重度疲劳
"""

def __init__(
self,
feature_dim: int = 128,
num_classes: int = 3,
dropout: float = 0.5
):
super().__init__()

self.classifier = nn.Sequential(
nn.Linear(feature_dim, 256),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(256, 128),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(128, num_classes)
)

def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
Args:
features: (batch, feature_dim)

Returns:
logits: (batch, num_classes)
"""
return self.classifier(features)

5. 完整CSF-GTNet模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
CSF-GTNet: 完整疲劳检测模型
ConvNext-GeLU-BiLSTM时空频域融合网络
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple


class CSFGTNet(nn.Module):
"""
CSF-GTNet: 时空频域融合疲劳检测网络

架构:
1. 时域特征提取 (ConvNext)
2. 频域特征提取 (STFT + CNN)
3. 时空融合 (BiLSTM)
4. 疲劳分类

输入: (batch, channels, time_steps)
输出: (batch, num_classes) 疲劳等级
"""

def __init__(
self,
in_channels: int = 14,
sample_rate: int = 256,
num_classes: int = 3,
temporal_embed_dim: int = 64,
spectral_embed_dim: int = 64,
fusion_hidden_dim: int = 128,
dropout: float = 0.3
):
super().__init__()

# 时域特征提取
self.temporal_extractor = TemporalFeatureExtractor(
in_channels=in_channels,
embed_dim=temporal_embed_dim,
num_blocks=3
)

# 频域特征提取
self.spectral_extractor = SpectralFeatureExtractor(
sample_rate=sample_rate,
n_fft=64,
hop_length=16,
embed_dim=spectral_embed_dim
)

# 时空融合
self.fusion = SpatioTemporalFusion(
temporal_dim=self.temporal_extractor.out_dim,
spectral_dim=self.spectral_extractor.out_dim,
hidden_dim=fusion_hidden_dim,
dropout=dropout
)

# 分类器
self.classifier = FatigueClassifier(
feature_dim=fusion_hidden_dim,
num_classes=num_classes,
dropout=dropout
)

def forward(
self,
x: torch.Tensor,
return_features: bool = False
) -> Dict[str, torch.Tensor]:
"""
Args:
x: (batch, channels, time_steps)
return_features: 是否返回中间特征

Returns:
output: dict with 'logits' and optionally 'features'
"""
# 时域特征
temporal_features = self.temporal_extractor(x)

# 频域特征
spectral_features = self.spectral_extractor(x)

# 时空融合
fused_features = self.fusion(temporal_features, spectral_features)

# 分类
logits = self.classifier(fused_features)

output = {'logits': logits}

if return_features:
output['temporal_features'] = temporal_features
output['spectral_features'] = spectral_features
output['fused_features'] = fused_features

return output

def predict(self, x: torch.Tensor) -> torch.Tensor:
"""
预测疲劳等级

Args:
x: (batch, channels, time_steps)

Returns:
predictions: (batch,) 疲劳等级
"""
self.eval()
with torch.no_grad():
output = self.forward(x)
predictions = output['logits'].argmax(dim=1)
return predictions


# 完整训练代码
if __name__ == "__main__":
# 模拟数据
batch_size = 16
n_channels = 14
time_steps = 256 * 5 # 5秒 @ 256Hz
num_classes = 3

# 创建模型
model = CSFGTNet(
in_channels=n_channels,
sample_rate=256,
num_classes=num_classes
)

# 模拟输入
x = torch.randn(batch_size, n_channels, time_steps)
labels = torch.randint(0, num_classes, (batch_size,))

# 前向传播
output = model(x, return_features=True)

print("=== CSF-GTNet 测试 ===")
print(f"输入形状: {x.shape}")
print(f"输出形状: {output['logits'].shape}")
print(f"时域特征形状: {output['temporal_features'].shape}")
print(f"频域特征形状: {output['spectral_features'].shape}")
print(f"融合特征形状: {output['fused_features'].shape}")
print(f"总参数量: {sum(p.numel() for p in model.parameters()):,}")

# 训练示例
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# 单步训练
model.train()
optimizer.zero_grad()
output = model(x)
loss = criterion(output['logits'], labels)
loss.backward()
optimizer.step()

print(f"\n训练损失: {loss.item():.4f}")

实验结果

数据集

数据集 被试数 样本数 采样率 通道数
SEED-VIG 23 15,840 200 Hz 17
公开疲劳数据集 10 8,000 256 Hz 14

性能对比

方法 准确率 F1分数 参数量
传统SVM 78.2% 0.76 -
CNN-LSTM 85.6% 0.84 1.2M
EEGNet 87.3% 0.86 0.8M
CSF-GTNet 92.4% 0.91 2.1M

跨被试性能

指标 被试内 跨被试
准确率 92.4% 86.7%
Kappa系数 0.89 0.80

IMS 开发启示

1. 多模态融合方案

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# 推荐架构:视觉 + EEG 混合
class MultimodalFatigueDetector:
"""
多模态疲劳检测

主模块:视觉(非接触,实时)
辅助模块:EEG(高精度,接触式)
"""

def __init__(self):
self.visual_detector = PERCLOSDetector()
self.eeg_detector = CSFGTNet()
self.fusion_weights = {
'visual': 0.6,
'eeg': 0.4
}

def detect(self, video_frame, eeg_signal):
# 视觉检测
visual_score = self.visual_detector.update(video_frame)

# EEG检测(可选,当可用时)
if eeg_signal is not None:
eeg_score = self.eeg_detector.predict(eeg_signal)
# 加权融合
final_score = (
self.fusion_weights['visual'] * visual_score +
self.fusion_weights['eeg'] * eeg_score
)
else:
final_score = visual_score

return final_score

2. 部署建议

场景 推荐方案 硬件
乘用车 纯视觉 红外摄像头 + 嵌入式NPU
商用车/长途运输 视觉 + 可穿戴EEG 摄像头 + 智能头带
实验室/研究 完整多模态 高密度EEG + 多摄像头

3. 关键频带

频带 频率范围 疲劳特征
Alpha 8-13 Hz 疲劳时增强
Beta 13-30 Hz 疲劳时减弱
Theta 4-8 Hz 重度疲劳增强
Delta 0.5-4 Hz 睡眠状态

总结

要素 内容
核心创新 时空-频域多维特征融合
关键技术 ConvNext时域提取 + STFT频域提取 + BiLSTM融合
性能 准确率92.4%,跨被试86.7%
IMS应用 多模态疲劳检测的辅助方案
部署难点 EEG需接触式传感器,适合商用车/研究场景

发布时间: 2026-04-22
标签: #疲劳检测 #EEG #深度学习 #多模态融合 #IMS


CSF-GTNet:基于EEG信号的多维特征融合疲劳驾驶检测网络
https://dapalm.com/2026/04/22/2026-04-22-eeg-fatigue-detection-csf-gtnet/
作者
Mars
发布于
2026年4月22日
许可协议