AI学习者 · 2021年09月18日

【ReID学习笔记】Auto-ReID:ReID结构搜索首次尝试(附单卡代码下载)

1、摘要

目前流行的用于行人ReID的深度卷积神经网络(deep convolutional neural network, CNNs),通常是建立在ResNet或VGG的BackBone上的,这些BackBone原本是为了分类Classification而设计的。因为ReID任务不同于分类,所以体系结构应该进行相应的修改。我们建议自动搜索一个特别适合ReID任务的CNN架构。有三个方面需要处理。

第一,身体结构信息在ReID中起着重要的作用,但它并不能在BackBone中被Encode编码。

第二,神经架构搜索(Neural Architecture Search, NAS)在不需要人工干预的情况下,实现了架构设计过程的自动化,但是目前还没有一种NAS方法将输入图像的结构信息纳入其中。

第三,ReID本质上是一个检索任务,而目前的NAS算法仅仅是为了分类而设计的。

为了解决这些问题,我们提出了一种基于检索的搜索算法。我们的Auto-ReID使自动化方法能够为ReID找到一个高效、有效的CNN架构。大量的实验表明,搜索的架构在减少50%参数和53%失败的同时,达到了最先进的性能

2、研究现状介绍

最新的ReID模型是基于深度CNNs的。它们通常建立在用于图像分类,如VGG、Inception和ResNet的卷积神经网络的基础上。这些骨架可以很容易地用于检索,因为这两个任务的输入都是图像。但是,在ReID任务和分类任务之间仍然存在一些差异。例如,在图像分类中,两个物体的外观可能不同,例如,猫看起来不同于树。相比之下,ReID任务的所有输入示例都是具有不同属性的人物图像,例如服装或发型。一个专注于识别超过1000个物体的CNN在应用到ReID任务时应该进行修改。
image.png

一种直接的方法是手工设计一个专门针对ReID问题的面向reID的CNN架构。然而,为ReID任务手工设计一个精致的架构可能需要几个月的时间,甚至对于人类专家来说也是如此。这是低效和劳动密集型的。在这篇文章中,我们提出了一种自动的方法来搜索最适合ReID任务的CNN架构。我们的前提是,为分类而设计的CNN主干可能有冗余和缺失的检索组件(ReID任务),例如:

    (1)、较少的池层有利于ReID的准确性;
    (2)、没有明确捕获体结构信息的分类组件;

对于ReID来说,自动化神经架构搜索(NAS)仍然存在3个挑战。

首先,没有现有的NAS方法搜索并保存身体结构信息的CNN架构。身体的结构信息在ReID中起着非常重要的作用,这也是ReID与分类的主要区别。

其次,ReID方法通常以一种依赖于Backboone的方式对身体结构信息进行编码。当采用不同的主干网时,它们需要大量的超参数手动调整。

最后,ReID本质上是一个检索任务,但是大多数NAS算法都是为分类而设计的。由于检索和分类的目标不同,现有的NAS算法并不直接适用于ReID问题。

3、Auto-ReID的贡献

Auto-ReID的贡献如下:

这是第一个为ReID任务搜索神经结构结构的方法,减轻了人工为ReID设计CNN模型的工作。

我们提出了一种新颖的里德搜索空间,将身体结构作为可训练和可操作的CNN组件。提出的ReID搜索空间结合了(1)明确捕获行人身体部位信息的模块和(2)标准NAS搜索空间中使用的典型模块。

为了更好地拟合ReID任务,我们将检索损失集成到可微NAS算法中。根据新的检索目标,采用改进的搜索策略和批量数据抽样方法。

大量实验表明,所搜索的CNN与ReID基线相比具有较强的精度,而该CNN的ReID基线参数不足40%。

通过在ImageNet上对CNN进行初始化的预训练,我们可以在只有一半参数的情况下实现三个ReID基准测试的最新性能。

4、Auto-ReID的相关工作

4.1、网络结构搜索

Auto-ReID是受到最近NAS研究的启发,工作重点是寻找一个高性能的ReID模型,而不是一个分类模型。大多数的NAS方法都是在一个小的代理任务上搜索CNN,然后将搜索得到的CNN结构迁移到另一个大的目标任务上。Zoph等人将强化学习应用于搜索CNN,而搜索成本却超过了GPU数百天。Real等人通过引入年龄属性来修改Tournament Selection Evolutionary Algorithm,使其更倾向于年轻的CNN候选人。Brock等和Bender等人探讨了一次性NAS方法。Liu等人对离散搜索空间进行了松弛,使得CNN的搜索是可微的。Dong等人提出了一种可微抽样方法来改进。得益于参数共享技术,我们抛弃了代理范式,直接在目标ReID数据集上搜索一个鲁棒的CNN。

