T-SCAF 多任务学习DMS:疲劳检测与人脸识别统一框架论文解读与代码复现

T-SCAF 多任务学习DMS:疲劳检测与人脸识别统一框架论文解读与代码复现

论文信息

  • 标题: Multi-Task Learning for Fatigue Detection and Face Recognition of Drivers via Tree-Style Space-Channel Attention Fusion Network
  • 作者: Gao Z., Chen X., Li N., et al.
  • 机构: 华侨大学
  • 发表时间: 2024年5月
  • 链接: arXiv:2405.07845

核心创新

T-SCAF 提出树形多任务学习架构,共享骨干网络的同时利用空间-通道注意力融合,同时完成驾驶员疲劳检测和人脸识别。

核心贡献:

  1. 树形多任务建模方法
  2. LASE-Net 空间-通道注意力融合模块
  3. 单任务数据集训练多任务模型的技术
  4. SOTA 性能

问题背景

传统方法的局限

并行式多任务方案:

1
2
输入图像 ─┬─→ 疲劳检测模型 ─→ 疲劳状态
└─→ 人脸识别模型 ─→ 驾驶员ID

问题:

  • 两个模型独立提取相似特征(人脸图像)
  • 计算资源浪费
  • 部署成本高

多任务学习优势

1
2
输入图像 ─→ 共享骨干 ─┬─→ 疲劳检测分支 ─→ 疲劳状态
└─→ 人脸识别分支 ─→ 驾驶员ID

优势:

  • 共享底层特征提取
  • 减少参数量
  • 提高推理效率

方法详解

树形多任务架构

1
2
3
4
5
6
7
8
9
10
11
12
13
输入图像

┌─────────────────┐
│ 共享骨干 (CNN) │ ← 根节点
└─────────────────┘

┌─────────────────────────┐
│ LASE-Net 分支模块 │
└─────────────────────────┘
↓ ↓
┌─────────┐ ┌─────────┐
│疲劳检测头│ │人脸识别头│
└─────────┘ └─────────┘

1. 共享骨干网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import torchvision.models as models

class SharedBackbone(nn.Module):
"""
共享骨干网络

基于 ResNet 提取人脸特征
"""

def __init__(
self,
backbone_name: str = 'resnet18',
pretrained: bool = True
):
super().__init__()

# 加载预训练 ResNet
if backbone_name == 'resnet18':
resnet = models.resnet18(pretrained=pretrained)
self.features = nn.Sequential(*list(resnet.children())[:-1])
self.feature_dim = 512
elif backbone_name == 'resnet50':
resnet = models.resnet50(pretrained=pretrained)
self.features = nn.Sequential(*list(resnet.children())[:-1])
self.feature_dim = 2048
else:
raise ValueError(f"Unknown backbone: {backbone_name}")

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
前向传播

Args:
x: 输入图像, shape=(B, 3, H, W)

Returns:
features: 共享特征, shape=(B, D, 1, 1)
"""
return self.features(x)


# 测试
if __name__ == "__main__":
backbone = SharedBackbone('resnet18')
x = torch.randn(4, 3, 224, 224)
features = backbone(x)
print(f"特征形状: {features.shape}") # (4, 512, 1, 1)

2. LASE-Net 空间-通道注意力融合

核心结构:

1
2
3
4
5
6
7
8
9
输入特征

┌───────────────────────────────────────┐
│ LASE-Net 模块 │
├───────────────────────────────────────┤
│ LANet (通道注意力) ─┐ │
│ SENet (空间注意力) ─┼─→ 融合 ─→ 输出 │
│ 残差连接 ─┘ │
└───────────────────────────────────────┘
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class LANet(nn.Module):
"""
通道注意力模块 (Local Attention Network)

强调通道维度的重要性
"""

def __init__(self, in_channels: int, reduction: int = 16):
super().__init__()

self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)

self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction),
nn.ReLU(),
nn.Linear(in_channels // reduction, in_channels),
nn.Sigmoid()
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: 输入特征, shape=(B, C, H, W)

Returns:
channel_attended: 通道注意力加权特征
"""
B, C, _, _ = x.shape

