引言:为什么需要轻量级视线估计
视线估计在车载场景的应用面临三个核心挑战:
| 挑战 |
传统方法问题 |
GazeCapsNet解决方案 |
| 算力限制 |
模型参数量大(>50M) |
仅11.7M参数 |
| 实时性 |
推理延迟高(>50ms) |
20ms延迟 |
| 鲁棒性 |
头位/光照变化敏感 |
Capsule保留空间关系 |
GazeCapsNet(2025年发表于Sensors)通过胶囊网络 + Self-Attention Routing实现了精度与效率的平衡。
一、胶囊网络基础
1.1 传统CNN的局限性
问题:CNN的池化操作丢失空间信息
1 2 3 4
| 传统CNN: 输入图像 → Conv → Pooling → Conv → Pooling → FC → 输出 ↓ 空间信息丢失(位置不变性)
|
对视线估计的影响:
- 眼睛的位置信息丢失(上下左右)
- 只能依赖特征组合,无法建模空间关系
- 头位变化时精度下降明显
1.2 胶囊网络的核心思想
Capsule = 一组神经元:
- 输出是一个向量,而非标量
- 向量的长度表示特征存在的概率
- 向量的方向表示特征的姿态参数
1 2
| 传统神经元:y = f(Σwx + b) → 标量 胶囊神经元:v = squash(Σc·û) → 向量
|
Squash函数:
1 2 3 4 5 6
| def squash(s): """ 将向量压缩到[0,1]范围,长度表示概率 """ norm = torch.norm(s, dim=-1, keepdim=True) return (norm ** 2 / (1 + norm ** 2)) * (s / norm)
|
1.3 胶囊网络的优势
| 特性 |
CNN |
Capsule Network |
| 输出 |
标量(激活值) |
向量(姿态+概率) |
| 空间关系 |
隐式学习 |
显式建模 |
| 视角变化 |
需要数据增强 |
等变表示 |
| 参数量 |
多 |
少(动态路由) |
二、Self-Attention Routing(SAR)
2.1 传统路由的问题
动态路由(Dynamic Routing):
1 2 3 4 5
| 迭代过程(通常3次): 1. 预测:û_j|i = W_ij · u_i 2. 耦合:c_ij = softmax(b_ij) 3. 聚合:s_j = Σ c_ij · û_j|i 4. 更新:b_ij += û_j|i · v_j
|
问题:
- 迭代计算耗时长
- 超参数(迭代次数)敏感
- 难以并行化
2.2 Self-Attention Routing原理
核心思想:用注意力机制替代迭代路由
1 2 3 4
| SAR(单次前向): 1. 预测:û_j|i = W_ij · u_i 2. 注意力:A_ij = softmax(û_j|i · û_j|i^T) 3. 聚合:v_j = squash(Σ A_ij · û_j|i)
|
数学表达:
$$A_{ij} = \frac{\exp(\hat{u}{j|i} \cdot \hat{u}{j|i}^T)}{\sum_k \exp(\hat{u}{k|i} \cdot \hat{u}{k|i}^T)}$$
$$v_j = \text{squash}\left(\sum_i A_{ij} \cdot \hat{u}_{j|i}\right)$$
2.3 SAR代码实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
| import torch import torch.nn as nn import torch.nn.functional as F
class SelfAttentionRouting(nn.Module): def __init__(self, in_caps, out_caps, in_dim, out_dim): super().__init__() self.in_caps = in_caps self.out_caps = out_caps self.W = nn.Parameter(torch.randn(in_caps, out_caps, in_dim, out_dim)) def forward(self, u): """ u: [batch, in_caps, in_dim] - 输入胶囊 返回: [batch, out_caps, out_dim] - 输出胶囊 """ batch_size = u.size(0) u_hat = torch.einsum('bci,ijdo->bcjo', u, self.W) u_hat_flat = u_hat.view(batch_size, self.in_caps, self.out_caps, -1) attention = F.softmax( torch.sum(u_hat_flat ** 2, dim=-1), dim=1 ) s = torch.einsum('bco,bcod->bod', attention, u_hat) v = self.squash(s) return v @staticmethod def squash(s): norm = torch.norm(s, dim=-1, keepdim=True) return (norm ** 2 / (1 + norm ** 2)) * (s / (norm + 1e-8))
|
2.4 SAR vs 动态路由对比
| 维度 |
动态路由 |
SAR |
| 计算次数 |
迭代3次 |
单次前向 |
| 延迟 |
+15ms |
0ms |
| 并行性 |
差 |
好 |
| 精度 |
基准 |
相当 |
| 可训练性 |
难 |
易 |
三、GazeCapsNet架构详解
3.1 整体架构
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
| 输入图像 (224×224) ↓ ┌─────────────────────────────────┐ │ 人脸检测 (SCRFD) │ │ - 检测人脸区域 │ │ - 裁剪并缩放到224×224 │ └─────────────────────────────────┘ ↓ ┌─────────────────────────────────┐ │ 双流特征提取 │ │ ├── MobileNet v2 → 低级特征 │ │ └── ResNet-18 → 高级特征 │ │ 拼接 → combined_features │ └─────────────────────────────────┘ ↓ ┌─────────────────────────────────┐ │ Primary Capsules │ │ - 将特征图转换为初级胶囊 │ │ - 每个胶囊8维向量 │ └─────────────────────────────────┘ ↓ ┌─────────────────────────────────┐ │ Self-Attention Routing │ │ - 注意力加权的胶囊路由 │ │ - 生成Gaze Capsules (16维) │ └─────────────────────────────────┘ ↓ ┌─────────────────────────────────┐ │ Gaze Regression Head │ │ - 输出3D视线向量 (pitch, yaw, roll)│ └─────────────────────────────────┘
|
3.2 双流特征提取
为什么用MobileNet v2 + ResNet-18?
| 网络 |
特点 |
贡献 |
| MobileNet v2 |
轻量、快速 |
低级纹理特征 |
| ResNet-18 |
残差连接 |
高级语义特征 |
| 融合 |
互补 |
平衡精度与速度 |
实现代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| import torchvision.models as models
class DualStreamExtractor(nn.Module): def __init__(self): super().__init__() mobilenet = models.mobilenet_v2(pretrained=True) resnet = models.resnet18(pretrained=True) self.mobile_features = mobilenet.features self.resnet_features = nn.Sequential(*list(resnet.children())[:-2]) self.fusion = nn.Conv2d(1280 + 512, 512, kernel_size=1) def forward(self, x): f1 = self.mobile_features(x) f2 = self.resnet_features(x) f = torch.cat([f1, f2], dim=1) f = self.fusion(f) return f
|
3.3 Primary Capsules
从特征图到胶囊:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| class PrimaryCapsules(nn.Module): def __init__(self, in_channels=512, out_channels=32, dim=8): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels * dim, kernel_size=3, stride=2, padding=1) self.out_channels = out_channels self.dim = dim def forward(self, x): out = self.conv(x) batch_size = out.size(0) out = out.view(batch_size, self.out_channels, -1, self.dim) out = self.squash(out) return out.view(batch_size, -1, self.dim) @staticmethod def squash(s): norm = torch.norm(s, dim=-1, keepdim=True) return (norm ** 2 / (1 + norm ** 2)) * (s / (norm + 1e-8))
|
3.4 Gaze Regression Head
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| class GazeRegression(nn.Module): def __init__(self, in_dim=16): super().__init__() self.fc = nn.Linear(in_dim, 3) def forward(self, v): """ v: [B, out_caps, dim] - Gaze capsules 返回: [B, 3] - 3D gaze vector (pitch, yaw, roll) """ norms = torch.norm(v, dim=-1) idx = torch.argmax(norms, dim=1) batch_size = v.size(0) gaze_capsule = v[torch.arange(batch_size), idx] gaze = self.fc(gaze_capsule) gaze = F.normalize(gaze, dim=-1) return gaze
|
四、损失函数设计
4.1 Angular Loss
问题:MSE损失对角度误差不敏感
解决方案:直接优化角误差
$$L_{angular} = \arccos\left(\frac{g_{pred} \cdot g_{true}}{|g_{pred}| \cdot |g_{true}|}\right)$$
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| def angular_loss(g_pred, g_true): """ g_pred: [B, 3] - 预测的3D视线向量 g_true: [B, 3] - 真实的3D视线向量 返回: 角误差(度) """ g_pred = F.normalize(g_pred, dim=-1) g_true = F.normalize(g_true, dim=-1) cos_sim = torch.sum(g_pred * g_true, dim=-1) cos_sim = torch.clamp(cos_sim, -1.0, 1.0) angle = torch.acos(cos_sim) * 180 / torch.pi return angle.mean()
|
4.2 多任务损失
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| def total_loss(gaze_pred, gaze_true, zone_pred, zone_true): L_angular = angular_loss(gaze_pred, gaze_true) L_class = F.cross_entropy(zone_pred, zone_true) L_routing = torch.mean(torch.norm(capsule_output, dim=-1)) loss = L_angular + 0.1 * L_class + 0.01 * L_routing return loss
|
五、实验结果与分析
5.1 数据集性能
Gaze360数据集:
| 方法 |
MAE(角误差) |
参数量 |
推理时间 |
| FullFace |
6.53° |
196.6M |
50ms |
| RT-GENE |
6.02° |
45.2M |
40ms |
| GazeTR-Pure |
5.33° |
23.1M |
45ms |
| GazeCaps |
5.10° |
14.2M |
25ms |
| GazeCapsNet |
5.10° |
11.7M |
20ms |
MPIIFaceGaze数据集:
| 方法 |
MAE |
| FullFace |
4.95° |
| Dilated-Net |
4.78° |
| GazeTR-Hybrid |
4.06° |
| GazeCapsNet |
4.06° |
5.2 消融实验
| 组件 |
MAE |
说明 |
| 完整模型 |
5.10° |
基准 |
| - ResNet-18 |
5.68° |
+0.58° |
| - MobileNet v2 |
5.45° |
+0.35° |
| - SAR(用动态路由) |
5.23° |
+0.13°,但延迟+15ms |
| - 胶囊(用FC) |
5.89° |
+0.79° |
结论:
- ResNet-18贡献最大:高级特征对精度影响显著
- SAR贡献适中:主要优势在速度
- 胶囊机制重要:空间建模能力不可替代
5.3 跨域泛化
训练→测试跨数据集:
| 训练数据 |
测试数据 |
MAE |
| ETH-XGaze |
Gaze360 |
6.82° |
| Gaze360 |
ETH-XGaze |
5.94° |
| ETH-XGaze |
MPIIFaceGaze |
5.23° |
| Gaze360 |
MPIIFaceGaze |
4.87° |
结论:Gaze360训练的模型泛化能力更强(野外数据多样性高)。
六、嵌入式部署
6.1 模型量化
FP16量化:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| import torch
model = GazeCapsNet.load_from_checkpoint('checkpoint.pth') model.eval()
model_half = model.half()
dummy_input = torch.randn(1, 3, 224, 224).half()
with torch.no_grad(): output = model_half(dummy_input)
|
INT8量化(需校准):
1 2 3 4 5 6 7 8 9 10 11 12
| import torch.quantization as quant
model.qconfig = quant.get_default_qconfig('qnnpack') quant.prepare(model, inplace=True)
for image in calibration_images: model(image)
quant.convert(model, inplace=True)
|
6.2 TensorRT加速
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| import torch.onnx import tensorrt as trt
dummy_input = torch.randn(1, 3, 224, 224) torch.onnx.export(model, dummy_input, 'gazecapsnet.onnx')
logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(logger) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, logger)
with open('gazecapsnet.onnx', 'rb') as f: parser.parse(f.read())
config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) config.set_flag(trt.BuilderFlag.FP16)
engine = builder.build_serialized_network(network, config)
|
6.3 性能对比
NVIDIA Jetson Orin Nano:
| 精度 |
延迟 |
功耗 |
精度损失 |
| FP32 |
35ms |
8W |
0° |
| FP16 |
20ms |
6W |
0.1° |
| INT8 |
12ms |
5W |
0.5° |
Qualcomm 8255(QNN):
| 精度 |
延迟 |
功耗 |
| FP32 |
28ms |
4W |
| FP16 |
18ms |
3W |
| INT8 |
10ms |
2.5W |
七、代码开源
完整代码:https://github.com/yakhyo/gaze-estimation
快速使用:
1 2 3 4 5 6 7 8 9 10 11 12
| from gazecapsnet import GazeCapsNet
model = GazeCapsNet.from_pretrained('gazecapsnet_ethxgaze.pth')
image = cv2.imread('driver.jpg') gaze_vector = model.predict(image)
pitch, yaw, roll = gaze_vector print(f"视线方向: pitch={pitch:.1f}°, yaw={yaw:.1f}°, roll={roll:.1f}°")
|
八、总结
8.1 GazeCapsNet核心贡献
- Self-Attention Routing:替代迭代路由,延迟降低20%
- 双流特征提取:MobileNet v2 + ResNet-18平衡精度与速度
- 端到端设计:无需人脸关键点,直接从图像预测
- 轻量级:11.7M参数,20ms延迟,适合边缘部署
8.2 适用场景
| 场景 |
推荐 |
说明 |
| 车载DMS |
✅ 强烈推荐 |
精度、速度、部署成本平衡 |
| VR/AR |
✅ 推荐 |
实时性好 |
| 科研实验 |
⚠️ 可选 |
精度不如专用硬件 |
| 移动设备 |
✅ 推荐 |
轻量级 |
8.3 局限性
- 墨镜场景:IR图像瞳孔不可见
- 极端头位:>45°时精度下降
- 多人种:训练数据以欧美为主,亚洲人精度略低
参考文献
- Muksimova, S., et al. “GazeCapsNet: A Lightweight Gaze Estimation Framework.” Sensors, 2025.
- Sabour, S., et al. “Dynamic Routing Between Capsules.” NeurIPS, 2017.
- Zhang, X., et al. “ETH-XGaze: A Large Scale Dataset.” ECCV, 2020.
本文是IMS视线估计算法系列文章之一,下一篇:GazeTR Transformer详解