> 文档中心 > 知识蒸馏DEiT算法实战:使用RegNet蒸馏DEiT模型

知识蒸馏DEiT算法实战:使用RegNet蒸馏DEiT模型


摘要

论文翻译:【第58篇】DEiT:通过注意力训练数据高效的图像transformer &蒸馏
DEiT通过引入一个蒸馏token实现蒸馏,蒸馏的方式有两种:

  • 1、将蒸馏token作为Teacher标签。两个token通过注意力在transformer中相互作用。实现蒸馏。用法参考:
    DEiT实战:使用DEiT实现图像分类任务(一)
  • 2、通过卷积神经网络去蒸馏蒸馏token,让transformer从卷积神经网络学习一些卷积特征,比如归纳偏置这样的特征。这一点作者也是表示疑问。

这篇文章就是从第二点入手,使用卷积神经网络蒸馏DEiT。
讲解视频:https://www.zhihu.com/zvideo/1588881049425276928

最终结论

先把结论说了吧! Teacher网络使用RegNet的regnetx_160网络,Student网络使用DEiT的deit_tiny_distilled_patch16_224模型。如下表

网络 epochs ACC
DEiT 100 94%
RegNet 100 96%
DEiT+Hard 100 95%

知识蒸馏DEiT算法实战:使用RegNet蒸馏DEiT模型

项目结构

DeiT_dist_demo├─data│  ├─train│  │  ├─Black-grass│  │  ├─Charlock│  │  ├─Cleavers│  │  ├─Common Chickweed│  │  ├─Common wheat│  │  ├─Fat Hen│  │  ├─Loose Silky-bent│  │  ├─Maize│  │  ├─Scentless Mayweed│  │  ├─Shepherds Purse│  │  ├─Small-flowered Cranesbill│  │  └─Sugar beet│  └─val│      ├─Black-grass│      ├─Charlock│      ├─Cleavers│      ├─Common Chickweed│      ├─Common wheat│      ├─Fat Hen│      ├─Loose Silky-bent│      ├─Maize│      ├─Scentless Mayweed│      ├─Shepherds Purse│      ├─Small-flowered Cranesbill│      └─Sugar beet├─models│  └─models.py├─losses.py├─teacher_train.py├─student_train.py├─train_kd.py└─test.py

data:数据集,分为train和val。
models:存放模型文件。
losses.py:loss文件,计算外部蒸馏loss。
teacher_train.py:训练Teacher模型
student_train.py:训练Student模型
train_kd.py:训练蒸馏模型
test:测试结果。

模型和loss

模型模型models.py和loss脚本losses.py需要从官方模型获取,链接:https://github.com/facebookresearch/deit。

model.py代码

# Copyright (c) 2015-present, Facebook, Inc.# All rights reserved.import torchimport torch.nn as nnfrom functools import partialfrom timm.models.vision_transformer import VisionTransformer, _cfgfrom timm.models.registry import register_modelfrom timm.models.layers import trunc_normal___all__ = [    'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',    'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',    'deit_base_distilled_patch16_224', 'deit_base_patch16_384',    'deit_base_distilled_patch16_384',]class DistilledVisionTransformer(VisionTransformer):    def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) num_patches = self.patch_embed.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() trunc_normal_(self.dist_token, std=.02) trunc_normal_(self.pos_embed, std=.02) self.head_dist.apply(self._init_weights)    def forward_features(self, x): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add the dist_token B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) x = x + self.pos_embed x = self.pos_drop(x) for blk in self.blocks:     x = blk(x) x = self.norm(x) return x[:, 0], x[:, 1]    def forward(self, x): x, x_dist = self.forward_features(x) x = self.head(x) x_dist = self.head_dist(x_dist) if self.training:     return x, x_dist else:     # during inference, return the average of both classifier predictions     return (x + x_dist) / 2@register_modeldef deit_tiny_patch16_224(pretrained=False, **kwargs):    model = VisionTransformer( patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)    model.default_cfg = _cfg()    if pretrained: checkpoint = torch.hub.load_state_dict_from_url(     url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",     map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"])    return model@register_modeldef deit_small_patch16_224(pretrained=False, **kwargs):    model = VisionTransformer( patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)    model.default_cfg = _cfg()    if pretrained: checkpoint = torch.hub.load_state_dict_from_url(     url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",     map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"])    return model@register_modeldef deit_base_patch16_224(pretrained=False, **kwargs):    model = VisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)    model.default_cfg = _cfg()    if pretrained: checkpoint = torch.hub.load_state_dict_from_url(     url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",     map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"])    return model@register_modeldef deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):    model = DistilledVisionTransformer( patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)    model.default_cfg = _cfg()    print(model.default_cfg)    if pretrained: checkpoint = torch.hub.load_state_dict_from_url(     url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",     map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"])    return model@register_modeldef deit_small_distilled_patch16_224(pretrained=False, **kwargs):    model = DistilledVisionTransformer( patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)    model.default_cfg = _cfg()    if pretrained: checkpoint = torch.hub.load_state_dict_from_url(     url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",     map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"])    return model@register_modeldef deit_base_distilled_patch16_224(pretrained=False, **kwargs):    model = DistilledVisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)    model.default_cfg = _cfg()    if pretrained: checkpoint = torch.hub.load_state_dict_from_url(     url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",     map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"])    return model@register_modeldef deit_base_patch16_384(pretrained=False, **kwargs):    model = VisionTransformer( img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)    model.default_cfg = _cfg()    if pretrained: checkpoint = torch.hub.load_state_dict_from_url(     url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",     map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"])    return model@register_modeldef deit_base_distilled_patch16_384(pretrained=False, **kwargs):    model = DistilledVisionTransformer( img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)    model.default_cfg = _cfg()    if pretrained: checkpoint = torch.hub.load_state_dict_from_url(     url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",     map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"])    return model

