1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
| import torch import torch.nn as nn import torch.nn.functional as F
class BEVFormer(nn.Module): """ BEVFormer: 鸟瞰图Transformer 用于驾驶员注意力预测 """ def __init__(self, num_cams=6, embed_dim=256, num_heads=8, num_encoder_layers=6, bev_h=200, bev_w=200): super(BEVFormer, self).__init__() self.num_cams = num_cams self.embed_dim = embed_dim self.bev_h = bev_h self.bev_w = bev_w self.backbone = ResNetBackbone() self.bev_queries = nn.Parameter(torch.randn(1, bev_h * bev_w, embed_dim)) self.spatial_cross_attention = nn.ModuleList([ SpatialCrossAttention(embed_dim, num_heads) for _ in range(num_encoder_layers) ]) self.temporal_self_attention = nn.ModuleList([ TemporalSelfAttention(embed_dim, num_heads) for _ in range(num_encoder_layers) ]) self.ffn = nn.ModuleList([ FFN(embed_dim) for _ in range(num_encoder_layers) ]) def forward(self, imgs, prev_bev=None): """ 前向传播 Args: imgs: (B, num_cams, C, H, W) 多摄像头图像 prev_bev: (B, bev_h * bev_w, C) 前一帧BEV特征 Returns: bev_features: (B, bev_h * bev_w, C) BEV特征 """ B = imgs.shape[0] img_features = [] for cam_id in range(self.num_cams): feat = self.backbone(imgs[:, cam_id]) img_features.append(feat) img_features = torch.stack(img_features, dim=1) bev_queries = self.bev_queries.expand(B, -1, -1) for i in range(len(self.spatial_cross_attention)): if prev_bev is not None: bev_queries = self.temporal_self_attention[i](bev_queries, prev_bev) bev_queries = self.spatial_cross_attention[i](bev_queries, img_features) bev_queries = self.ffn[i](bev_queries) return bev_queries
class SpatialCrossAttention(nn.Module): """ 空间交叉注意力 BEV query与图像特征交互 """ def __init__(self, embed_dim, num_heads): super().__init__() self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads) self.norm = nn.LayerNorm(embed_dim) def forward(self, bev_queries, img_features): """ Args: bev_queries: (B, N, C) BEV queries img_features: (B, num_cams, C, H, W) 图像特征 Returns: output: (B, N, C) """ B, num_cams, C, H, W = img_features.shape img_features_flat = img_features.flatten(3).permute(0, 1, 3, 2) img_features_flat = img_features_flat.flatten(1, 2) bev_queries_t = bev_queries.permute(1, 0, 2) img_features_t = img_features_flat.permute(1, 0, 2) attn_output, _ = self.cross_attn(bev_queries_t, img_features_t, img_features_t) attn_output = attn_output.permute(1, 0, 2) output = self.norm(bev_queries + attn_output) return output
class TemporalSelfAttention(nn.Module): """ 时序自注意力 融合历史BEV特征 """ def __init__(self, embed_dim, num_heads): super().__init__() self.self_attn = nn.MultiheadAttention(embed_dim, num_heads) self.norm = nn.LayerNorm(embed_dim) def forward(self, bev_queries, prev_bev): """ Args: bev_queries: (B, N, C) 当前BEV queries prev_bev: (B, N, C) 历史BEV特征 Returns: output: (B, N, C) """ combined = torch.cat([bev_queries, prev_bev], dim=1) combined_t = combined.permute(1, 0, 2) attn_output, _ = self.self_attn(combined_t, combined_t, combined_t) attn_output = attn_output.permute(1, 0, 2) output = self.norm(combined + attn_output) output = output[:, :bev_queries.shape[1]] return output
class ResNetBackbone(nn.Module): """简化的ResNet backbone""" def __init__(self, output_dim=256): super().__init__() import torchvision.models as models resnet = models.resnet50(pretrained=True) self.backbone = nn.Sequential(*list(resnet.children())[:-2]) self.proj = nn.Conv2d(2048, output_dim, 1) def forward(self, x): x = self.backbone(x) x = self.proj(x) return x
class FFN(nn.Module): """Feed Forward Network""" def __init__(self, embed_dim, hidden_dim=1024): super().__init__() self.fc1 = nn.Linear(embed_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, embed_dim) self.norm = nn.LayerNorm(embed_dim) def forward(self, x): return self.norm(x + self.fc2(F.relu(self.fc1(x))))
|