1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
| class KnowledgeDistillation: """ 知识蒸馏 """ def __init__(self, teacher_model, student_model, temperature=3.0, alpha=0.5): self.teacher = teacher_model self.student = student_model self.temperature = temperature self.alpha = alpha for param in self.teacher.parameters(): param.requires_grad = False self.teacher.eval() def distillation_loss(self, student_output, teacher_output, labels): """ 计算蒸馏损失 """ soft_loss = torch.nn.KLDivLoss(reduction='batchmean')( torch.nn.functional.log_softmax(student_output / self.temperature, dim=1), torch.nn.functional.softmax(teacher_output / self.temperature, dim=1) ) * (self.temperature ** 2) hard_loss = torch.nn.CrossEntropyLoss()(student_output, labels) total_loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss return total_loss def train(self, train_loader, epochs=50): """ 蒸馏训练 """ optimizer = torch.optim.Adam(self.student.parameters(), lr=1e-4) for epoch in range(epochs): for images, labels in train_loader: with torch.no_grad(): teacher_output = self.teacher(images) student_output = self.student(images) loss = self.distillation_loss(student_output, teacher_output, labels) optimizer.zero_grad() loss.backward() optimizer.step() return self.student
teacher = load_large_model() student = create_small_model()
distiller = KnowledgeDistillation(teacher, student) distilled_student = distiller.train(train_loader)
|