1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| import torch import torch.nn as nn
class DistractedDrivingDetector(nn.Module): def __init__(self, num_classes=10): super().__init__()
self.cnn = nn.Sequential( nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), )
self.bilstm = nn.LSTM( 128 * 56 * 56, 256, bidirectional=True, batch_first=True )
self.attention = nn.Linear(512, 1)
self.classifier = nn.Linear(512, num_classes)
def forward(self, x): batch, seq_len = x.size(0), x.size(1)
cnn_features = [] for t in range(seq_len): feat = self.cnn(x[:, t]) cnn_features.append(feat.view(batch, -1)) cnn_features = torch.stack(cnn_features, dim=1)
lstm_out, _ = self.bilstm(cnn_features)
attn_weights = torch.softmax( self.attention(lstm_out), dim=1 ) context = torch.sum(attn_weights * lstm_out, dim=1)
return self.classifier(context)
|