# 平均池化分支
avg_out = self.avg_pool(x).view(B, C)
avg_out = self.fc(avg_out)

# 最大池化分支
max_out = self.max_pool(x).view(B, C)
max_out = self.fc(max_out)

# 融合
channel_weights = torch.sigmoid(avg_out + max_out).view(B, C, 1, 1)

return x * channel_weights


class SENet(nn.Module):
"""
空间注意力模块 (Spatial Enhancement Network)

强调空间位置的重要性
"""

def __init__(self, kernel_size: int = 7):
super().__init__()

self.conv = nn.Sequential(
nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2),
nn.Sigmoid()
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: 输入特征, shape=(B, C, H, W)

Returns:
spatial_attended: 空间注意力加权特征
"""
# 通道维度的平均和最大
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)

# 拼接
concat = torch.cat([avg_out, max_out], dim=1)

# 空间注意力权重
spatial_weights = self.conv(concat)

return x * spatial_weights


class LASENet(nn.Module):
"""
LASE-Net: 空间-通道注意力融合模块

结合 LANet 和 SENet 的优势
"""

def __init__(
self,
in_channels: int,
reduction: int = 16,
kernel_size: int = 7
):
super().__init__()

# 通道注意力分支
self.lanet = LANet(in_channels, reduction)

# 空间注意力分支
self.senet = SENet(kernel_size)

# 融合层
self.fusion = nn.Conv2d(in_channels * 2, in_channels, 1)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
前向传播

Args:
x: 输入特征, shape=(B, C, H, W)

Returns:
fused: 融合特征
"""
# 通道注意力
channel_attended = self.lanet(x)

# 空间注意力
spatial_attended = self.senet(x)

# 融合
concat = torch.cat([channel_attended, spatial_attended], dim=1)
fused = self.fusion(concat)

# 残差连接
return fused + x