losses.py代码

# Copyright (c) 2015-present, Facebook, Inc.# All rights reserved."""Implements the knowledge distillation loss"""import torchfrom torch.nn import functional as Fclass DistillationLoss(torch.nn.Module):    """    This module wraps a standard criterion and adds an extra knowledge distillation loss by    taking a teacher model prediction and using it as additional supervision.    """    def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,   distillation_type: str, alpha: float, tau: float): super().__init__() self.base_criterion = base_criterion self.teacher_model = teacher_model assert distillation_type in ['none', 'soft', 'hard'] self.distillation_type = distillation_type self.alpha = alpha self.tau = tau    def forward(self, inputs, outputs, labels): """ Args:     inputs: The original inputs that are feed to the teacher model     outputs: the outputs of the model to be trained. It is expected to be  either a Tensor, or a Tuple[Tensor, Tensor], with the original output  in the first position and the distillation predictions as the second output     labels: the labels for the base criterion """ outputs_kd = None if not isinstance(outputs, torch.Tensor):     # assume that the model outputs a tuple of [outputs, outputs_kd]     outputs, outputs_kd = outputs base_loss = self.base_criterion(outputs, labels) if self.distillation_type == 'none':     return base_loss if outputs_kd is None:     raise ValueError("When knowledge distillation is enabled, the model is " "expected to return a Tuple[Tensor, Tensor] with the output of the " "class_token and the dist_token") # don't backprop throught the teacher with torch.no_grad():     teacher_outputs = self.teacher_model(inputs) if self.distillation_type == 'soft':     T = self.tau     # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100     # with slight modifications     distillation_loss = F.kl_div(  F.log_softmax(outputs_kd / T, dim=1),  #We provide the teacher's targets in log probability because we use log_target=True   #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)  #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.  F.log_softmax(teacher_outputs / T, dim=1),  reduction='sum',  log_target=True     ) * (T * T) / outputs_kd.numel()     #We divide by outputs_kd.numel() to have the legacy PyTorch behavior.      #But we also experiments output_kd.size(0)      #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details elif self.distillation_type == 'hard':     distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha return loss

训练Teacher模型

Teacher选用regnetx_160,这个模型的预训练模型比较大,如果不能直接下来,可以借助下载工具,比如某雷下载。

步骤

新建teacher_train.py,插入代码:

导入需要的库

import torch.optim as optimimport torchimport torch.nn as nnimport torch.nn.parallelimport torch.utils.dataimport torch.utils.data.distributedimport torchvision.transforms as transformsfrom torchvision import datasetsfrom torch.autograd import Variablefrom timm.models import regnetx_160import jsonimport os# 定义训练过程

定义训练和验证函数

