基于 Pytorch 的细粒度图像分类实战

阅读数:321 2019 年 9 月 25 日 16:55

基于Pytorch的细粒度图像分类实战

基于Pytorch的细粒度图像分类实战

今天讲述基于 pytorch 的细粒度图像分类实战。

1 简介

针对传统的多类别图像分类任务,经典的 CNN 网络已经取得了非常优异的成绩,但在处理细粒度图像数据时,往往无法发挥自身的最大威力。

这是因为细粒度图像间存在更加相似的外观和特征,同时在采集中存在姿态、视角、光照、遮挡、背景干扰等影响,导致数据呈现类间差异性大、类内差异性小的现象,从而使分类更加具有难度。

为了改善经典 CNN 网络在细粒度图像分类中的表现,同时不借助其他标注信息,人们提出了双线性网络(Bilinear CNN)这一非常具有创意的结构,并在细粒度图像分类中取得了相当可观的进步。

本次实战将通过 CUB-200 数据集进行训练,对比经典 CNN 网络结构和双线性网络结构间的差异性。

2 数据集

基于Pytorch的细粒度图像分类实战
首先我们回顾一下在多类别图像分类实战中所提出的图像分类任务的五个步骤。其中,在整个任务中最基础的一环就是根据数据集的构成编写相应的读取代码,这也是整个训练的关键所在。

本次实战选择的数据集为 CUB-200 数据集,该数据集是细粒度图像分类领域最经典,也是最常用的一个数据集。共包括 annotations、attributes、attributes-yaml、images、lists 五个文件夹。

此次实战中,我们只利用数据集提供的类别标注信息。因此只需要关注 lists 文件夹下的 train.txt 和 test.txt 文件即可。
基于Pytorch的细粒度图像分类实战

通过图片我们可以看到,两个 txt 文件中给出了不同图片的相对路径,而开头数字则代表了对应的标记信息,但是 pytorch 中的标签必须从 0 开始,因此我们只需要借助 strip 和 split 函数即可完成图像和标签信息的获取。

txt 文件路径

path = ‘/media/by/Udata/Datasets/bird/lists/train.txt’

打开图像

txt = open(path,‘r’)

构建存储列表

imgs = []

读取每行信息

for line in txt:
   line = line.strip(’\n’)
   line = line.rstrip()
   # 将每行内容以’.‘为标记划分
   words = line.split(’.’)
   # 添加至列表
   imgs.append((line, int(words[0])-1))

输出结果示例如下图所示:

基于Pytorch的细粒度图像分类实战

此时我们只需要将上述模块融合进 pytorch 的数据集读取模块即可,代码如下:

class cub_dataset(Dataset):
   def init(self, transform):
       fh = open(
        ‘/media/by/Udata/Datasets/bird/lists/train.txt’, ‘r’)
       imgs = []
       for line in fh:
           line = line.strip(’\n’)
           line = line.rstrip()
           words = line.split(’.’)
           imgs.append((line, int(words[0])-1))

self.imgs = imgs
       self.transform = transform

def getitem(self, index):
       fn, label = self.imgs[index]
       img = Image.open(
        ‘/media/by/Udata/Datasets/bird/images/’ + fn)
       img = self.transform(img)

return img, label

def len(self):
       return len(self.imgs)

3 网络搭建

本次实战主要选取了经典 Resnet 50 网络结构和基于 Resnet 50 的双线性网络结构。

Resnet 50 作为经典的分类网络,其结构不再赘述,在此详细介绍一下双线性网络的构建。
基于Pytorch的细粒度图像分类实战

如上图所示,双线性网络包括两个分支 CNN 结构,这两个分支可以是相同的网络,也可以是不同的网络,本次实战使用 Resnet 50 做为相同的分支网络,以保证对比的客观性。

在此网络下将图像送入两个分支 Resnet 50 之后,把获取到的两个特征分支进行相应的融合操作。

具体代码如下:

class Net(nn.Module):
   def init(self):
       super(Net, self).init()
       self.features = nn.Sequential(resnet50().conv1, 
                                                      resnet50().bn1, 
                                                      resnet50().relu, 
                                                      resnet50().maxpool, 
                                                      resnet50().layer1,
                                                      resnet50().layer2,
                                                      resnet50().layer3,
                                                      resnet50().layer4)
       self.classifiers = nn.Sequential(
           nn.Linear(2048 ** 2, 200))

def forward(self, x):
       x = self.features(x)
       batch_size = x.size(0)
       x = x.view(batch_size, 2048, x.size(2) ** 2)
       x = (torch.bmm(x, 
       torch.transpose(x, 1, 2)) / 28 ** 2).view(batch_size, -1)
       x = torch.nn.functional.normalize(torch.sign(x) * 
              torch.sqrt(torch.abs(x) + 1e-10))
       x = self.classifiers(x)
       return x

4 训练及参数调试

损失函数选择交叉熵损失函数,优化方式选择 SGD 优化。初始学习率设置为 0.01,batch size 设置为 8,衰减率设置为 0.00001,迭代周期为 20,采用 top-5 评价指标

最终的训练结果如下图所示:
基于Pytorch的细粒度图像分类实战

Resnet 50 最终取得的准确率约 52% 左右,而基于 Resnet 50 的双线性网络取得了近 80% 的准确率,由此可见不同的网络在细粒度分类任务上的性能差异非常巨大。

作者介绍

郭冰洋,公众号“有三 AI”作者。该公号聚焦于让大家能够系统性地完成 AI 各个领域所需的专业知识的学习。

原文链接

https://mp.weixin.qq.com/s/5Y4sQlt6DvgkAYtncByjzw

评论

发布