1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
| import torch import torch.nn as nn import torchvision.models as models
class DriverDistractionDetector(nn.Module): """ 驾驶员分心检测模型 基于ResNet迁移学习,分类三种分心类型 """ def __init__(self, num_classes: int = 4): """ Args: num_classes: 分类数 0: 正常驾驶 1: 视觉分心 2: 手动分心 3: 认知分心 """ super().__init__() self.backbone = models.resnet50(pretrained=True) for param in list(self.backbone.parameters())[:-4]: param.requires_grad = False num_features = self.backbone.fc.in_features self.backbone.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(num_features, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, num_classes) ) def forward(self, x): """ 前向传播 Args: x: 驾驶员图像 (B, 3, 224, 224) Returns: logits: 分心类型预测 (B, num_classes) """ return self.backbone(x) def extract_features(self, x): """ 提取特征用于认知分心检测 特征包括: - 头部姿态 - 视线方向 - 手部位置 """ x = self.backbone.conv1(x) x = self.backbone.bn1(x) x = self.backbone.relu(x) x = self.backbone.maxpool(x) x = self.backbone.layer1(x) x = self.backbone.layer2(x) x = self.backbone.layer3(x) x = self.backbone.layer4(x) x = self.backbone.avgpool(x) features = torch.flatten(x, 1) return features
class DistractionLoss(nn.Module): """ 分心检测损失函数 使用交叉熵损失 + 类别平衡 """ def __init__(self, class_weights=None): super().__init__() self.class_weights = class_weights def forward(self, predictions, targets): """ 计算损失 Args: predictions: 预测logits (B, num_classes) targets: 真实标签 (B,) """ ce_loss = nn.functional.cross_entropy( predictions, targets, weight=self.class_weights ) return ce_loss
|