1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
| import torch import torch.nn as nn import torch.nn.functional as F import math
class FinegrainedInterFrameAttention(nn.Module): """ 细粒度帧间注意力模块 显式建模瞳孔位移变化 """ def __init__( self, feature_dim: int = 256, num_heads: int = 8, dropout: float = 0.1 ): super().__init__() self.feature_dim = feature_dim self.num_heads = num_heads self.head_dim = feature_dim // num_heads self.q_proj = nn.Linear(feature_dim, feature_dim) self.k_proj = nn.Linear(feature_dim, feature_dim) self.v_proj = nn.Linear(feature_dim, feature_dim) self.out_proj = nn.Linear(feature_dim, feature_dim) self.displacement_encoder = nn.Sequential( nn.Linear(2, 64), nn.ReLU(), nn.Linear(64, feature_dim) ) self.dropout = nn.Dropout(dropout) self.scale = math.sqrt(self.head_dim) def forward( self, frame_features: torch.Tensor, pupil_positions: torch.Tensor ) -> torch.Tensor: """ 前向传播 Args: frame_features: 帧特征, shape=(B, T, D) pupil_positions: 瞳孔位置, shape=(B, T, 2) Returns: attended_features: 注意力加权特征 """ B, T, D = frame_features.shape displacement = pupil_positions[:, 1:] - pupil_positions[:, :-1] displacement = F.pad(displacement, (0, 0, 1, 0), value=0) disp_encoding = self.displacement_encoder(displacement) enhanced_features = frame_features + disp_encoding Q = self.q_proj(enhanced_features).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) K = self.k_proj(enhanced_features).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) V = self.v_proj(enhanced_features).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale attn_weights = F.softmax(attn_scores, dim=-1) attn_weights = self.dropout(attn_weights) attn_output = torch.matmul(attn_weights, V) attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, D) output = self.out_proj(attn_output) return output
class PupilDisplacementExtractor(nn.Module): """ 瞳孔位移提取器 从连续帧中提取瞳孔位置变化 """ def __init__(self): super().__init__() self.eye_detector = nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((4, 4)) ) self.pupil_regressor = nn.Sequential( nn.Linear(128 * 4 * 4, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 2) ) def forward(self, frames: torch.Tensor) -> torch.Tensor: """ 提取瞳孔位置序列 Args: frames: 视频帧, shape=(B, T, C, H, W) Returns: pupil_positions: 瞳孔位置, shape=(B, T, 2) """ B, T = frames.shape[:2] frames_flat = frames.view(B * T, *frames.shape[2:]) eye_features = self.eye_detector(frames_flat) eye_features = eye_features.view(B * T, -1) pupil_positions = self.pupil_regressor(eye_features) pupil_positions = pupil_positions.view(B, T, 2) return pupil_positions
|