【专题推荐】AI大模型落地的前景和痛点,技术人面临哪些机会和挑战? 了解详情
写点什么

使用 AWS EC2 上的 Apache MXNet 和 Multimedia Commons 数据集来估计图像位置

  • 2019-11-11
  • 本文字数:4408 字

    阅读完需:约 14 分钟

使用 AWS EC2 上的 Apache MXNet 和 Multimedia Commons 数据集来估计图像位置

这是由国际计算机科学研究院的 Jaeyoung Choi 和加州大学伯克利分校的 Kevin Li 所著的一篇访客文章。本项目演示学术研究人员如何利用我们的 AWS Cloud Credits for Research Program 实现科学突破。


当您拍摄照片时,现代移动设备可以自动向图像分配地理坐标。不过,网络上的大多数图像仍缺少该位置元数据。图像定位是估计图像位置并应用位置标签的过程。根据您的数据集大小以及提出问题的方式,分配的位置标签可以是建筑物或地标名称或实际地理坐标 (纬度、经度)。


在本文中,我们会展示如何使用通过 Apache MXNet 创建的预训练模型对图像进行地理分类。我们使用的数据集包含拍摄于全球各地的数百万张 Flickr 图像。我们还会展示如何将结果制成地图以直观地显示结果。

我们的方法

图像定位方法可以分为两类:图像检索搜索法和分类法。(该博文将对这两个类别中最先进的方法进行比较。)


Weyand 等人近期的作品提出图像定位是一个分类问题。在这种方法中,作者将地球表面细分为数千个地理单元格,并利用带地理标记的图像训练了深层神经网路。有关他们的试验更通俗的描述,请参阅该文章


由于作者没有公开他们的训练数据或训练模型 (即 PlaNet),因此我们决定训练我们自己的图像定位器。我们训练模型的场景灵感来自于 Weyand 等人描述的方法,但是我们对几个设置作了改动。


我们在单个 p2.16xlarge 实例上使用 MXNet 来训练我们的模型 LocationNet,该实例包含来自 AWS Multimedia Commons 数据集的带有地理标记的图像。


我们将训练、验证和测试图像分离,以便同一人上传的图像不会出现在多个集合中。我们使用 Google 的 S2 Geometry Library 通过训练数据创建类。该模型经过 12 个训练周期后收敛,完成 p2.16xlarge 实例训练大约花了 9 天时间。GitHub 上提供了采用 Jupyter Notebook 的完整教程


下表对用于训练和测试 LocationNet 和 PlaNet 的设置进行了比较。


col 1col 2col 3


    | LocationNet                                                    | PlaNet                           
复制代码


数据集来源 | Multimedia Commons | 从网络抓取的图像


训练集 | 3390 万 | 9100 万


验证 | 180 万 | 3400 万


S2 单元分区 | t1=5000, t2=500


→ 15,527 个单元格 | t1=10,000, t2=50


→ 26,263 个单元格


模型 | ResNet-101 | GoogleNet


优化 | 使用动量和 LR 计划的 SGD | Adagrad


训练时间 | 采用 16 个 NVIDIA K80 GPU (p2.16xlarge EC2 实例) 时为 9 天


12 个训练周期 | 采用 200 个 CPU 内核时为两个半月


框架 | MXNet | DistBelief


测试集 | Placing Task 2016 测试集 (150 万张 Flickr 图像) | 230 万张有地理标记的 Flickr 图像


在推理时,LocationNet 会输出地理单元格间的概率分布。单元格中概率最高的图像的质心地理坐标会被分配为查询图像的地理坐标。


LocationNet 会在 MXNet Model Zoo 中公开分享。

下载 LocationNet

现在下载 LocationNet 预训练模型。LocationNet 已使用 AWS Multimedia Commons 数据集中带地理标记的图像子集进行了训练。Multimedia Commons 数据集包含 3900 多万张图像和 15000 个地理单元格 (类)。


LocationNet 包括两部分:一个包含模型定义的 JSON 文件和一个包含参数的二进制文件。我们从 S3 加载必要的软件包并下载文件。


Java


import os
import urllib
import mxnet as mx
import logging
import numpy as np
from skimage import io, transform
from collections import namedtuple
from math import radians, sin, cos, sqrt, asin
path = 'https://s3.amazonaws.com/mmcommons-tutorial/models/'
model_path = 'models/'
if not os.path.exists(model_path):
os.mkdir(model_path)
urllib.urlretrieve(path+'RN101-5k500-symbol.json', model_path+'RN101-5k500-symbol.json')
urllib.urlretrieve(path+'RN101-5k500-0012.params', model_path+'RN101-5k500-0012.params')
复制代码


然后,加载下载的模型。如果您没有可用 GPU,请将 mx.gpu() 替换为 mx.cpu():


Java


# Load the pre-trained model
prefix = "models/RN101-5k500"
load_epoch = 12
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, load_epoch)
mod = mx.mod.Module(symbol=sym, context=mx.gpu())
mod.bind([('data', (1,3,224,224))], for_training=False)mod.set_params(arg_params, aux_params, allow_missing=True)
复制代码


grids.txt 文件包含用于训练模型的地理单元格。


