基于Pytorch的多类别图像分类实战

2019 年 10 月 15 日

基于Pytorch的多类别图像分类实战


本篇基于Pytorch完成一个多类别图像分类实战。


1 简介



实现一个完整的图像分类任务,大致需要分为五个步骤:


1、选择开源框架


目前常用的深度学习框架主要包括 tensorflow、caffe、pytorch、mxnet 等;


2、构建并读取数据集


根据任务需求搜集相关图像搭建相应的数据集,常见的方式包括:网络爬虫、实地拍摄、公共数据使用等。随后根据所选开源框架读取数据集。


3、框架搭建


选择合适的网络模型、损失函数以及优化方式,以完成整体框架的搭建


4、训练并调试参数


通过训练选定合适超参数


5、测试准确率


在测试集上验证模型的最终性能


本文利用 Pytorch 框架,按照上述结构实现一个基本的图像分类任务,并详细阐述其中的细节及注意事项。


2 数据集



本次实战选择的数据集为 Kaggle 竞赛中的细胞数据集,共包含 9961 个训练样本,2491 个测试样本,可以分为嗜曙红细胞、淋巴细胞、单核细胞、中性白细胞 4 个类别,图片大小为 320x240。


Pytorch 中封装了相应的数据读取的类函数,通过调用 torch.utils.data.Datasets 函数,则可以实现读取功能。



init()模块用来定义相关的参数,len()模块用来获取训练样本个数,getitem()模块则用来获取每张具体的图片,在读取图片时其可以通过 opencv 库、PIL 库等进行读取,具体代码如下:


数据集


class dataset(data.Dataset):


# 参数预定义


def init(self, anno_pd, transforms=None):


self.paths = anno_pd[‘ImageName’].tolist()


self.labels = anno_pd[‘label’].tolist()


self.transforms = transforms


# 返回图片个数


def len(self):


return len(self.paths)


# 获取每个图片


def getitem(self, item):


img_path =self.paths[item]


img_id =img_path.split("/")[-1]


img =cv2.imread(img_path)


img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)


if self.transforms is not None:


img = self.transforms(img)


label = self.labels[item]


return torch.from_numpy(img).float(), int(label)


此外,需要定义图像增强模块,即上述代码中的 transform,通常采取的操作为翻转、剪切等,关于图像增强的具体介绍可以参考公众号前作。


【技术综述】深度学习中的数据增强方法都有哪些?


需要特别强调的是对图像进行去均值处理,很多同学不明白为何要减去均值,其主要的原因是图像作为一种平稳的数据分布,通过减去数据对应维度的统计平均值,可以消除公共部分,以凸显个体之间的特征和差异。进行去均值前后操作后的图像对比如下:



3 框架搭建


本次实战主要选取了 VGG16、Resnet50、InceptionV4 三个经典网络,也是对前篇文章的一个总结。


损失函数则选择交叉熵损失函数:【技术综述】一文道尽softmax loss及其变种


优化方式选择 SGD、Adam 优化两种:【模型训练】SGD的那些变种,真的比SGD强吗


4 训练及参数调试


初始学习率设置为 0.01,batch size 设置为 8,衰减率设置为 0.00001,迭代周期为 15,在不同框架组合下的最佳准确率和最低 loss 如下图所示:




可以发现在验证集上 Resnet-50+SGD+Cross Entropy 的组合下取得了 99%左右的准确率,相反 VGG-16 结果则稍微差一些。


最佳组合下的准确率走势曲线如下图所示:



5 测试


对上述模型分别在测试集上进行测试,所获得的结果如下图所示,整体精度比训练集上约下降了一个百分点:



总结


以上就是整个多类别图像分类实战的过程,由于时间限制,本次实战并没有对多个数据集进行练,因此没有列出同一模型在不同数据集上的表现。


作者介绍


郭冰洋,公众号“有三 AI”作者。该公号聚焦于让大家能够系统性地完成 AI 各个领域所需的专业知识的学习。


原文链接


https://mp.weixin.qq.com/s/jPpZLYXQBX7l5AUfFV5n3g


2019 年 10 月 15 日 16:411585

评论

发布
暂无评论
发现更多内容

你不知道的java对象序列化的秘密

程序那些事

Java java序列化 序列化的秘密

区块链需与5G等技术打好“组合拳”

CECBC区块链专委会

区块链 5G

一个草根的日常杂碎(10月11日)

刘新吾

随笔杂谈 生活记录 社会百态

优质数据库管理工具盘点,看看这三个软件的区别

CloudQuery社区

数据库 sql 云原生 工具 编辑器

技术解读丨分布式缓存数据库Redis大KEY问题定位及优化建议

华为云开发者社区

云计算 华为 技术

通俗易懂和你聊聊寄存器那些事(精美图文)

cxuan

后端 计算机 汇编

一个草根的日常杂碎(10月9日)

刘新吾

随笔杂谈 生活记录 社会百态

JDK14性能管理工具:jmap和jhat使用介绍

程序那些事

内存泄露 JDK14 jmap jhat

华为程序员发现孩子不是自己的,怒提离婚!女方不要孩子!绿他的竟然是个酒吧混混!

程序员生活志

华为 程序员

每个数据科学家都应该知道的5个概念

计算机与AI

学习 数据科学

Apple Developer 开发者账号申请&实名认证【2020】

iHTC

Apple Developer iOS Developer 苹果实名认证

正则表达式知识总结

iHTC

正则表达式

商业模式和盈利模式的思考

iHTC

商业模式 盈利模式 地摊经济

当我们在谈论跨平台的时候 ——— 我们在说什么

iHTC

跨平台

随想

Nydia

我们可以把Adapter精简到什么地步

mengxn

RecyclerView BetterAdapter Adapter

为什么学Go(二)

soolaugust

go

CECBC区块链专委会副主任吴桐主讲全国社保基金数字货币讲座

CECBC区块链专委会

区块链 数字货币

融合与共生之下,区块链都能“+”什么?

CECBC区块链专委会

区块链 大数据

第四周 系统架构学习总结

钟杰

极客大学架构师训练营

极客时间 - 架构师一期 - 第四周作业

_

第四周作业 架构师一期

OpenResty 项目脚手架

小铁匠

lua nginx openresty

iOS Handle Refunds 处理退款 --- WWDC20(Session 10661)

iHTC

WWDC2020 wwdc iap 苹果退款 iOS退款

面经手册 · 第13篇《除了JDK、CGLIB,还有3种类代理方式?面试又卡住!》

小傅哥

Java 字节码编程 asm 动态代理 cglib

数字货币交易所系统开发app,交易所搭建源码

WX13823153201

数字货币交易所系统开发

优秀开源项目、博客、书籍整理

小铁匠

收藏教程 资源汇总

JVM系列笔记 - 寄存器

朱华

JVM

第四周总结

_

极客大学架构师训练营 第四周总结

华为云专家带你解读文本情感分析任务

华为云开发者社区

内容 数据 分析

Guava-技术专题-Cache用法介绍

李浩宇/Alex

一个草根的日常杂碎(10月10日)

刘新吾

随笔杂谈 生活记录 社会百态

基于Pytorch的多类别图像分类实战-InfoQ