1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
| class DriverBehaviorSwinTransformer(nn.Module): """ 驾驶员行为检测Swin Transformer 论文架构: - Patch Partition -> Stage 1-4 -> Head 行为类别: - 正常驾驶 - 打电话 (左/右手) - 发短信 (左/右手) - 调整收音机 - 喝水 - 伸手后座 - 整理头发 - 与乘客交谈 """ def __init__(self, img_size: int = 224, patch_size: int = 4, in_channels: int = 3, num_classes: int = 10, embed_dim: int = 96, depths: list = [2, 2, 6, 2], num_heads: list = [3, 6, 12, 24]): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.patch_embed = nn.Sequential( nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size), nn.LayerNorm(embed_dim) ) self.stages = nn.ModuleList() for i_layer in range(self.num_layers): stage = self._make_stage( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer] ) self.stages.append(stage) self.norm = nn.LayerNorm(int(embed_dim * 2 ** (self.num_layers - 1))) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear( int(embed_dim * 2 ** (self.num_layers - 1)), num_classes ) def _make_stage(self, dim: int, depth: int, num_heads: int) -> nn.Sequential: """构建一个Stage""" blocks = [] for _ in range(depth): blocks.append(ImprovedSwinBlock(dim, num_heads)) blocks.append(nn.LayerNorm(dim)) blocks.append(PatchMerging(dim)) return nn.Sequential(*blocks) def forward(self, x: torch.Tensor) -> torch.Tensor: """ 前向传播 Args: x: 输入图像, shape=(B, 3, 224, 224) Returns: 分类logits, shape=(B, num_classes) """ x = self.patch_embed(x) B, H, W, C = x.shape x = x.flatten(1, 2) for stage in self.stages: x = stage(x) x = self.norm(x) x = x.transpose(1, 2) x = self.avgpool(x).flatten(1) x = self.head(x) return x
class PatchMerging(nn.Module): """Patch合并 (下采样)""" def __init__(self, dim: int): super().__init__() self.norm = nn.LayerNorm(4 * dim) self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, H*W, C) Returns: (B, H*W/4, 2*C) """ B, L, C = x.shape H = W = int(np.sqrt(L)) x = x.reshape(B, H, W, C) x0 = x[:, 0::2, 0::2, :] x1 = x[:, 1::2, 0::2, :] x2 = x[:, 0::2, 1::2, :] x3 = x[:, 1::2, 1::2, :] x = torch.cat([x0, x1, x2, x3], dim=-1) x = x.flatten(1, 2) x = self.norm(x) x = self.reduction(x) return x
|