# 设置随机因子def seed_everything(seed=42):    os.environ['PYHTONHASHSEED'] = str(seed)    torch.manual_seed(seed)    torch.cuda.manual_seed(seed)    torch.backends.cudnn.deterministic = True# 训练函数def train(model, device, train_loader, optimizer, epoch):    model.train()    sum_loss = 0    total_num = len(train_loader.dataset)    print(total_num, len(train_loader))    for batch_idx, (data, target) in enumerate(train_loader): data, target = Variable(data).to(device), Variable(target).to(device) out = model(data) loss = criterion(out, target) optimizer.zero_grad() loss.backward() optimizer.step() print_loss = loss.data.item() sum_loss += print_loss if (batch_idx + 1) % 10 == 0:     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(  epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),  100. * (batch_idx + 1) / len(train_loader), loss.item()))    ave_loss = sum_loss / len(train_loader)    print('epoch:{},loss:{}'.format(epoch, ave_loss))Best_ACC=0# 验证过程@torch.no_grad()def val(model, device, test_loader):    global Best_ACC    model.eval()    test_loss = 0    correct = 0    total_num = len(test_loader.dataset)    print(total_num, len(test_loader))    with torch.no_grad(): for data, target in test_loader:     data, target = Variable(data).to(device), Variable(target).to(device)     out = model(data)     loss = criterion(out, target)     _, pred = torch.max(out.data, 1)     correct += torch.sum(pred == target)     print_loss = loss.data.item()     test_loss += print_loss correct = correct.data.item() acc = correct / total_num avgloss = test_loss / len(test_loader) if acc > Best_ACC:     torch.save(model, file_dir + '/' + 'best.pth')     Best_ACC = acc print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(     avgloss, correct, len(test_loader.dataset), 100 * acc)) return acc

定义全局参数

if __name__ == '__main__':    # 创建保存模型的文件夹    file_dir = 'TeacherModel'    if os.path.exists(file_dir): print('true') os.makedirs(file_dir, exist_ok=True)    else: os.makedirs(file_dir)    # 设置全局参数    modellr = 1e-4    BATCH_SIZE = 16    EPOCHS = 100    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    SEED=42    seed_everything(SEED)

图像预处理与增强

 # 数据预处理7    transform = transforms.Compose([ transforms.RandomRotation(10), transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)), transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])    ])    transform_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])    ])

读取数据

使用pytorch默认读取数据的方式。

   # 读取数据    dataset_train = datasets.ImageFolder('data/train', transform=transform)    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)    with open('class.txt', 'w') as file: file.write(str(dataset_train.class_to_idx))    with open('class.json', 'w', encoding='utf-8') as file: file.write(json.dumps(dataset_train.class_to_idx))    # 导入数据    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

设置模型和Loss

    # 实例化模型并且移动到GPU    criterion = nn.CrossEntropyLoss()    model_ft = regnetx_160(pretrained=True)    model_ft.reset_classifier(num_classes=12)    model_ft.to(DEVICE)    # 选择简单暴力的Adam优化器,学习率调低    optimizer = optim.Adam(model_ft.parameters(), lr=modellr)    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)    # 训练    val_acc_list= {}    for epoch in range(1, EPOCHS + 1): train(model_ft, DEVICE, train_loader, optimizer, epoch) cosine_schedule.step() acc=val(model_ft, DEVICE, test_loader) val_acc_list[epoch]=acc with open('result.json', 'w', encoding='utf-8') as file:     file.write(json.dumps(val_acc_list))    torch.save(model_ft, 'TeacherModel/model_final.pth')

完成上面的代码就可以开始训练Teacher网络了。

学生网络

学生网络选用deit_tiny_distilled_patch16_224,是一个比较小一点的网络了,模型的大小有20M。训练100个epoch。

步骤

新建student_train.py,插入代码:

导入需要的库

import torch.optim as optimimport torchimport torch.nn as nnimport torch.nn.parallelimport torch.utils.dataimport torch.utils.data.distributedimport torchvision.transforms as transformsfrom torchvision import datasetsfrom torch.autograd import Variablefrom models.models import deit_tiny_distilled_patch16_224import jsonimport os

定义训练和验证函数

# 设置随机因子def seed_everything(seed=42):    os.environ['PYHTONHASHSEED'] = str(seed)    torch.manual_seed(seed)    torch.cuda.manual_seed(seed)    torch.backends.cudnn.deterministic = True# 定义训练过程def train(model, device, train_loader, optimizer, epoch):    model.train()    sum_loss = 0    total_num = len(train_loader.dataset)    print(total_num, len(train_loader))    for batch_idx, (data, target) in enumerate(train_loader): data, target = Variable(data).to(device), Variable(target).to(device) out = model(data)[0] loss = criterion(out, target) optimizer.zero_grad() loss.backward() optimizer.step() print_loss = loss.data.item() sum_loss += print_loss if (batch_idx + 1) % 10 == 0:     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(  epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),  100. * (batch_idx + 1) / len(train_loader), loss.item()))    ave_loss = sum_loss / len(train_loader)    print('epoch:{},loss:{}'.format(epoch, ave_loss))Best_ACC=0# 验证过程@torch.no_grad()def val(model, device, test_loader):    global Best_ACC    model.eval()    test_loss = 0    correct = 0    total_num = len(test_loader.dataset)    print(total_num, len(test_loader))    with torch.no_grad(): for data, target in test_loader:     data, target = Variable(data).to(device), Variable(target).to(device)     out = model(data)     loss = criterion(out, target)     _, pred = torch.max(out.data, 1)     correct += torch.sum(pred == target)     print_loss = loss.data.item()     test_loss += print_loss correct = correct.data.item() acc = correct / total_num avgloss = test_loss / len(test_loader) if acc > Best_ACC:     torch.save(model, file_dir + '/' + 'best.pth')     Best_ACC = acc print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(     avgloss, correct, len(test_loader.dataset), 100 * acc)) return acc