# 测试
if __name__ == "__main__":
lase_net = LASENet(512)
x = torch.randn(4, 512, 7, 7)
output = lase_net(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")

3. 树形多任务模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
class TaskHead(nn.Module):
"""
任务头

用于特定任务的分类/回归
"""

def __init__(
self,
in_features: int,
num_classes: int,
task_type: str = 'classification'
):
super().__init__()

self.task_type = task_type

self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.classifier(x)


class TSCAFModel(nn.Module):
"""
T-SCAF: 树形空间-通道注意力融合模型

多任务学习框架:
- 任务1: 疲劳检测
- 任务2: 人脸识别
"""

def __init__(
self,
backbone_name: str = 'resnet18',
num_fatigue_classes: int = 2, # 疲劳/正常
num_identity_classes: int = 100, # 驾驶员数量
pretrained: bool = True
):
super().__init__()

# 共享骨干
self.backbone = SharedBackbone(backbone_name, pretrained)
feature_dim = self.backbone.feature_dim

# 疲劳检测分支
self.fatigue_lase = LASENet(feature_dim)
self.fatigue_head = TaskHead(
feature_dim, num_fatigue_classes, 'classification'
)

# 人脸识别分支
self.identity_lase = LASENet(feature_dim)
self.identity_head = TaskHead(
feature_dim, num_identity_classes, 'classification'
)

def forward(
self,
x: torch.Tensor,
task: str = 'both'
) -> dict:
"""
前向传播

Args:
x: 输入图像, shape=(B, 3, H, W)
task: 'fatigue' / 'identity' / 'both'

Returns:
outputs: {
'fatigue_logits': tensor,
'identity_logits': tensor
}
"""
# 共享特征提取
shared_features = self.backbone(x)

outputs = {}

# 疲劳检测
if task in ['fatigue', 'both']:
fatigue_features = self.fatigue_lase(shared_features)
outputs['fatigue_logits'] = self.fatigue_head(fatigue_features)

# 人脸识别
if task in ['identity', 'both']:
identity_features = self.identity_lase(shared_features)
outputs['identity_logits'] = self.identity_head(identity_features)

return outputs


# 测试
if __name__ == "__main__":
model = TSCAFModel(
backbone_name='resnet18',
num_fatigue_classes=2,
num_identity_classes=50
)

x = torch.randn(4, 3, 224, 224)
outputs = model(x, task='both')

print(f"疲劳检测输出: {outputs['fatigue_logits'].shape}")
print(f"人脸识别输出: {outputs['identity_logits'].shape}")

4. 单任务数据集训练技术

问题: 大多数数据集只有单一任务标注,无法直接用于多任务训练。

解决方案:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class AlternatingTrainer:
"""
交替更新训练器

使用单任务数据集训练多任务模型
"""

def __init__(
self,
model: nn.Module,
fatigue_dataset,
identity_dataset,
device: str = 'cuda'
):
self.model = model.to(device)
self.device = device

# 数据加载器
self.fatigue_loader = torch.utils.data.DataLoader(
fatigue_dataset, batch_size=32, shuffle=True
)
self.identity_loader = torch.utils.data.DataLoader(
identity_dataset, batch_size=32, shuffle=True
)

# 优化器
self.optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# 损失函数
self.fatigue_criterion = nn.CrossEntropyLoss()
self.identity_criterion = nn.CrossEntropyLoss()

def train_epoch(self, epoch: int):
"""训练一个 epoch"""

self.model.train()

# 创建迭代器
fatigue_iter = iter(self.fatigue_loader)
identity_iter = iter(self.identity_loader)

max_batches = max(
len(self.fatigue_loader),
len(self.identity_loader)
)

for batch_idx in range(max_batches):
# ===== 疲劳检测任务 =====
try:
fatigue_batch = next(fatigue_iter)
except StopIteration:
fatigue_iter = iter(self.fatigue_loader)
fatigue_batch = next(fatigue_iter)

images, fatigue_labels = fatigue_batch
images = images.to(self.device)
fatigue_labels = fatigue_labels.to(self.device)

# 前向传播(仅疲劳检测)
outputs = self.model(images, task='fatigue')
fatigue_loss = self.fatigue_criterion(
outputs['fatigue_logits'], fatigue_labels
)

# 反向传播
self.optimizer.zero_grad()
fatigue_loss.backward()
self.optimizer.step()

# ===== 人脸识别任务 =====
try:
identity_batch = next(identity_iter)
except StopIteration:
identity_iter = iter(self.identity_loader)
identity_batch = next(identity_iter)

images, identity_labels = identity_batch
images = images.to(self.device)
identity_labels = identity_labels.to(self.device)

# 前向传播(仅人脸识别)
outputs = self.model(images, task='identity')
identity_loss = self.identity_criterion(
outputs['identity_logits'], identity_labels
)

# 反向传播
self.optimizer.zero_grad()
identity_loss.backward()
self.optimizer.step()

if batch_idx % 10 == 0:
print(
f"Epoch {epoch}, Batch {batch_idx}, "
f"Fatigue Loss: {fatigue_loss.item():.4f}, "
f"Identity Loss: {identity_loss.item():.4f}"
)


class GradientAccumulationTrainer:
"""
梯度累积训练器

解决显存不足问题
"""

def __init__(
self,
model: nn.Module,
accumulation_steps: int = 4
):
self.model = model
self.accumulation_steps = accumulation_steps

def train_step(
self,
batch: dict,
optimizer: torch.optim.Optimizer,
step: int
):
"""
训练步骤

Args:
batch: 数据批次
optimizer: 优化器
step: 当前步骤
"""
images = batch['image']
fatigue_labels = batch.get('fatigue_label')
identity_labels = batch.get('identity_label')

# 前向传播
outputs = self.model(images, task='both')

# 计算损失
loss = 0
if fatigue_labels is not None:
loss += nn.functional.cross_entropy(
outputs['fatigue_logits'], fatigue_labels
)

if identity_labels is not None:
loss += nn.functional.cross_entropy(
outputs['identity_logits'], identity_labels
)

# 归一化损失
loss = loss / self.accumulation_steps

# 反向传播(累积梯度)
loss.backward()

# 每 accumulation_steps 步更新一次
if (step + 1) % self.accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

return loss.item() * self.accumulation_steps

实验结果

数据集

数据集 用途 样本数 类别数
CEW 疲劳检测 2400 2 (疲劳/正常)
YawDD 疲劳检测 1700 2
LFW 人脸识别 13000 5749
CASIA 人脸识别 49414 10575

性能对比

方法 疲劳检测 Acc 人脸识别 Acc 参数量 (M) 推理时间 (ms)
并行独立模型 92.5% 97.3% 44.2 45
共享骨干 91.8% 96.8% 23.1 28
T-SCAF (Ours) 94.2% 98.1% 25.6 32

消融实验

配置 疲劳检测 Acc 人脸识别 Acc
仅共享骨干 91.8% 96.8%
+ LANet 93.1% 97.2%
+ SENet 92.8% 97.5%
+ LASE-Net (完整) 94.2% 98.1%

IMS 应用启示

1. 部署优化

模型压缩:

1
2
3
4
5
6
7
8
9
10
11
12
import torch.quantization as quant

# 动态量化
model_quantized = quant.quantize_dynamic(
model,
{nn.Linear, nn.Conv2d},
dtype=torch.qint8
)

# 性能对比
# FP32: 32ms, 100MB
# INT8: 12ms, 28MB

ONNX 导出:

1
2
3
4
5
6
7
8
9
10
# 导出多任务模型
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"tscaf_multi_task.onnx",
input_names=['image'],
output_names=['fatigue', 'identity'],
dynamic_axes={'image': {0: 'batch'}}
)

