1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
| import torch import torch.nn as nn
class PrimaryCapsule(nn.Module): """主胶囊层""" def __init__(self, in_channels, out_channels, kernel_size, num_routes): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=2) self.num_routes = num_routes def forward(self, x): features = self.conv(x) batch_size = features.size(0) capsules = features.view(batch_size, self.num_routes, -1) return self.squash(capsules) def squash(self, x): """胶囊激活函数""" norm = (x ** 2).sum(dim=-1, keepdim=True) return (norm / (1 + norm)) * (x / torch.sqrt(norm))
class GazeCapsule(nn.Module): """凝视胶囊层""" def __init__(self, num_routes, in_channels, out_channels): super().__init__() self.weight = nn.Parameter( torch.randn(num_routes, out_channels, in_channels) ) def forward(self, x, num_routing=3): batch_size = x.size(0) u_hat = torch.matmul(x, self.weight) b = torch.zeros(batch_size, self.num_routes, 1) for _ in range(num_routing): c = torch.softmax(b, dim=1) s = (c * u_hat).sum(dim=1, keepdim=True) v = self.squash(s) b = b + (u_hat * v).sum(dim=-1, keepdim=True) return v.squeeze(1) def squash(self, x): norm = (x ** 2).sum(dim=-1, keepdim=True) return (norm / (1 + norm)) * (x / torch.sqrt(norm + 1e-8))
class GazeCapsNet(nn.Module): """完整GazeCapsNet""" def __init__(self): super().__init__() self.primary_caps = PrimaryCapsule(3, 256, 9, 32) self.gaze_caps = GazeCapsule(32, 8, 16) def forward(self, x): x = self.primary_caps(x) gaze_vector = self.gaze_caps(x) return gaze_vector
|