这里要注意一点,由于我们使用的官方的模型,在做正常的训练时,返回值有两个,分别是x和x_dist。
知识蒸馏DEiT算法实战:使用RegNet蒸馏DEiT模型
loss计算只需前一个值,即:

 out = model(data)[0]

在验证的时候,返回一个值。所以不用做上面的操作了,即:

 out = model(data)

定义全局参数

if __name__ == '__main__':    # 创建保存模型的文件夹    file_dir = 'StudentModel'    if os.path.exists(file_dir): print('true') os.makedirs(file_dir, exist_ok=True)    else: os.makedirs(file_dir)    # 设置全局参数    modellr = 1e-4    BATCH_SIZE = 16    EPOCHS = 100    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    SEED=42    seed_everything(SEED)

图像预处理与增强

 # 数据预处理7    transform = transforms.Compose([ transforms.RandomRotation(10), transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)), transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])    ])    transform_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])    ])

读取数据

使用pytorch默认读取数据的方式。

    # 读取数据    dataset_train = datasets.ImageFolder('data/train', transform=transform)    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)    with open('class.txt', 'w') as file: file.write(str(dataset_train.class_to_idx))    with open('class.json', 'w', encoding='utf-8') as file: file.write(json.dumps(dataset_train.class_to_idx))    # 导入数据    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

设置模型和Loss

 # 实例化模型并且移动到GPU    criterion = nn.CrossEntropyLoss()    model_ft = deit_tiny_distilled_patch16_224(pretrained=True)    num_ftrs = model_ft.head.in_features    model_ft.head = nn.Linear(num_ftrs, 12)    num_ftrs_dist = model_ft.head_dist.in_features    model_ft.head_dist = nn.Linear(num_ftrs_dist, 12)    print(model_ft)    model_ft.to(DEVICE)    # 选择简单暴力的Adam优化器,学习率调低    optimizer = optim.Adam(model_ft.parameters(), lr=modellr)    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)    # 训练    val_acc_list= {}    for epoch in range(1, EPOCHS + 1): train(model_ft, DEVICE, train_loader, optimizer, epoch) cosine_schedule.step() acc=val(model_ft, DEVICE, test_loader) val_acc_list[epoch]=acc with open('result_student.json', 'w', encoding='utf-8') as file:     file.write(json.dumps(val_acc_list))    torch.save(model_ft, 'StudentModel/model_final.pth')

完成上面的代码就可以开始训练Student网络了。

蒸馏学生网络

学生网络继续选用deit_tiny_distilled_patch16_224,使用Teacher网络蒸馏学生网络,训练100个epoch。

步骤

新建train_kd.py.py,插入代码:

导入需要的库

import torch.optim as optimimport torchimport torch.nn as nnimport torch.nn.parallelimport torch.utils.dataimport torch.utils.data.distributedimport torchvision.transforms as transformsfrom timm.loss import LabelSmoothingCrossEntropyfrom torchvision import datasetsfrom models.models import deit_tiny_distilled_patch16_224import jsonimport osfrom losses import DistillationLoss

定义训练和验证函数

# 设置随机因子def seed_everything(seed=42):    os.environ['PYHTONHASHSEED'] = str(seed)    torch.manual_seed(seed)    torch.cuda.manual_seed(seed)    torch.backends.cudnn.deterministic = True    # 定义训练过程def train(s_net,t_net, device,criterionKD,train_loader, optimizer, epoch):    s_net.train()    sum_loss = 0    total_num = len(train_loader.dataset)    print(total_num, len(train_loader))    for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() out_s = s_net(data) loss = criterionKD(data,out_s, target) loss.backward() optimizer.step() print_loss = loss.data.item() sum_loss += print_loss if (batch_idx + 1) % 10 == 0:     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(  epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),  100. * (batch_idx + 1) / len(train_loader), loss.item()))    ave_loss = sum_loss / len(train_loader)    print('epoch:{},loss:{}'.format(epoch, ave_loss))Best_ACC=0# 验证过程@torch.no_grad()def val(model, device,criterionCls, test_loader):    global Best_ACC    model.eval()    test_loss = 0    correct = 0    total_num = len(test_loader.dataset)    print(total_num, len(test_loader))    with torch.no_grad(): for data, target in test_loader:     data, target = data.to(device), target.to(device)     out_s = model(data)     loss = criterionCls(out_s, target)     _, pred = torch.max(out_s.data, 1)     correct += torch.sum(pred == target)     print_loss = loss.data.item()     test_loss += print_loss correct = correct.data.item() acc = correct / total_num avgloss = test_loss / len(test_loader) if acc > Best_ACC:     torch.save(model, file_dir + '/' + 'best.pth')     Best_ACC = acc print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(     avgloss, correct, len(test_loader.dataset), 100 * acc)) return acc

