ronghuaiyang · 2021年05月07日

使用PyTorch来进展不平衡数据集的图像分类

首发:AI公园公众号
作者:Marek Paulik
编译:ronghuaiyang

导读

一个非常简单和容易上手的例子。

image.png

对于教程中使用的大多数人工数据集,每个类都有相同数量的数据。然而,在实际应用中,这种情况很少发生。今天,我将给你介绍来自Kaggle的木薯叶分类,并告诉你当类频率有很大差异时该怎么做。

处理类别的不平衡

有两种方法可以解决这个问题。

  • WeightedRandomSampler
  • loss函数中的weight参数

下一步是创建一个有5个方法的CassavaClassifier类:load\_data()、load\_model()、fit\_one\_epoch()、val\_one\_epoch()和fit()。

在load\_data()中,将构造一个train和验证数据集,并返回数据加载器以供进一步使用。

在load\_model()中定义了体系结构、损失函数和优化器。

fit方法包含一些初始化和对fit\_one\_epoch()和val\_one\_epoch()的循环。

早期停止

早期停止类有助于根据验证损失跟踪最佳模型,并保存检查点。

#Callbacks# Early stoppingclass EarlyStopping:  def __init__(self, patience=1, delta=0, path='checkpoint.pt'):    self.patience = patience    self.delta = delta    self.path= path    self.counter = 0    self.best_score = None    self.early_stop = False  def __call__(self, val_loss, model):    if self.best_score is None:      self.best_score = val_loss      self.save_checkpoint(model)    elif val_loss > self.best_score:      self.counter +=1      if self.counter >= self.patience:        self.early_stop = True     else:      self.best_score = val_loss      self.save_checkpoint(model)      self.counter = 0        def save_checkpoint(self, model):    torch.save(model.state_dict(), self.path)

Init

我们首先初始化CassavaClassifier类。

class CassavaClassifier():    def __init__(self, data_dir, num_classes, device, Transform=None, sample=False, loss_weights=False, batch_size=16,     lr=1e-4, stop_early=True, freeze_backbone=True):    #############################################################################################################    # data_dir - directory with images in subfolders, subfolders name are categories    # Transform - data augmentations    # sample - if the dataset is imbalanced set to true and RandomWeightedSampler will be used    # loss_weights - if the dataset is imbalanced set to true and weight parameter will be passed to loss function    # freeze_backbone - if using pretrained architecture freeze all but the classification layer    ###############################################################################################################        self.data_dir = data_dir        self.num_classes = num_classes        self.device = device        self.sample = sample        self.loss_weights = loss_weights        self.batch_size = batch_size        self.lr = lr        self.stop_early = stop_early        self.freeze_backbone = freeze_backbone        self.Transform = Transform

Load Data

训练图像被组织在子文件夹中,子文件夹名称表示图像的类。这是图像分类问题的典型情况,幸运的是,不需要编写自定义数据集类。在这种情况下,可以立即使用torchvision中的ImageFolder。如果你想使用WeightedRandomSampler,你需要为数据集的每个元素指定一个权重。通常,总图像总比上类别数被用作一个权重。

def load_data(self):    train_full = torchvision.datasets.ImageFolder(self.data_dir, transform=self.Transform)    train_set, val_set = random_split(train_full, [math.floor(len(train_full)*0.8), math.ceil(len(train_full)*0.2)])    self.train_classes = [label for _, label in train_set]    if self.sample:        # Need to get weight for every image in the dataset        class_count = Counter(self.train_classes)        class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values])         # Can't iterate over class_count because dictionary is unordered        sample_weights = [0] * len(train_set)        for idx, (image, label) in enumerate(train_set):            class_weight = class_weights[label]            sample_weights[idx] = class_weight        sampler = WeightedRandomSampler(weights=sample_weights,                                        num_samples = len(train_set), replacement=True)          train_loader = DataLoader(train_set, batch_size=self.batch_size, sampler=sampler)    else:        train_loader = DataLoader(train_set, batch_size=self.batch_size, shuffle=True)    val_loader = DataLoader(val_set, batch_size=self.batch_size)    return train_loader, val_loader