第 i 行是第 i 个类,列分别代表:S2 单元格标记、纬度和经度。我们将标签加载到名为 grids 的列表中。


Java


# Download and load grids file
urllib.urlretrieve('https://raw.githubusercontent.com/multimedia-berkeley/tutorials/master/grids.txt','grids.txt')
# Load labels.
grids = []
with open('grids.txt', 'r') as f:
for line in f:
line = line.strip().split('\t')
lat = float(line[1])
lng = float(line[2])
grids.append((lat, lng))
复制代码


该模型使用半径公式来测量点 p1 和 p2 之间的大圆弧距离,以千米为单位:


Java


def distance(p1, p2):
R = 6371 # Earth radius in km
lat1, lng1, lat2, lng2 = map(radians, (p1[0], p1[1], p2[0], p2[1]))
dlat = lat2 - lat1
dlng = lng2 - lng1
a = sin(dlat * 0.5) ** 2 + cos(lat1) * cos(lat2) * (sin(dlng * 0.5) ** 2)
复制代码


Java


return 2 * R * asin(sqrt(a))
复制代码


在将图像提供给深度学习网络之前,该模型会通过裁剪以及减去均值来预处理图像:


Java


# mean image for preprocessing
mean_rgb = np.array([123.68, 116.779, 103.939])
mean_rgb = mean_rgb.reshape((3, 1, 1))
def PreprocessImage(path, show_img=False):
# load image.
img = io.imread(path)
# We crop image from center to get size 224x224.
short_side = min(img.shape[:2])
yy = int((img.shape[0] - short_side) / 2)
xx = int((img.shape[1] - short_side) / 2)
crop_img = img[yy : yy + short_side, xx : xx + short_side]
resized_img = transform.resize(crop_img, (224,224))
if show_img:
io.imshow(resized_img)
# convert to numpy.ndarray
sample = np.asarray(resized_img) * 256
# swap axes to make image from (224, 224, 3) to (3, 224, 224)
sample = np.swapaxes(sample, 0, 2)
sample = np.swapaxes(sample, 1, 2)
# sub mean
normed_img = sample - mean_rgb
normed_img = normed_img.reshape((1, 3, 224, 224))return [mx.nd.array(normed_img)]
复制代码

评估并比较模型

为了进行评估,我们使用两个数据集:IM2GPS 数据集和 Flickr 图像测试数据集,后者用于 MediaEval Placing 2016 基准测试


IM2GPS 测试集结果


以下值表示 IM2GPS 测试集中正确位于与实际位置的每个距离内的图像的百分比。



Flickr 图像结果


由于 PlaNet 中使用的测试集图像尚未公开发布,因此不能直接比较这些结果。这些值表示测试集中正确位于与实际位置的每个距离内的图像的百分比。



通过目测检查定位图像,我们可以看到该模型不仅在地标位置方面表现出色,而且也能准确定位非标志性场景。

使用 URL 估算图像的地理位置

现在我们试着用 URL 对网页上的图像进行定位。


Java


Batch = namedtuple('Batch', ['data'])
def predict(imgurl, prefix='images/'):
download_url(imgurl, prefix)
imgname = imgurl.split('/')[-1]
batch = PreprocessImage(prefix + imgname, True)
#predict and show top 5 results
mod.forward(Batch(batch), is_train=False)
prob = mod.get_outputs()[0].asnumpy()[0]
pred = np.argsort(prob)[::-1]
result = list()
for i in range(5):
pred_loc = grids[int(pred[i])]
res = (i+1, prob[pred[i]], pred_loc)
print('rank=%d, prob=%f, lat=%s, lng=%s' \
% (i+1, prob[pred[i]], pred_loc[0], pred_loc[1]))
result.append(res[2])
return result
def download_url(imgurl, img_directory):
if not os.path.exists(img_directory):
os.mkdir(img_directory)
imgname = imgurl.split('/')[-1]
filepath = os.path.join(img_directory, imgname)
if not os.path.exists(filepath):
filepath, _ = urllib.urlretrieve(imgurl, filepath)
statinfo = os.stat(filepath)
print('Succesfully downloaded', imgname, statinfo.st_size, 'bytes.')
复制代码


Java


return filepath
复制代码


来看看我们的模型如何处理东京塔图片。以下代码从 URL 下载图像,并输出模型的位置预测。


Java


#download and predict geo-location of an image of Tokyo Tower
url = 'https://farm5.staticflickr.com/4275/34103081894_f7c9bfa86c_k_d.jpg'
复制代码


Java


result = predict(url)
复制代码


结果列出了置信度分数 (概率) 排在前五位的输出以及地理坐标:


Java


rank=1, prob=0.139923, lat=35.6599344486, lng=139.728919109
rank=2, prob=0.095210, lat=35.6546613641, lng=139.745685815
rank=3, prob=0.042224, lat=35.7098435803, lng=139.810458528
rank=4, prob=0.032602, lat=35.6641725688, lng=139.746648114
rank=5, prob=0.023119, lat=35.6901996892, lng=139.692857396
复制代码