定义全局参数

if __name__ == '__main__':    # 创建保存模型的文件夹    file_dir = 'KDModel'    if os.path.exists(file_dir): print('true') os.makedirs(file_dir, exist_ok=True)    else: os.makedirs(file_dir)    # 设置全局参数    modellr = 1e-4    BATCH_SIZE = 4    EPOCHS = 100    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    SEED=42    seed_everything(SEED)    distillation_type='hard'  #['none', 'soft', 'hard']    distillation_alpha=0.5    distillation_tau=1.0

distillation_type:蒸馏的类型,本文选用hard。
distillation_alpha:α系数,蒸馏loss的权重系数。
distillation_tau:T,蒸馏温度的意思。

图像预处理与增强

 # 数据预处理7    transform = transforms.Compose([ transforms.RandomRotation(10), transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)), transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])    ])    transform_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])    ])

读取数据

使用pytorch默认读取数据的方式。

    # 读取数据    dataset_train = datasets.ImageFolder('data/train', transform=transform)    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)    with open('class.txt', 'w') as file: file.write(str(dataset_train.class_to_idx))    with open('class.json', 'w', encoding='utf-8') as file: file.write(json.dumps(dataset_train.class_to_idx))    # 导入数据    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

设置模型和Loss

 model_ft = deit_tiny_distilled_patch16_224(pretrained=True)    num_ftrs = model_ft.head.in_features    model_ft.head = nn.Linear(num_ftrs, 12)    num_ftrs_dist = model_ft.head_dist.in_features    model_ft.head_dist = nn.Linear(num_ftrs_dist, 12)    print(model_ft)    model_ft.to(DEVICE)    # 选择简单暴力的Adam优化器,学习率调低    optimizer = optim.Adam(model_ft.parameters(), lr=modellr)    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)    teacher_model=torch.load('TeacherModel/best.pth')    teacher_model.eval()    # 实例化模型并且移动到GPU    criterion = LabelSmoothingCrossEntropy(smoothing=0.1)    criterionKD = DistillationLoss( criterion, teacher_model, distillation_type, distillation_alpha, distillation_tau    )    criterionCls = nn.CrossEntropyLoss()    # 训练    val_acc_list= {}    for epoch in range(1, EPOCHS + 1): train(model_ft,teacher_model, DEVICE,criterionKD, train_loader, optimizer, epoch) cosine_schedule.step() acc=val(model_ft,DEVICE,criterionCls , test_loader) val_acc_list[epoch]=acc with open('result_kd.json', 'w', encoding='utf-8') as file:     file.write(json.dumps(val_acc_list))    torch.save(model_ft, 'KDModel/model_final.pth')

完成上面的代码就可以开始蒸馏模式!!!

结果比对

加载保存的结果,然后绘制acc曲线。

import numpy as npfrom matplotlib import pyplot as pltimport jsonteacher_file='result.json'student_file='result_student.json'student_kd_file='result_kd.json'def read_json(file):    with open(file, 'r', encoding='utf8') as fp: json_data = json.load(fp) print(json_data)    return json_datateacher_data=read_json(teacher_file)student_data=read_json(student_file)student_kd_data=read_json(student_kd_file)x =[int(x) for x in  list(dict(teacher_data).keys())]print(x)plt.plot(x, list(teacher_data.values()), label='teacher')plt.plot(x,list(student_data.values()), label='student without IRG')plt.plot(x, list(student_kd_data.values()), label='student with IRG')plt.title('Test accuracy')plt.legend()plt.show()

知识蒸馏DEiT算法实战:使用RegNet蒸馏DEiT模型

总结

本文重点讲解了如何使用外部模型蒸馏算法对DeiT模型进行蒸馏。希望能帮助到大家,如果觉得有用欢迎收藏、点赞和转发;如果有问题也可以留言讨论。
本次实战用到的代码和数据集详见:

https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/87323531

手机爆料