Load Model

在该方法中,我使用迁移学习,架构参数从预先训练的resnet50和efficientnet-b7中选择。CrossEntropyLoss和许多其他损失函数都有权重参数。这是一个手动调整参数,用于处理不平衡。在这种情况下,不需要为每个参数定义权重,只需为每个类定义权重。

def load_model(self, arch='resnet'):    ##############################################################################################################    # arch - choose the pretrained architecture from resnet or efficientnetb7    ##############################################################################################################     if arch == 'resnet':        self.model = torchvision.models.resnet50(pretrained=True)        if self.freeze_backbone:            for param in self.model.parameters():                param.requires_grad = False        self.model.fc = nn.Linear(in_features=self.model.fc.in_features, out_features=self.num_classes)    elif arch == 'efficient-net':        self.model = EfficientNet.from_pretrained('efficientnet-b7')        if self.freeze_backbone:            for param in self.model.parameters():                param.requires_grad = False        self.model._fc = nn.Linear(in_features=self.model._fc.in_features, out_features=self.num_classes)        self.model = self.model.to(self.device)    self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr)     if self.loss_weights:        class_count = Counter(self.train_classes)        class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values])        # Cant iterate over class_count because dictionary is unordered        class_weights = class_weights.to(self.device)          self.criterion = nn.CrossEntropyLoss(class_weights)    else:        self.criterion = nn.CrossEntropyLoss() 

Fit One Epoch

这个方法只包含一个经典的训练循环,带有训练损失记录和tqdm进度条。

def fit_one_epoch(self, train_loader, epoch, num_epochs ):     step_train = 0    train_losses = list() # Every epoch check average loss per batch     train_acc = list()    self.model.train()    for i, (images, targets) in enumerate(tqdm(train_loader)):        images = images.to(self.device)        targets = targets.to(self.device)        logits = self.model(images)        loss = self.criterion(logits, targets)        loss.backward()        self.optimizer.step()        self.optimizer.zero_grad()        train_losses.append(loss.item())        #Calculate running train accuracy        predictions = torch.argmax(logits, dim=1)        num_correct = sum(predictions.eq(targets))        running_train_acc = float(num_correct) / float(images.shape[0])        train_acc.append(running_train_acc)            train_loss = torch.tensor(train_losses).mean()        print(f'Epoch {epoch}/{num_epochs-1}')      print(f'Training loss: {train_loss:.2f}')

Validate one epoch

与上面类似,但此方法在验证数据加载器上迭代。在每一个epoch'之后,平均batch损失和准确性被打印出来。

def val_one_epoch(self, val_loader, scaler):        val_losses = list()        val_accs = list()        self.model.eval()        step_val = 0        with torch.no_grad():            for (images, targets) in val_loader:                images = images.to(self.device)                targets = targets.to(self.device)                logits = self.model(images)                loss = self.criterion(logits, targets)                val_losses.append(loss.item())                                  predictions = torch.argmax(logits, dim=1)                num_correct = sum(predictions.eq(targets))                running_val_acc = float(num_correct) / float(images.shape[0])                val_accs.append(running_val_acc)                        self.val_loss = torch.tensor(val_losses).mean()            val_acc = torch.tensor(val_accs).mean() # Average acc per batch                    print(f'Validation loss: {self.val_loss:.2f}')              print(f'Validation accuracy: {val_acc:.2f}') 

Fit

Fit方法在训练和验证过程中经历了许多阶段和循环。如果预训练模型的参数在开始时被冻结,那么unfreeze\_after定义了整个模型在多少个epoch之后开始训练。在此之前,只训练全连接层(分类器)。