此外,以往的NAS算法主要针对分类问题。它们是通用的,可以很容易地应用于ReID问题。但是,如果不考虑ReID的语义、遮挡、姿态、部分等具体信息,一般的NAS方法并不能保证搜索到的CNN适合ReID任务。

image.png

本工作中,在一个有效的NAS算法的基础上,采用两种技术来修改它的里德搜索空间问题:

1、修改了目标函数和训练策略以适应reID问题。
2、设计了一个Part-aware模块,并将其整合到标准的NAS搜索空间中;

这样可以找到更好的CNN,推进NAS搜索空间的研究。

5、 Auto-ReID算法

5.1、Preliminaries

大多数NAS方法是将一个Neural Cell的多个拷贝堆叠起来构建一个CNN模型。一个Neural Cell由几种不同的层组成,从以前的Cell中获取输出张量并生成一个新的输出张量。我们遵循之前的NAS方法来搜索Neural Cell的拓扑结构。

具体来说,一个Neural Cell可以看做是一个directed acyclic graph(DAG),假设有B个blocks。

每一个block有如下三个步骤:

1). 将 2 tensors 作为输入;
2). 在输入 tensors 上分别采用 two operations;
3). 将这两个 tensors 进行 sum。

 而所选择的操作,就是从一个候选操作集合 O 上选择。

本文中,采用如下的 operations:

(1) 3×3 max pooling,
(2) 3×3 average pooling,
(3) 3×3 depth-wise separable convolution,
(4) 3×3 dilated convolution,
(5) zero operation (none),
(6) identity mapping.

第c个cell的第i个block可以表达为如下的四元组:
image.png

另外,第c个cell的第i个block的输出是:

image.png
其中O^c_i1和O^c_i2是从O中为第i个块选择的操作。I^c_i1和I^c_i2是从候选输入张量I^c_i中选择的,I^c_i由最后两个神经细胞(I^c-1和I^c-2)的输出张量和当前细胞中前一个块的输出张量组成。

为了搜索上述公式中的操作符 O 以及 I,我们将操作符选择的问题进行松弛,用 softmax 来对每一个 operation 进行打分:

image.png

其中,αα代表对于一个Neural Cell的拓扑结构,称为:Architecture Parameters。我们将所有O中的参数记为w,称为 Operation Parameters,一个典型的可微分的NAS算法,联合的在Training Set上进行w的训练,在Validation Set上进行αα的训练。在训练之后,H和I之间的强度(strength)定义为:
image.png
其中,带有最大强度的H被选择作为I,做大权重的操作符,被选为Oci1Oi1c。上述常规的NAS搜索方法DARTS是被设计用于分类任务的,所以作者考虑将该机制结合到ReID任务中。

5.2、ReID Search Algorithm

作者尝试将The ReID Specific Knowledge结合到搜索算法中。从网络结构的角度来说,作者利用ResNet的macro structure 作为ReID的backbone,每一个Residual Layer被Neural Cell 所替换。而这种Neural Cell的结构,就是搜索出来的。
image.png
上述过程,简述了特征变换的过程。另外一个重要的问题是损失函数的定义,在分类任务中Cross-Entropy Loss Function 当然是首选,那么问题是ReID任务并非简单的分类问题。所以,作者在这里做了些许的改变,引入了Triplet Loss 来进一步改善网络的训练过程。联合的损失函数表示如下:

image.png
其中,交叉熵损失函数Ls和Lt三元组损失函数的定义分别如下所示:

image.png
上述就是损失函数的定义。算法流程如下所示:
image.png

5.3、ReID Search Space with Part-Aware Module

作者设计的part-aware module来改善搜索空间的问题。作者也给出了一个pipeline,来说明该过程:

image.png

设计了一个部分感知模块,并将其与一个公共搜索空间(O)相结合来构造我们的里德搜索空间Oreid:

(1) part-aware module
(2) 3 × 3 max pooling
(3) 3 × 3average pooling
(4) 3 × 3 depth-wise separable convolution
(5) 3 × 3 dilated convolution
(6) zero operation, and
(7) identity mapping

部件感知模块如上图所示。

