1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
| """ 论文复现:低光照疲劳检测完整管道 论文:Low-light driver drowsiness detection for real-time safety assistance 作者:Javed et al. 会议:Scientific Reports, 2026 """
import cv2 import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from torchvision import transforms from PIL import Image
class LowLightEnhancer: """低光照增强模块""" def __init__(self, clip_limit: float = 3.0, tile_size: tuple = (8, 8)): self.clip_limit = clip_limit self.tile_size = tile_size self.scales = [15, 80, 250] def clahe(self, image: np.ndarray) -> np.ndarray: """CLAHE增强""" lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE( clipLimit=self.clip_limit, tileGridSize=self.tile_size ) l_enhanced = clahe.apply(l) lab_enhanced = cv2.merge([l_enhanced, a, b]) return cv2.cvtColor(lab_enhanced, cv2.COLOR_LAB2BGR) def msr(self, image: np.ndarray) -> np.ndarray: """多尺度Retinex""" image = image.astype(np.float64) + 1.0 msr_result = np.zeros_like(image) for scale in self.scales: gaussian = cv2.GaussianBlur(image, (0, 0), scale) retinex = np.log10(image) - np.log10(gaussian) msr_result += retinex msr_result /= len(self.scales) for i in range(3): msr_result[:, :, i] = cv2.normalize( msr_result[:, :, i], None, 0, 255, cv2.NORM_MINMAX ) return msr_result.astype(np.uint8) def __call__(self, image: np.ndarray) -> np.ndarray: """融合CLAHE和MSR""" clahe_result = self.clahe(image) msr_result = self.msr(image) enhanced = cv2.addWeighted(clahe_result, 0.6, msr_result, 0.4, 0) return enhanced
class DualAttentionDrowsinessNet(nn.Module): """双注意力疲劳检测网络""" def __init__(self, num_classes: int = 2): super().__init__() from torchvision.models import mobilenet_v3_small mobilenet = mobilenet_v3_small(pretrained=True) self.features = mobilenet.features self.channel_attention = ChannelAttention(576, reduction=16) self.spatial_attention = SpatialAttention(kernel_size=7) self.avgpool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Sequential( nn.Linear(576, 256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(256, num_classes) ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: 输入图像 (B, 3, 224, 224) Returns: logits: 分类输出 (B, num_classes) """ x = self.features(x) x = self.channel_attention(x) x = self.spatial_attention(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x
class DrowsinessDetector: """疲劳检测完整管道""" def __init__(self, model_path: str = None, device: str = 'cuda'): self.device = device self.model = DualAttentionDrowsinessNet(num_classes=2) if model_path: self.model.load_state_dict(torch.load(model_path, map_location=device)) self.model.to(device) self.model.eval() self.enhancer = LowLightEnhancer(clip_limit=3.0, tile_size=(8, 8)) self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) def detect(self, frame: np.ndarray) -> dict: """ 检测单帧疲劳状态 Args: frame: BGR图像 (H, W, 3) Returns: result: 检测结果 """ enhanced = self.enhancer(frame) image_rgb = cv2.cvtColor(enhanced, cv2.COLOR_BGR2RGB) image_pil = Image.fromarray(image_rgb) input_tensor = self.transform(image_pil).unsqueeze(0).to(self.device) with torch.no_grad(): logits = self.model(input_tensor) probs = F.softmax(logits, dim=1) drowsy_prob = probs[0, 1].item() result = { 'drowsy_prob': drowsy_prob, 'is_drowsy': drowsy_prob > 0.5, 'enhanced_frame': enhanced } return result
if __name__ == "__main__": detector = DrowsinessDetector(device='cuda') np.random.seed(42) dark_image = np.random.randint(30, 80, (480, 640, 3), dtype=np.uint8) result = detector.detect(dark_image) print(f"疲劳概率: {result['drowsy_prob']:.2%}") print(f"是否疲劳: {'是' if result['is_drowsy'] else '否'}") cv2.imshow("Original", dark_image) cv2.imshow("Enhanced", result['enhanced_frame']) cv2.waitKey(0)
|