1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
| import torch import torch.nn as nn
class Simple3DLifter(nn.Module): """ 简单3D提升网络 """ def __init__(self, num_joints=17, hidden_dim=256): super().__init__() self.encoder = nn.Sequential( nn.Linear(num_joints * 2, hidden_dim), nn.ReLU(), nn.Dropout(0.3), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.3) ) self.decoder = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, num_joints * 3) ) def forward(self, keypoints_2d): """ keypoints_2d: [B, J, 2] """ B, J, _ = keypoints_2d.shape x = keypoints_2d.view(B, -1) h = self.encoder(x) keypoints_3d = self.decoder(h) keypoints_3d = keypoints_3d.view(B, J, 3) return keypoints_3d
|