1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
| """ CNN-LSTM 混合架构疲劳检测 """
import torch import torch.nn as nn import torchvision.models as models
class CNNLSTMFatigueDetector(nn.Module): """ CNN-LSTM混合疲劳检测器 架构: 1. CNN提取每帧空间特征 2. LSTM建模时序依赖 3. 注意力机制聚焦关键时刻 """ def __init__(self, cnn_backbone: str = 'resnet18', lstm_hidden: int = 256, lstm_layers: int = 2, num_classes: int = 2, pretrained: bool = True): super().__init__() if cnn_backbone == 'resnet18': self.cnn = models.resnet18(pretrained=pretrained) feature_dim = 512 elif cnn_backbone == 'mobilenetv2': self.cnn = models.mobilenet_v2(pretrained=pretrained) feature_dim = 1280 else: raise ValueError(f"不支持的骨干网络: {cnn_backbone}") self.cnn.fc = nn.Identity() self.lstm = nn.LSTM( input_size=feature_dim, hidden_size=lstm_hidden, num_layers=lstm_layers, batch_first=True, bidirectional=True, dropout=0.3 ) self.attention = nn.Sequential( nn.Linear(lstm_hidden * 2, 64), nn.Tanh(), nn.Linear(64, 1) ) self.classifier = nn.Sequential( nn.Linear(lstm_hidden * 2, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, num_classes) ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ 前向传播 Args: x: 视频帧序列 (B, T, 3, H, W) Returns: logits: 分类输出 (B, num_classes) """ batch_size, seq_len = x.size(0), x.size(1) x_flat = x.view(batch_size * seq_len, *x.size()[2:]) cnn_features = self.cnn(x_flat) cnn_features = cnn_features.view(batch_size, seq_len, -1) lstm_out, _ = self.lstm(cnn_features) attn_scores = self.attention(lstm_out) attn_weights = torch.softmax(attn_scores, dim=1) context = torch.sum(lstm_out * attn_weights, dim=1) logits = self.classifier(context) return logits def get_temporal_attention(self, x: torch.Tensor) -> torch.Tensor: """获取时间注意力权重(可视化用)""" batch_size, seq_len = x.size(0), x.size(1) x_flat = x.view(batch_size * seq_len, *x.size()[2:]) cnn_features = self.cnn(x_flat) cnn_features = cnn_features.view(batch_size, seq_len, -1) lstm_out, _ = self.lstm(cnn_features) attn_scores = self.attention(lstm_out) attn_weights = torch.softmax(attn_scores, dim=1) return attn_weights.squeeze(-1)
if __name__ == "__main__": model = CNNLSTMFatigueDetector( cnn_backbone='resnet18', lstm_hidden=256, lstm_layers=2 ) x = torch.randn(4, 10, 3, 224, 224) logits = model(x) attn = model.get_temporal_attention(x) print(f"输入形状: {x.shape}") print(f"输出形状: {logits.shape}") print(f"注意力权重形状: {attn.shape}") print(f"注意力权重样本: {attn[0].detach().numpy()}") total_params = sum(p.numel() for p in model.parameters()) print(f"总参数量: {total_params/1e6:.2f}M")
|