给定一个输入特征张量F,首先把它垂直分割成M个部分,在图中展示了M=4的例子。在得到零件特征后,对每个零件特征在空间维度上进行平均,并对合并后的特征进行线性变换,从而得到M个局部零件特征向量。

然后,在这些M部分特征向量上应用了一个Attention机制。通过这种方式,可以将全局信息整合到每个部分向量中,以增强其身体结构线索。

再然后,将每个部分向量融合到它原来的空间形状中,并将重复的部分特征垂直地连接到一个体结构增强的特征张量中。

最后,将部分感知张量与原输入特征张量进行融合通道级联的方式,在这个融合张量上应用1x1的卷积层来生成输出张量。设计的部件感知模块可以捕获有用的身体部件线索,并将这些结构信息融入到输入特征中。

此外,参数的大小和数量所提出的part-aware模块的计算类似于3x3深度可分卷积,因此与使用标准的NAS搜索空间相比,不会影响发现的CNN的效率。

这么做的优势是什么呢?作者提到本文所设计的 part-aware module 可以捕获有用的 body part cue,并将这种结构化信息融合到 input feature 中。本文所提出的part-aware module 的参数大小和数量 和3x3的depth-wise separable convlution 是相当的。所以,并不会显著的影响 NAS 的效率。

from __future__ import print_function, division

import os
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable

version = torch.__version__

import time
import utils.distributed as dist
## visualization
from tensorboardX import SummaryWriter
import yaml
import argparse
from utils.configurations import visualize_configurations, transfer_txt
from utils import loggers

from data import tripletsample_dataset as base_dataset
from models import baseline_cls, optimizers, losses

from models.DARTS.archetect import Architect
from models.DARTS.search_cnn import SearchCNNController

import utils.metrics as metrics

try:
    from utils.visualization import plot
except:
    print('\nNo graphic visualization supports...')

try:
    from apex.fp16_utils import *
    from apex import amp, optimizers
    # fp16 = True
except:
    print('\nNo apex supports, using default setting in pytorch {} \n'.format(version))

######## solve multi-thread crash in IDEs #########
import multiprocessing

multiprocessing.set_start_method('spawn', True)

parser = argparse.ArgumentParser(description='Re-Implementation of Darts Based Partial Aware People Re-ID')
parser.add_argument('--config', default='configs/baseline_classification_DARTS_distributed.yaml')
parser.add_argument("--verbose", default=False, help='whether verbose each stage')
parser.add_argument('--port', default=10530, type=int, help='port of server')
parser.add_argument('--distributed', default=False, help='switch to distributed training on slurm')
# parser.add_argument('--world-size', default=1, type=int)
# parser.add_argument('--rank', default=0, type=int)
parser.add_argument('--resume', default=False, help='resume')

parser.add_argument('--fix_gpu_id', default=False, help='for extreme condition, some are not working')
parser.add_argument('--sync_grad_sum', default=True, help='sychronize sum or mean')
parser.add_argument('--fp16', default=False, help='whether use apex quantization')
args = parser.parse_args()

if args.fix_gpu_id == False and torch.cuda.is_available():
    device = torch.device("cuda")
    use_gpu = True
    try:
        if len(args.fix_gpu_id) > 0:
            torch.cuda.set_device(args.fix_gpu_id[0])
    except:
        print('Not fixing GPU ids...')
else:
    device = torch.device("cpu")

#####################################################################
### history for draw graph

y_loss = {}  # loss history
y_loss['train'] = []
y_loss['val'] = []
y_err = {}
y_err['train'] = []
y_err['val'] = []

best_top1 = 0


######################################################################
# Save model
# ---------------------------
def save_network(args, network, epoch_label, top1, isbest=False):
    if isbest:
        save_filename = 'best.pth'
    else:
        save_filename = 'net_%s.pth' % epoch_label
    save_path = os.path.join(args.checkpoint, args.task_name, save_filename)
    if not os.path.isdir(os.path.join(args.checkpoint, args.task_name)):
        os.makedirs(os.path.join(args.checkpoint, args.task_name))
    checkpoint = {}
    checkpoint['network'] = network.cpu().state_dict()
    checkpoint['epoch'] = epoch_label
    checkpoint['top1'] = top1
    torch.save(checkpoint, save_path)


