1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
| import torchvision.models as models
class DBMNet(nn.Module): """ DBMNet: 双分支多尺度网络 用于跨摄像头驾驶员分心分类 """ def __init__( self, num_distraction_classes=10, num_cameras=5, feature_dim=2048, hidden_dim=512 ): super().__init__() self.backbone = models.resnet50(pretrained=True) self.backbone = nn.Sequential(*list(self.backbone.children())[:-2]) self.multi_scale = MultiScaleFeatureExtractor( in_channels=2048, out_channels=256 ) self.decoupler = FeatureDecoupler( feature_dim=feature_dim, hidden_dim=hidden_dim, num_cameras=num_cameras, num_distraction_classes=num_distraction_classes ) self.contrastive = ContrastiveLearningModule( feature_dim=hidden_dim, num_classes=num_distraction_classes, temperature=0.07 ) self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) def forward( self, images, camera_ids=None, labels=None, lambda_=1.0 ): """ Args: images: (B, C, H, W) 输入图像 camera_ids: (B,) 摄像头ID(训练时使用) labels: (B,) 分心标签(训练时使用) lambda_: 梯度反转强度 Returns: outputs: 包含预测和损失的字典 """ backbone_features = self.backbone(images) pooled_features = self.global_pool(backbone_features) pooled_features = pooled_features.view(pooled_features.size(0), -1) decoupled = self.decoupler(pooled_features, lambda_) if labels is not None: proto_loss, inst_loss = self.contrastive( decoupled['discriminative_features'], labels ) else: proto_loss = torch.tensor(0.0, device=images.device) inst_loss = torch.tensor(0.0, device=images.device) outputs = { 'distraction_logits': decoupled['distraction_prediction'], 'camera_logits': decoupled['camera_prediction'], 'discriminative_features': decoupled['discriminative_features'], 'domain_features': decoupled['domain_features'], 'prototype_loss': proto_loss, 'instance_loss': inst_loss } return outputs
class DBMNetLoss(nn.Module): """DBMNet总损失函数""" def __init__( self, distraction_weight=1.0, camera_weight=0.1, prototype_weight=0.5, instance_weight=0.3, orthogonal_weight=0.1 ): super().__init__() self.distraction_weight = distraction_weight self.camera_weight = camera_weight self.prototype_weight = prototype_weight self.instance_weight = instance_weight self.orthogonal_weight = orthogonal_weight self.distraction_criterion = nn.CrossEntropyLoss() self.camera_criterion = nn.CrossEntropyLoss() def forward(self, outputs, labels, camera_ids): """ 计算总损失 """ loss_distraction = self.distraction_criterion( outputs['distraction_logits'], labels ) loss_camera = self.camera_criterion( outputs['camera_logits'], camera_ids ) loss_prototype = outputs['prototype_loss'] loss_instance = outputs['instance_loss'] disc_feat = outputs['discriminative_features'] domain_feat = outputs['domain_features'] loss_orthogonal = self.orthogonal_loss(disc_feat, domain_feat) total_loss = ( self.distraction_weight * loss_distraction + self.camera_weight * loss_camera + self.prototype_weight * loss_prototype + self.instance_weight * loss_instance + self.orthogonal_weight * loss_orthogonal ) loss_dict = { 'distraction': loss_distraction.item(), 'camera': loss_camera.item(), 'prototype': loss_prototype.item(), 'instance': loss_instance.item(), 'orthogonal': loss_orthogonal.item(), 'total': total_loss.item() } return total_loss, loss_dict def orthogonal_loss(self, feat1, feat2): """特征正交性损失""" feat1_norm = F.normalize(feat1, dim=1) feat2_norm = F.normalize(feat2, dim=1) cov = torch.mm(feat1_norm.t(), feat2_norm) loss = torch.norm(cov, p='fro') return loss
|