1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
| """ DMS 模型训练与导出
以眼动追踪模型为例 """
import torch import torch.nn as nn import torch.onnx from typing import Tuple
class EyeGazeModel(nn.Module): """ 眼动追踪模型 输入:眼部图像 (B, 3, 64, 64) 输出:视线方向 (B, 2),归一化坐标 """ def __init__(self): super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.AdaptiveAvgPool2d(1) ) self.regressor = nn.Sequential( nn.Flatten(), nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 2), nn.Tanh() ) def forward(self, x: torch.Tensor) -> torch.Tensor: features = self.features(x) gaze = self.regressor(features) return gaze
def export_to_onnx( model: nn.Module, output_path: str, input_size: Tuple[int, int, int] = (3, 64, 64) ): """ 导出模型为 ONNX 格式 Args: model: PyTorch 模型 output_path: 输出路径 input_size: 输入尺寸 (C, H, W) """ model.eval() dummy_input = torch.randn(1, *input_size) torch.onnx.export( model, dummy_input, output_path, export_params=True, opset_version=11, do_constant_folding=True, input_names=['input'], output_names=['gaze'], dynamic_axes={ 'input': {0: 'batch_size'}, 'gaze': {0: 'batch_size'} } ) print(f"模型已导出到: {output_path}")
def train_model(): """训练模型""" model = EyeGazeModel() criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(100): images = torch.randn(32, 3, 64, 64) gaze_gt = torch.randn(32, 2) gaze_pred = model(images) loss = criterion(gaze_pred, gaze_gt) optimizer.zero_grad() loss.backward() optimizer.step() if epoch % 10 == 0: print(f"Epoch {epoch}, Loss: {loss.item():.4f}") export_to_onnx(model, "eye_gaze_model.onnx") return model
if __name__ == "__main__": model = train_model()
|