def train(args, train_loader, valid_loader, model, architect, w_optim, alpha_optim, lr_scheduler, epoch=0):
    print('-------------------training_start at epoch {}---------------------'.format(epoch))
    top1 = metrics.AverageMeter()
    top5 = metrics.AverageMeter()
    top10 = metrics.AverageMeter()
    losses = metrics.AverageMeter()

    cur_step = epoch * len(train_loader)

    lr_scheduler.step()
    lr = lr_scheduler.get_lr()[0]

    if args.distributed:
        if rank == 0:
            writer.add_scalar('train/lr', lr, cur_step)
    else:
        writer.add_scalar('train/lr', lr, cur_step)

    model.train()

    running_loss = 0.0
    running_corrects = 0.0
    # step = 0
    model.to(device)

    for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(train_loader, valid_loader)):
        # step = step+1
        now_batch_size, c, h, w = trn_X.shape
        trn_X, trn_y = trn_X.to(device, non_blocking=True), trn_y.to(device, non_blocking=True)
        val_X, val_y = val_X.to(device, non_blocking=True), val_y.to(device, non_blocking=True)

        if args.distributed:
            if now_batch_size < int(args.batch_size // world_size):
                continue
        else:
            if now_batch_size < args.batch_size:  # skip the last batch
                continue

        alpha_optim.zero_grad()
        architect.unrolled_backward(trn_X, trn_y, val_X, val_y, lr, w_optim)
        alpha_optim.step()

        w_optim.zero_grad()
        logits = model(trn_X)
        loss = model.criterion(logits, trn_y)
        loss.backward()

        # gradient clipping\
        if args.w_grad_clip != False:
            nn.utils.clip_grad_norm_(model.weights(), args.w_grad_clip)

        if args.distributed:
            if args.sync_grad_sum:
                dist.sync_grad_sum(model)
            else:
                dist.sync_grad_mean(model)

        w_optim.step()

        if args.distributed:
            dist.sync_bn_stat(model)

        prec1, prec5, prec10 = metrics.accuracy(logits, trn_y, topk=(1, 5, 10))

        if args.distributed:
            dist.simple_sync.allreducemean_list([loss, prec1, prec5, prec10])

        losses.update(loss.item(), now_batch_size)
        top1.update(prec1.item(), now_batch_size)
        top5.update(prec5.item(), now_batch_size)
        top10.update(prec10.item(), now_batch_size)

        # running_loss += loss.item() * now_batch_size

        # y_loss['train'].append(losses)
        # y_err['train'].append(1.0-top1)

        if args.distributed:
            if rank == 0:
                if step % args.print_freq == 0 or step == len(train_loader) - 1:
                    logger.info(
                        "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                        "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                            epoch + 1, args.epochs, step, len(train_loader) - 1, losses=losses,
                            top1=top1, top5=top5))

                writer.add_scalar('train/loss', loss.item(), cur_step)
                writer.add_scalar('train/top1', prec1.item(), cur_step)
                writer.add_scalar('train/top5', prec5.item(), cur_step)
                writer.add_scalar('train/top10', prec10.item(), cur_step)
        else:
            if step % args.print_freq == 0 or step == len(train_loader) - 1:
                logger.info(
                    "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch + 1, args.epochs, step, len(train_loader) - 1, losses=losses,
                        top1=top1, top5=top5))

            writer.add_scalar('train/loss', loss.item(), cur_step)
            writer.add_scalar('train/top1', prec1.item(), cur_step)
            writer.add_scalar('train/top5', prec5.item(), cur_step)
            writer.add_scalar('train/top10', prec10.item(), cur_step)

        cur_step += 1
    if args.distributed:
        if rank == 0:
            logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, args.epochs, top1.avg))
    else:
        logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, args.epochs, top1.avg))

    if epoch % args.forcesave == 0:
        save_network(args, model, epoch, top1)


