Linux 之父出席、干货分享、圆桌讨论,精彩尽在 OpenCloudOS 社区开放日,报名戳 了解详情
写点什么

基于 Pytorch 的细粒度图像分类实战

  • 2019 年 9 月 25 日
  • 本文字数:0 字

    阅读完需:约 1 分钟

基于Pytorch的细粒度图像分类实战


今天讲述基于 pytorch 的细粒度图像分类实战。


1 简介

针对传统的多类别图像分类任务,经典的 CNN 网络已经取得了非常优异的成绩,但在处理细粒度图像数据时,往往无法发挥自身的最大威力。


这是因为细粒度图像间存在更加相似的外观和特征,同时在采集中存在姿态、视角、光照、遮挡、背景干扰等影响,导致数据呈现类间差异性大、类内差异性小的现象,从而使分类更加具有难度。


为了改善经典 CNN 网络在细粒度图像分类中的表现,同时不借助其他标注信息,人们提出了双线性网络(Bilinear CNN)这一非常具有创意的结构,并在细粒度图像分类中取得了相当可观的进步。


本次实战将通过 CUB-200 数据集进行训练,对比经典 CNN 网络结构和双线性网络结构间的差异性。


2 数据集


首先我们回顾一下在多类别图像分类实战中所提出的图像分类任务的五个步骤。其中,在整个任务中最基础的一环就是根据数据集的构成编写相应的读取代码,这也是整个训练的关键所在。


本次实战选择的数据集为 CUB-200 数据集,该数据集是细粒度图像分类领域最经典,也是最常用的一个数据集。共包括 annotations、attributes、attributes-yaml、images、lists 五个文件夹。


此次实战中,我们只利用数据集提供的类别标注信息。因此只需要关注 lists 文件夹下的 train.txt 和 test.txt 文件即可。



通过图片我们可以看到,两个 txt 文件中给出了不同图片的相对路径,而开头数字则代表了对应的标记信息,但是 pytorch 中的标签必须从 0 开始,因此我们只需要借助 strip 和 split 函数即可完成图像和标签信息的获取。


txt 文件路径

path = ‘/media/by/Udata/Datasets/bird/lists/train.txt’


打开图像

txt = open(path,‘r’)


构建存储列表

imgs = []


读取每行信息

for line in txt:


   line = line.strip(’\n’)


   line = line.rstrip()


   # 将每行内容以’.‘为标记划分


   words = line.split(’.’)


   # 添加至列表


   imgs.append((line, int(words[0])-1))


输出结果示例如下图所示:



此时我们只需要将上述模块融合进 pytorch 的数据集读取模块即可,代码如下:


class cub_dataset(Dataset):


   def init(self, transform):


       fh = open(


        ‘/media/by/Udata/Datasets/bird/lists/train.txt’, ‘r’)


       imgs = []


       for line in fh:


           line = line.strip(’\n’)


           line = line.rstrip()


           words = line.split(’.’)


           imgs.append((line, int(words[0])-1))


self.imgs = imgs


       self.transform = transform


def getitem(self, index):


       fn, label = self.imgs[index]


       img = Image.open(


        ‘/media/by/Udata/Datasets/bird/images/’ + fn)


       img = self.transform(img)


return img, label


def len(self):


       return len(self.imgs)


3 网络搭建

本次实战主要选取了经典 Resnet 50 网络结构和基于 Resnet 50 的双线性网络结构。


Resnet 50 作为经典的分类网络,其结构不再赘述,在此详细介绍一下双线性网络的构建。



如上图所示,双线性网络包括两个分支 CNN 结构,这两个分支可以是相同的网络,也可以是不同的网络,本次实战使用 Resnet 50 做为相同的分支网络,以保证对比的客观性。


在此网络下将图像送入两个分支 Resnet 50 之后,把获取到的两个特征分支进行相应的融合操作。


具体代码如下:


class Net(nn.Module):


   def init(self):


       super(Net, self).init()


       self.features = nn.Sequential(resnet50().conv1, 


                                                      resnet50().bn1, 


                                                      resnet50().relu, 


                                                      resnet50().maxpool, 


                                                      resnet50().layer1,


                                                      resnet50().layer2,


                                                      resnet50().layer3,


                                                      resnet50().layer4)


       self.classifiers = nn.Sequential(


           nn.Linear(2048 ** 2, 200))