def fit(self, train_loader, val_loader, num_epochs=10, unfreeze_after=5, checkpoint_dir='checkpoint.pt'):    if self.stop_early:        early_stopping = EarlyStopping(        patience=5,         path=checkpoint_dir)      for epoch in range(num_epochs):        if self.freeze_backbone:            if epoch == unfreeze_after:  # Unfreeze after x epochs                for param in self.model.parameters():                    param.requires_grad = True        self.fit_one_epoch(train_loader, scaler, epoch, num_epochs)        self.val_one_epoch(val_loader, scaler)        if self.stop_early:            early_stopping(self.val_loss, self.model)            if early_stopping.early_stop:                print('Early Stopping')                print(f'Best validation loss: {early_stopping.best_score}')                break

Run

现在,可以初始化CassavaClassifier类、创建dataloaders、设置模型并运行整个过程了。

Transform = T.Compose(                    [T.ToTensor(),                    T.Resize((256, 256)),                    T.RandomRotation(90),                    T.RandomHorizontalFlip(p=0.5),                    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')data_dir = "Data/cassava-disease/train/train"classifier = CassavaClassifier(data_dir=data_dir, num_classes=5, device=device, sample=True, Transform=Transform)train_loader, val_loader = classifier.load_data()classifier.load_model()classifier.fit(num_epochs=20, unfreeze_after=5, train_loader=train_loader, val_loader=val_loader)

Inference

使用ImageFolder加载测试数据是不可能的,因为显然没有带有类的子文件夹。因此,我创建了一个返回图像和图像id的自定义数据集。随后,加载模型检查点,通过推理循环运行它,并将预测保存到数据帧中。将数据帧导出为CSV并提交结果。

# Inferencemodel = torchvision.models.resnet50()#model = EfficientNet.from_name('efficientnet-b7')model.fc = nn.Linear(in_features=model.fc.in_features, out_features=5)model = model.to(device)checkpoint = torch.load('Data/cassava-disease/sampler_checkpoint.pt')model.load_state_dict(checkpoint)model.eval()# Dataset for test dataclass Cassava_Test(Dataset):  def __init__(self, dir, transform=None):    self.dir = dir    self.transform = transform    self.images = os.listdir(self.dir)    def __len__(self):    return len(self.images)  def __getitem__(self, idx):    img = Image.open(os.path.join(self.dir, self.images[idx]))    return self.transform(img), self.images[idx] test_dir = 'Data/cassava-disease/test/test/0'test_set = Cassava_Test(test_dir, transform=Transform)test_loader = DataLoader(test_set, batch_size=4)  # Test loopsub = pd.DataFrame(columns=['category', 'id'])id_list = []pred_list = []model = model.to(device)with torch.no_grad():  for (image, image_id) in test_loader:    image = image.to(device)    logits = model(image)    predicted = list(torch.argmax(logits, 1).cpu().numpy())    for id in image_id:      id_list.append(id)      for prediction in predicted:      pred_list.append(prediction)sub['category'] = pred_listsub['id'] = id_listmapping = {0:'cbb', 1:'cbsd', 2:'cgm', 3:'cmd', 4:'healthy'}sub['category'] = sub['category'].map(mapping)sub = sub.sort_values(by='id')sub.to_csv('Cassava_sub.csv', index=False)

如果在方案中包含WeightedRandomSampler或损失权值,则测试集的精度会提高2%。对于仅仅几行代码来说,这是一个很好的改进。对于这个数据集,我没有看到这两种方法在精度上的巨大差异,但WeightedRandomSampler的表现要好一些。

不同的学习速度、优化器和数据扩展肯定有自己的发展空间。然而,对于这种简单的方法来说,86%的准确率似乎足够好了。

—END—

英文原文:https://marekpaulik.medium.co...

推荐阅读

关注图像处理,自然语言处理,机器学习等人工智能领域,请点击关注AI公园专栏
欢迎关注微信公众号
AI公园 公众号二维码.jfif
推荐阅读
关注数
8259
内容数
210
关注图像处理,NLP,机器学习等人工智能领域
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息