def validate(args, valid_loader, model, epoch=0, cur_step=0):
    print('-------------------validation_start at epoch {}---------------------'.format(epoch))
    top1 = metrics.AverageMeter()
    top5 = metrics.AverageMeter()
    top10 = metrics.AverageMeter()
    losses = metrics.AverageMeter()

    model.eval()
    model.to(device)
    with torch.no_grad():
        for step, (X, y) in enumerate(valid_loader):
            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
            N = X.size(0)

            ### 必须加分布式判断,否则validation跳过一直为真。
            if args.distributed:
                if N < int(args.batch_size // world_size):
                    continue
            else:
                if N < args.batch_size:  # skip the last batch
                    continue

            logits = model(X)
            loss = model.criterion(logits, y)

            prec1, prec5, prec10 = metrics.accuracy(logits, y, topk=(1, 5, 10))

            if args.distributed:
                dist.simple_sync.allreducemean_list([loss, prec1, prec5, prec10])

            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)
            top10.update(prec10.item(), N)

            if args.distributed:
                if rank == 0:
                    if step % args.print_freq == 0 or step == len(valid_loader) - 1:
                        logger.info(
                            "Valid: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                            "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                                epoch + 1, args.epochs, step, len(valid_loader) - 1, losses=losses,
                                top1=top1, top5=top5))

            else:
                if step % args.print_freq == 0 or step == len(valid_loader) - 1:
                    logger.info(
                        "Valid: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                        "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                            epoch + 1, args.epochs, step, len(valid_loader) - 1, losses=losses,
                            top1=top1, top5=top5))

    if args.distributed:
        if rank == 0:
            writer.add_scalar('val/loss', losses.avg, cur_step)
            writer.add_scalar('val/top1', top1.avg, cur_step)
            writer.add_scalar('val/top5', top5.avg, cur_step)
            writer.add_scalar('val/top10', top10.avg, cur_step)

            logger.info(
                "Valid: [{:2d}/{}] Final Prec@1 {:.4%}, Prec@5 {:.4%}, Prec@10 {:.4%}".format(epoch + 1, args.epochs,
                                                                                              top1.avg, top5.avg,
                                                                                              top10.avg))
    else:
        writer.add_scalar('val/loss', losses.avg, cur_step)
        writer.add_scalar('val/top1', top1.avg, cur_step)
        writer.add_scalar('val/top5', top5.avg, cur_step)
        writer.add_scalar('val/top10', top10.avg, cur_step)

        logger.info(
            "Valid: [{:2d}/{}] Final Prec@1 {:.4%}, Prec@5 {:.4%}, Prec@10 {:.4%}".format(epoch + 1, args.epochs,
                                                                                          top1.avg, top5.avg,
                                                                                          top10.avg))

    return top1.avg

    # optimizer, scheduler