仅通过原始纬度和经度值,很难判断地理位置输出的质量。我们可以通过将输出制成地图来直观地显示结果。

在 Jupyter Notebook 上使用 Google Maps 直观显示结果

为了直观地显示预测结果,我们可以在 Jupyter Notebook 中使用 Google Maps。它让您能够看到预测是否有意义。我们使用一个名为 gmaps 的插件,它允许我们在 Jupyter Notebook 中使用 Google Maps。要安装 gmaps,请按照 gmaps GitHub 页面上的安装说明操作。


使用 gmaps 直观显示结果只需几行代码。请在您的 Notebook 输入以下内容:


Java


import gmaps
gmaps.configure(api_key="") # Fill in with your API key
fig = gmaps.figure()
for i in range(len(result)):
marker = gmaps.marker_layer([result[i]], label=str(i+1))
fig.add_layer(marker)
fig
复制代码



事实上,排在第一位的定位估算结果就是东京塔所在的位置。


现在,试着对您选择的图像进行定位吧!

鸣谢

在 AWS 上训练 LocationNet 的工作得到了 AWS 研究与教育计划的大力支持。我们还要感谢 AWS 公共数据集计划托管 Multimedia Commons 数据集以供公众使用。我们的工作也得到了劳伦斯·利弗莫尔国家实验室领导的合作 LDRD 的部分支持 (美国能源部合同 DE-AC52-07NA27344)。


本文转载自 AWS 技术博客。


原文链接:


https://amazonaws-china.com/cn/blogs/china/estimating-image-locations-using-the-apache-mxnet-and-multimedia-commons-datasets-on-aws-ec2/


2019-11-11 08:00396

评论

发布
暂无评论
发现更多内容

Week10总结

熊威

智能汽车安全风险及防护技术分析

几维安全

移动应用安全

微软看上的Rust 语言,安全性真的很可靠吗

华为云开发者联盟

数据库 开源 rust 安全 代码

Django查看操作数据库的执行命令

BigYoung

数据库 django 操作

易观CTO郭炜:如何构建企业级大数据Ad-hoc查询引擎

易观大数据

看前谷歌工程师是如何副业赚钱的?

非著名程序员

程序员 个人成长 副业赚钱 提升认知

奋斗在一线大城市的年轻人的生活工作实录(工厂蓝领篇)

Learun

程序员 软件开发 故事 企业信息化 短片小说

DSN 主流项目调研 2——Sia和SAFE Network

AIbot

区块链 分布式存储 分布式文件存储 Sia SAFENetwork

LeetCode题解:88. 合并两个有序数组,for循环合并数组+sort排序,JavaScript,详细注释

Lee Chen

大前端 LeetCode

SpringBoot系列(三):SpringBoot特性_SpringApplication类(自定义Banner)

xcbeyond

Java 微服务 springboot Banner

安卓移动应用代码安全加固系统设计及实现

几维安全

android 安全评估 移动应用安全

HTML5CSS3前端入门教程---从0开始通过一个商城实例手把手教你学习PC端和移动端页面开发第10章有路网PC端主页实战整合

Geek_8dbdc1

核心稳定、易扩展——开放关闭原则(The Open-Closed Principle)

晃来晃去的萨麦尔

编程习惯 架构分析 软件设计原则

Windows AD日志分析告警平台—WatchAD安装教程

BigYoung

监控 windows 日志 AD 告警

SpringBoot 系列(一):SpringBoot项目搭建

xcbeyond

Java 微服务 springboot

Django中的session的使用

BigYoung

django session Cookie

神经网络的学习为何要设定损失函数?

王坤祥

神经网络 学习 损失函数

致远互联A6+Cloud C位出道 赋能中小企业乘风破浪

爱极客侠

普通工程师简史

郭华

Cobra 命令自动补全指北

郭旭东

cobra Go 语言

架构师训练营 第 10 周 作业&总结

Jam

关于微服务架构的一些思考

俊俊哥

微服务

《深度工作》学习笔记(完)

石云升

读书笔记 时间管理 专注 深度工作

有限数据量如何最大化提升模型效果?百度工程师构建数据增强服务

百度大脑

人工智能 数据 模型训练 百度大脑

JAVA位运算

彭阿三

Java 位运算

云图说丨手把手教你为容器应用配置弹性伸缩策略

华为云开发者联盟

Docker 云计算 Kubernetes 容器

HTML5+CSS3前端入门教程---从0开始通过一个商城实例手把手教你学习PC端和移动端页面开发第11章有路网移动端主页实战

Geek_8dbdc1

《深度工作》学习笔记(6)

石云升

读书笔记 专注 深度工作

SpringBoot系列(二):如何灵活使用SpringBoot

xcbeyond

Java 微服务 springboot

解析中美数字货币竞争战略 | 构建属于“人类命运共同体”的货币体系

CECBC

数字货币 人民币

DSN 主流项目调研 3——Orbit数据库的故事

AIbot

区块链 分布式存储 IPFS 分布式文件 Orbit

使用 AWS EC2 上的 Apache MXNet 和 Multimedia Commons 数据集来估计图像位置_语言 & 开发_亚马逊云科技 (Amazon Web Services)_InfoQ精选文章