def forward(self, x):


       x = self.features(x)


       batch_size = x.size(0)


       x = x.view(batch_size, 2048, x.size(2) ** 2)


       x = (torch.bmm(x, 


       torch.transpose(x, 1, 2)) / 28 ** 2).view(batch_size, -1)


       x = torch.nn.functional.normalize(torch.sign(x) * 


              torch.sqrt(torch.abs(x) + 1e-10))


       x = self.classifiers(x)


       return x


4 训练及参数调试

损失函数选择交叉熵损失函数,优化方式选择 SGD 优化。初始学习率设置为 0.01,batch size 设置为 8,衰减率设置为 0.00001,迭代周期为 20,采用 top-5 评价指标


最终的训练结果如下图所示:



Resnet 50 最终取得的准确率约 52%左右,而基于 Resnet 50 的双线性网络取得了近 80%的准确率,由此可见不同的网络在细粒度分类任务上的性能差异非常巨大。


作者介绍


郭冰洋,公众号“有三 AI”作者。该公号聚焦于让大家能够系统性地完成 AI 各个领域所需的专业知识的学习。


原文链接


https://mp.weixin.qq.com/s/5Y4sQlt6DvgkAYtncByjzw


2019 年 9 月 25 日 16:551655

评论 1 条评论

发布
用户头像
你给我看这个 我可就不困了
2021 年 07 月 11 日 14:18
回复
没有更多了
发现更多内容

链路层的封装成帧和透明传输基本问题

Regan Yue

计算机网络 10月日更

架构实战营_模块六作业_拆分电商系统为微服务

Rabbit

限时开源!阿里内部爆款的顶配版Spring Security笔记

Java spring 编程 架构 面试

同事跳槽阿里,临走甩给一份上千页的Linux源码笔记,真香

Java 程序员 架构 面试 后端

Java通过socket和DTU,RTU连接工业传感器通信

叫练

socket Modbus协议 java DTU RTU

SpringBoot 实战:JUnit5+MockMvc+Mockito 做好单元测试

看山

Java Spring Boot 10月日更 Effective Spring

《第2章 开始学习C++》

IT蜗壳-Tango

10月日更

千万级学生管理系统的考试试卷存储方案

刘琦Logan

在线最大公因数计算器

入门小站

工具

Prometheus 基本查询(二)时序数据的瞬时向量

耳东@Erdong

Prometheus 10月日更

【Android构建新工具】Bazel构建工具介绍

轻口味

android 构建工具 10月日更

linux之sudo使用技巧汇总

入门小站

Linux

太厉害了,阿里大佬用一篇神文把《数据结构与算法》讲的明明白白

程序员小呆

Java 程序员 架构师

百度智能云布局粤港澳大湾区,打造AI+工业互联网新高地

百度大脑

人工智能 百度

图解分布式之:最终一致性,一致只会迟到,但绝不缺席

普普通通程序员

区块链与智能革命的未来

CECBC

真香!肝完Alibaba这份面试通关宝典,我成功拿下今年第15个Offer

收到请回复

Java 面试 大厂Offer 20+大厂面经

阿里P8高级架构师开发高并发系统经验总结

Java 程序员 架构 面试 后端

ThreadPoolExecutor学习笔记

六维

ThreadPoolExecutor 10月日更

为何实现碳中和已刻不容缓?

CECBC

阿里内部教程:千页Redis源码笔记,涨薪必备

Java 程序员 架构 面试 后端

架构实战营 - 模块五作业

Alex.Wu

汽车的新能源之变,不仅在一块电池

脑极体

Mock Service Worker:可用于浏览器的Mock服务

devpoint

Vue Mock 10月日更 msw

CSS架构之Acss层

Augus

CSS 10月日更

Leetcode 题目解析:287. 寻找重复数

程序员架构进阶

算法 LeetCode 10月日更

【Vuex 源码学习】第十三篇 - Vuex 辅助函数的实现

Brave

源码 vuex 10月日更

生命中不重要的九件事情

石云升

10月日更

深入理解Java虚拟机之JVM内存布局篇

普普通通程序员

Go 中 Nil 理论上有类型,实践中无类型

baiyutang

golang 10月日更

绿色电力交易是一场迫在眉睫,区块链记录每一笔绿色电力交易

CECBC

GPU容器虚拟化:用户态和内核态的技术和实践详解

GPU容器虚拟化:用户态和内核态的技术和实践详解

基于Pytorch的细粒度图像分类实战_AI_郭冰洋_InfoQ精选文章