写点什么

基于 Pytorch 的细粒度图像分类实战

  • 2019 年 9 月 25 日
  • 本文字数:0 字

    阅读完需:约 1 分钟

基于Pytorch的细粒度图像分类实战


今天讲述基于 pytorch 的细粒度图像分类实战。


1 简介

针对传统的多类别图像分类任务,经典的 CNN 网络已经取得了非常优异的成绩,但在处理细粒度图像数据时,往往无法发挥自身的最大威力。


这是因为细粒度图像间存在更加相似的外观和特征,同时在采集中存在姿态、视角、光照、遮挡、背景干扰等影响,导致数据呈现类间差异性大、类内差异性小的现象,从而使分类更加具有难度。


为了改善经典 CNN 网络在细粒度图像分类中的表现,同时不借助其他标注信息,人们提出了双线性网络(Bilinear CNN)这一非常具有创意的结构,并在细粒度图像分类中取得了相当可观的进步。


本次实战将通过 CUB-200 数据集进行训练,对比经典 CNN 网络结构和双线性网络结构间的差异性。


2 数据集


首先我们回顾一下在多类别图像分类实战中所提出的图像分类任务的五个步骤。其中,在整个任务中最基础的一环就是根据数据集的构成编写相应的读取代码,这也是整个训练的关键所在。


本次实战选择的数据集为 CUB-200 数据集,该数据集是细粒度图像分类领域最经典,也是最常用的一个数据集。共包括 annotations、attributes、attributes-yaml、images、lists 五个文件夹。


此次实战中,我们只利用数据集提供的类别标注信息。因此只需要关注 lists 文件夹下的 train.txt 和 test.txt 文件即可。



通过图片我们可以看到,两个 txt 文件中给出了不同图片的相对路径,而开头数字则代表了对应的标记信息,但是 pytorch 中的标签必须从 0 开始,因此我们只需要借助 strip 和 split 函数即可完成图像和标签信息的获取。


txt 文件路径

path = ‘/media/by/Udata/Datasets/bird/lists/train.txt’


打开图像

txt = open(path,‘r’)


构建存储列表

imgs = []


读取每行信息

for line in txt:


   line = line.strip(’\n’)


   line = line.rstrip()


   # 将每行内容以’.‘为标记划分


   words = line.split(’.’)


   # 添加至列表


   imgs.append((line, int(words[0])-1))


输出结果示例如下图所示:



此时我们只需要将上述模块融合进 pytorch 的数据集读取模块即可,代码如下:


class cub_dataset(Dataset):


   def init(self, transform):


       fh = open(


        ‘/media/by/Udata/Datasets/bird/lists/train.txt’, ‘r’)


       imgs = []


       for line in fh:


           line = line.strip(’\n’)


           line = line.rstrip()


           words = line.split(’.’)


           imgs.append((line, int(words[0])-1))


self.imgs = imgs


       self.transform = transform


def getitem(self, index):


       fn, label = self.imgs[index]


       img = Image.open(


        ‘/media/by/Udata/Datasets/bird/images/’ + fn)


       img = self.transform(img)


return img, label


def len(self):


       return len(self.imgs)


3 网络搭建

本次实战主要选取了经典 Resnet 50 网络结构和基于 Resnet 50 的双线性网络结构。


Resnet 50 作为经典的分类网络,其结构不再赘述,在此详细介绍一下双线性网络的构建。



如上图所示,双线性网络包括两个分支 CNN 结构,这两个分支可以是相同的网络,也可以是不同的网络,本次实战使用 Resnet 50 做为相同的分支网络,以保证对比的客观性。


在此网络下将图像送入两个分支 Resnet 50 之后,把获取到的两个特征分支进行相应的融合操作。


具体代码如下:


class Net(nn.Module):


   def init(self):


       super(Net, self).init()


       self.features = nn.Sequential(resnet50().conv1, 


                                                      resnet50().bn1, 


                                                      resnet50().relu, 


                                                      resnet50().maxpool, 


                                                      resnet50().layer1,


                                                      resnet50().layer2,


                                                      resnet50().layer3,


                                                      resnet50().layer4)


       self.classifiers = nn.Sequential(


           nn.Linear(2048 ** 2, 200))


def forward(self, x):


       x = self.features(x)


       batch_size = x.size(0)


       x = x.view(batch_size, 2048, x.size(2) ** 2)


       x = (torch.bmm(x, 


       torch.transpose(x, 1, 2)) / 28 ** 2).view(batch_size, -1)


       x = torch.nn.functional.normalize(torch.sign(x) * 


              torch.sqrt(torch.abs(x) + 1e-10))


       x = self.classifiers(x)


       return x