2. 实时推理管线

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class RealtimeDMSInference:
"""实时 DMS 推理"""

def __init__(self, model_path: str):
import onnxruntime as ort
self.session = ort.InferenceSession(model_path)

def infer(self, image: np.ndarray) -> dict:
"""
推理

Args:
image: BGR 图像

Returns:
result: {
'is_fatigue': bool,
'driver_id': int,
'fatigue_confidence': float,
'identity_confidence': float
}
"""
# 预处理
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_tensor = self._preprocess(image_rgb)

# ONNX 推理
outputs = self.session.run(
None,
{'image': image_tensor}
)

fatigue_logits = outputs[0]
identity_logits = outputs[1]

# 后处理
fatigue_pred = np.argmax(fatigue_logits, axis=1)[0]
identity_pred = np.argmax(identity_logits, axis=1)[0]

fatigue_conf = softmax(fatigue_logits[0])[fatigue_pred]
identity_conf = softmax(identity_logits[0])[identity_pred]

return {
'is_fatigue': fatigue_pred == 1,
'driver_id': identity_pred,
'fatigue_confidence': float(fatigue_conf),
'identity_confidence': float(identity_conf)
}

def _preprocess(self, image: np.ndarray) -> np.ndarray:
"""预处理"""
image = cv2.resize(image, (224, 224))
image = image.astype(np.float32) / 255.0
image = (image - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
image = np.transpose(image, (2, 0, 1))
return np.expand_dims(image, 0)

3. 部署架构

1
2
3
4
5
6
7
8
9
10
11
IR摄像头 (30fps)

人脸检测

T-SCAF 模型

┌─────────┬─────────┐
│疲劳检测 │人脸识别 │
└─────────┴─────────┘

多任务决策

参考文献

  1. Gao Z., et al., “Multi-Task Learning for Fatigue Detection and Face Recognition of Drivers via Tree-Style Space-Channel Attention Fusion Network”, arXiv 2024
  2. Hu J., et al., “Squeeze-and-Excitation Networks”, CVPR 2018
  3. Woo S., et al., “CBAM: Convolutional Block Attention Module”, ECCV 2018

总结: T-SCAF 通过树形多任务架构和空间-通道注意力融合,实现了疲劳检测和人脸识别的高效统一。相比并行独立模型,参数量减少 42%,推理速度提升 40%,且准确率更高。建议采用 INT8 量化部署到车载平台,实现实时多任务 DMS。


T-SCAF 多任务学习DMS:疲劳检测与人脸识别统一框架论文解读与代码复现
https://dapalm.com/2026/06/04/2026-06-04-T-SCAF多任务学习DMS疲劳检测与人脸识别统一框架论文解读与代码复现/
作者
Mars
发布于
2026年6月4日
许可协议