1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| import torch import torch.onnx
class DMSModel(torch.nn.Module): """示例DMS模型""" def __init__(self): super().__init__() self.backbone = torch.nn.Sequential( torch.nn.Conv2d(3, 32, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(2), torch.nn.Conv2d(32, 64, 3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(2), torch.nn.Conv2d(64, 128, 3, padding=1), torch.nn.ReLU(), torch.nn.AdaptiveAvgPool2d(1) ) self.head = torch.nn.Linear(128, 5) def forward(self, x): features = self.backbone(x) features = features.view(features.size(0), -1) return self.head(features)
model = DMSModel() model.eval() dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export( model, dummy_input, "dms_model.onnx", opset_version=11, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}} )
print("ONNX模型已导出")
|