4 训练及参数调试

损失函数选择交叉熵损失函数,优化方式选择 SGD 优化。初始学习率设置为 0.01,batch size 设置为 8,衰减率设置为 0.00001,迭代周期为 20,采用 top-5 评价指标


最终的训练结果如下图所示:



Resnet 50 最终取得的准确率约 52%左右,而基于 Resnet 50 的双线性网络取得了近 80%的准确率,由此可见不同的网络在细粒度分类任务上的性能差异非常巨大。


作者介绍


郭冰洋,公众号“有三 AI”作者。该公号聚焦于让大家能够系统性地完成 AI 各个领域所需的专业知识的学习。


原文链接


https://mp.weixin.qq.com/s/5Y4sQlt6DvgkAYtncByjzw


2019 年 9 月 25 日 16:551617

评论 1 条评论

发布
用户头像
你给我看这个 我可就不困了
2021 年 07 月 11 日 14:18
回复
没有更多了
发现更多内容

什么是接口的幂等性,如何实现接口幂等性?,java微服务架构视频下载

Java 程序员 后端

什么?这个岗位薪资秒杀一众程序员?,java技术面试常见问题

Java 程序员 后端

从 0 到 1,带你解剖 MVP 的神秘之处,并自己动手实现 MVP !

Java 程序员 后端

从腾讯T3-3大佬手上获得的Java架构进阶PDF文档,图文并茂,真香

Java 程序员 后端

微服务可观测平台-总体设计

X

微服务 性能 metrics 业务 可观测

下班约会时来了新需求,咋办?

华为云开发者社区

ide 开发 代码 华为云 华为云DevStar

优化技术专题-线程间的高性能消息框架-深入浅出Disruptor的使用和原理

Java 程序员 后端

什么是 MySQL 全局锁、表锁、行锁,Java高级开发岗必问知识点

Java 程序员 后端

从Mybatis源码到Spring动态数据源底层原理分析系列二、Mybatis执行器源码分析

Java 程序员 后端

从腾讯T3-3大佬手上获得的Java架构进阶PDF文档,图文并茂,真香(1)

Java 程序员 后端

作为java程序员,在金三银四季你遇到过哪些质量很高的java面试?

Java 程序员 后端

作为分布式服务框架,我用大白话给你解释Zookeeper的选举机制!

Java 程序员 后端

人性的弱点-读书笔记,java面试题大汇总小山博客

Java 程序员 后端

什么神仙笔记!阿里P9用39实例+1项目讲明白了Spring Cloud家族

Java 程序员 后端

从一次线下读书会获得的收获,linux使用教程

Java 程序员 后端

从小公司跳槽到阿里,靠着刷多套面试题,成功拿到蚂蚁金服P7Offer

Java 程序员 后端

从某度外包逆袭成为阿里架构师,分享我的Java进阶成长笔记

Java 程序员 后端

以GraalVM原生镜像的方式运行Spring Boot应用程序,mybatisjoin原理

Java 程序员 后端

从Java小白到拿到30k offer,分享自己的学习路程,java基础案例教程pdf百度云

Java 程序员 后端

从单体式架构迁移到微服务架构,3年Java开发工程师面试经验分享

Java 程序员 后端

今日头条一面:十道经典面试题解析,Redis如何实现高可扩展

Java 程序员 后端

架构实战营-毕业总结

王晓宇

架构实战营

代码简洁之道--笔记,2021华为Java面试真题

Java 程序员 后端

今年,我在字节跳动面试了九次【已意向书,mongodb入门pdf

Java 程序员 后端

从构建小系统到架构分布式大系统,Spring Boot2的精髓全在这里了

Java 程序员 后端

人到中年的焦虑,Java面试第一问就是做过什么最有难度的项目

Java 程序员 后端

人工智能 - 语音识别的技术原理是什么,Java理论知识思维导图

Java 程序员 后端

优质高效!基于Spring-boot-admin的微服务监控系统实现

Java 程序员 后端

从Mybatis源码到Spring动态数据源底层原理分析系列一、Mybatis初始化源码浅析

Java 程序员 后端

CANN 5.0黑科技解密 | 算力虚拟化,让AI算力“物尽其用”

华为云开发者社区

AI 算力 CANN 昇腾 算力虚拟化

优秀如我毅然和女票分手,面试字节跳动技术四面吊打面试官,终获取到Offer

Java 程序员 后端

基于Pytorch的细粒度图像分类实战_AI_郭冰洋_InfoQ精选文章