1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
| import torchvision.models as models
class CNNLSTM_FatigueDetector(nn.Module): def __init__(self, num_classes=4): super().__init__() mobilenet = models.mobilenet_v2(pretrained=True) self.cnn = nn.Sequential(*list(mobilenet.features.children())) self.fc1 = nn.Linear(1280, 256) self.lstm = nn.LSTM( input_size=256, hidden_size=128, num_layers=2, batch_first=True, dropout=0.3 ) self.classifier = nn.Sequential( nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, num_classes) ) def forward(self, x): """ x: [B, T, 3, H, W] 视频序列 """ B, T, C, H, W = x.size() x = x.view(B * T, C, H, W) cnn_features = self.cnn(x) cnn_features = cnn_features.view(B * T, -1) cnn_features = self.fc1(cnn_features) cnn_features = cnn_features.view(B, T, -1) lstm_out, _ = self.lstm(cnn_features) last_hidden = lstm_out[:, -1, :] logits = self.classifier(last_hidden) return logits
model = CNNLSTM_FatigueDetector(num_classes=4)
video_clip = load_video_clip() logits = model(video_clip)
fatigue_level = torch.argmax(logits, dim=1) print(f"疲劳等级: {['清醒', '轻度', '中度', '重度'][fatigue_level]}")
|