def main():
    global args, use_gpu, writer, rank, logger, best_top1, world_size, rank
    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.load(f)

    #######  visualize configs ######
    visualize_configurations(config)
    #######  set args ######
    for key in config:
        for k, v in config[key].items():
            setattr(args, k, v)
    if args.verbose:
        print('Config parsing complete')

    #######  world initial ######
    if args.distributed:
        rank, world_size = dist.dist_init(args.port, 'nccl')
        logger = loggers.get_logger(os.path.join(args.logpath, '{}.distlog'.format(args.task_name)))
        if rank == 0:
            tbpath = os.path.join(args.logpath, 'tb', args.task_name)
            if os.path.isdir(tbpath):
                writer = SummaryWriter(log_dir=tbpath)
            else:
                os.makedirs(tbpath)
                writer = SummaryWriter(log_dir=tbpath)
            writer.add_text('config_infomation', transfer_txt(args))

            logger.info("Logger is set ")
            logger.info("Logger with distribution")
    else:

        tbpath = os.path.join(args.logpath, 'tb', args.task_name)
        if os.path.isdir(tbpath):
            writer = SummaryWriter(log_dir=tbpath)
        else:
            os.makedirs(tbpath)
            writer = SummaryWriter(log_dir=tbpath)
        writer.add_text('config_infomation', transfer_txt(args))
        logger = loggers.get_logger(os.path.join(args.logpath, '{}.log'.format(args.task_name)))
        logger.info("Logger is set ")
        logger.info("Logger without distribution")

    ######## initial random setting #######

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    torch.backends.cudnn.benchmark = True

    ######## test data reading ########

    since = time.time()
    dataset_train_val = base_dataset.baseline_dataset(args)
    train_loader, val_loader = dataset_train_val.get_loader()
    logger.info("Initializing dataset used {} basic time unit".format(time.time() - since))

    logger.info("The training classes labels length :  {}".format(len(dataset_train_val.train_classnames)))
    since = time.time()
    inputs, classes = next(iter(train_loader))
    logger.info('batch loading time example is {}'.format(time.time() - since))

    ######### Init model ############
    # woptimizer =  optimizers.get_optimizer(args, model)
    # lr_schedular = optimizers.get_lr_scheduler(args, woptimizer)
    criterion = losses.get_loss(args)

    criterion.to(device)

    if args.model_name == 'Darts_normal':
        model = SearchCNNController(args.input_channels, args.init_channels, len(dataset_train_val.train_classnames),
                                    args.Search_layers, criterion)
    else:
        model = SearchCNNController(args.input_channels, args.init_channels, len(dataset_train_val.train_classnames),
                                    args.Search_layers, criterion)

    model = model.to(device)
    if args.distributed:
        dist.sync_state(model)

    w_optim = torch.optim.SGD(model.weights(), args.w_lr, momentum=args.w_momentum,
                              weight_decay=args.w_weight_decay)

    alpha_optim = torch.optim.Adam(model.alphas(), args.alpha_lr, betas=(0.5, 0.999),
                                   weight_decay=args.alpha_weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, args.epochs, eta_min=args.w_lr_min)
    architect = Architect(model, args.w_momentum, args.w_weight_decay, args)

    ########## lauch training ###########

    if args.resume != '' and os.path.isfile(args.resume):
        if args.distributed:
            if rank == 0:
                print('resuem from [%s]' % config.resume)
            checkpoint = torch.load(
                args.resume,
                map_location='cuda:%d' % torch.cuda.current_device()
            )
        else:
            print('resuem from [%s]' % config.resume)
            checkpoint = torch.load(config.resume, map_location="cpu")

        model.load_state_dict(checkpoint['network'])
        # woptimizer.load_state_dict(checkpoint['optimizer'])
        # lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        epoch_offset = checkpoint['epoch']
    else:
        epoch_offset = 0

    model.to(device)

    if args.fp16:
        model, w_optim = amp.initialize(model, w_optim, opt_level="O1")

    for epoch in range(epoch_offset, args.epochs):
        if args.distributed:
            if rank == 0:
                model.print_alphas(logger)
        else:
            model.print_alphas(logger)

        # train
        if epoch % args.real_val_freq == 0:
            train(args, train_loader, val_loader, model, architect, w_optim, alpha_optim, lr_scheduler, epoch=epoch)
        else:
            train(args, train_loader, train_loader, model, architect, w_optim, alpha_optim, lr_scheduler, epoch=epoch)
            # validation
        cur_step = (epoch + 1) * len(train_loader)

        top1 = validate(args, val_loader, model, epoch=epoch, cur_step=cur_step)

        if args.distributed:
            if rank == 0:
                if best_top1 < top1:
                    best_top1 = top1
                    save_network(args, model, epoch, top1, isbest=True)
                else:
                    if epoch % args.forcesave == 0:
                        save_network(args, model, epoch, top1)
                writer.add_scalar('val/best_top1', best_top1, cur_step)

        else:
            if best_top1 < top1:
                best_top1 = top1
                save_network(args, model, epoch, top1, isbest=True)
            else:
                if epoch % args.forcesave == 0:
                    save_network(args, model, epoch, top1)

            writer.add_scalar('val/best_top1', best_top1, cur_step)

        if args.distributed:
            if rank == 0:
                logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
                # logger.info("Best Genotype = {}".format(best_genotype))
        else:
            logger.info("Final best Prec@1 = {:.4%}".format(best_top1))

        genotype = model.genotype()

        if args.distributed:

            if rank == 0:
                logger.info("genotype = {}".format(genotype))

                if args.plot_path != False:

                    plot_path = os.path.join(args.plot_path, args.task_name, "EP{:02d}".format(epoch + 1))
                    if not os.path.isdir(os.path.join(args.plot_path, args.task_name)):
                        os.makedirs(os.path.join(args.plot_path, args.task_name))
                    caption = "Epoch {}".format(epoch + 1)
                    plot(genotype.normal, plot_path + "-normal", caption)
                    plot(genotype.reduce, plot_path + "-reduce", caption)

                    writer.add_image(plot_path + '.png')

        else:
            logger.info("genotype = {}".format(genotype))

            if args.plot_path != False:
                if not os.path.isdir(os.path.join(args.plot_path, args.task_name)):
                    os.makedirs(os.path.join(args.plot_path, args.task_name))
                plot_path = os.path.join(args.plot_path, args.task_name, "EP{:02d}".format(epoch + 1))
                caption = "Epoch {}".format(epoch + 1)
                plot(genotype.normal, plot_path + "-normal", caption)
                plot(genotype.reduce, plot_path + "-reduce", caption)

                writer.add_image(plot_path + '.png')


if __name__ == '__main__':
    main()

6、Experiments

image.png
image.png
image.png

原文:集智书童
作者: ChaucerG

推荐阅读

推荐阅读
关注数
18838
内容数
1372
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息