写点什么

使用 AWS EC2 上的 Apache MXNet 和 Multimedia Commons 数据集来估计图像位置

  • 2019 年 11 月 11 日
  • 本文字数:4408 字

    阅读完需:约 14 分钟

使用 AWS EC2 上的 Apache MXNet 和 Multimedia Commons 数据集来估计图像位置

这是由国际计算机科学研究院的 Jaeyoung Choi 和加州大学伯克利分校的 Kevin Li 所著的一篇访客文章。本项目演示学术研究人员如何利用我们的 AWS Cloud Credits for Research Program 实现科学突破。


当您拍摄照片时,现代移动设备可以自动向图像分配地理坐标。不过,网络上的大多数图像仍缺少该位置元数据。图像定位是估计图像位置并应用位置标签的过程。根据您的数据集大小以及提出问题的方式,分配的位置标签可以是建筑物或地标名称或实际地理坐标 (纬度、经度)。


在本文中,我们会展示如何使用通过 Apache MXNet 创建的预训练模型对图像进行地理分类。我们使用的数据集包含拍摄于全球各地的数百万张 Flickr 图像。我们还会展示如何将结果制成地图以直观地显示结果。


我们的方法

图像定位方法可以分为两类:图像检索搜索法和分类法。(该博文将对这两个类别中最先进的方法进行比较。)


Weyand 等人近期的作品提出图像定位是一个分类问题。在这种方法中,作者将地球表面细分为数千个地理单元格,并利用带地理标记的图像训练了深层神经网路。有关他们的试验更通俗的描述,请参阅该文章


由于作者没有公开他们的训练数据或训练模型 (即 PlaNet),因此我们决定训练我们自己的图像定位器。我们训练模型的场景灵感来自于 Weyand 等人描述的方法,但是我们对几个设置作了改动。


我们在单个 p2.16xlarge 实例上使用 MXNet 来训练我们的模型 LocationNet,该实例包含来自 AWS Multimedia Commons 数据集的带有地理标记的图像。


我们将训练、验证和测试图像分离,以便同一人上传的图像不会出现在多个集合中。我们使用 Google 的 S2 Geometry Library 通过训练数据创建类。该模型经过 12 个训练周期后收敛,完成 p2.16xlarge 实例训练大约花了 9 天时间。GitHub 上提供了采用 Jupyter Notebook 的完整教程


下表对用于训练和测试 LocationNet 和 PlaNet 的设置进行了比较。


col 1col 2col 3


    | LocationNet                                                    | PlaNet                           
复制代码


数据集来源 | Multimedia Commons | 从网络抓取的图像


训练集 | 3390 万 | 9100 万


验证 | 180 万 | 3400 万


S2 单元分区 | t1=5000, t2=500


→ 15,527 个单元格 | t1=10,000, t2=50


→ 26,263 个单元格


模型 | ResNet-101 | GoogleNet


优化 | 使用动量和 LR 计划的 SGD | Adagrad


训练时间 | 采用 16 个 NVIDIA K80 GPU (p2.16xlarge EC2 实例) 时为 9 天


12 个训练周期 | 采用 200 个 CPU 内核时为两个半月


框架 | MXNet | DistBelief


测试集 | Placing Task 2016 测试集 (150 万张 Flickr 图像) | 230 万张有地理标记的 Flickr 图像


在推理时,LocationNet 会输出地理单元格间的概率分布。单元格中概率最高的图像的质心地理坐标会被分配为查询图像的地理坐标。


LocationNet 会在 MXNet Model Zoo 中公开分享。


下载 LocationNet

现在下载 LocationNet 预训练模型。LocationNet 已使用 AWS Multimedia Commons 数据集中带地理标记的图像子集进行了训练。Multimedia Commons 数据集包含 3900 多万张图像和 15000 个地理单元格 (类)。


LocationNet 包括两部分:一个包含模型定义的 JSON 文件和一个包含参数的二进制文件。我们从 S3 加载必要的软件包并下载文件。


Java


import os
import urllib
import mxnet as mx
import logging
import numpy as np
from skimage import io, transform
from collections import namedtuple
from math import radians, sin, cos, sqrt, asin
path = 'https://s3.amazonaws.com/mmcommons-tutorial/models/'
model_path = 'models/'
if not os.path.exists(model_path):
os.mkdir(model_path)
urllib.urlretrieve(path+'RN101-5k500-symbol.json', model_path+'RN101-5k500-symbol.json')
urllib.urlretrieve(path+'RN101-5k500-0012.params', model_path+'RN101-5k500-0012.params')
复制代码


然后,加载下载的模型。如果您没有可用 GPU,请将 mx.gpu() 替换为 mx.cpu():


Java


# Load the pre-trained model
prefix = "models/RN101-5k500"
load_epoch = 12
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, load_epoch)
mod = mx.mod.Module(symbol=sym, context=mx.gpu())
mod.bind([('data', (1,3,224,224))], for_training=False)mod.set_params(arg_params, aux_params, allow_missing=True)
复制代码


grids.txt 文件包含用于训练模型的地理单元格。


第 i 行是第 i 个类,列分别代表:S2 单元格标记、纬度和经度。我们将标签加载到名为 grids 的列表中。


Java


# Download and load grids file
urllib.urlretrieve('https://raw.githubusercontent.com/multimedia-berkeley/tutorials/master/grids.txt','grids.txt')
# Load labels.
grids = []
with open('grids.txt', 'r') as f:
for line in f:
line = line.strip().split('\t')
lat = float(line[1])
lng = float(line[2])
grids.append((lat, lng))
复制代码


该模型使用半径公式来测量点 p1 和 p2 之间的大圆弧距离,以千米为单位:


Java


def distance(p1, p2):
R = 6371 # Earth radius in km
lat1, lng1, lat2, lng2 = map(radians, (p1[0], p1[1], p2[0], p2[1]))
dlat = lat2 - lat1
dlng = lng2 - lng1
a = sin(dlat * 0.5) ** 2 + cos(lat1) * cos(lat2) * (sin(dlng * 0.5) ** 2)
复制代码


Java


return 2 * R * asin(sqrt(a))
复制代码


在将图像提供给深度学习网络之前,该模型会通过裁剪以及减去均值来预处理图像:


Java


# mean image for preprocessing
mean_rgb = np.array([123.68, 116.779, 103.939])
mean_rgb = mean_rgb.reshape((3, 1, 1))
def PreprocessImage(path, show_img=False):
# load image.
img = io.imread(path)
# We crop image from center to get size 224x224.
short_side = min(img.shape[:2])
yy = int((img.shape[0] - short_side) / 2)
xx = int((img.shape[1] - short_side) / 2)
crop_img = img[yy : yy + short_side, xx : xx + short_side]
resized_img = transform.resize(crop_img, (224,224))
if show_img:
io.imshow(resized_img)
# convert to numpy.ndarray
sample = np.asarray(resized_img) * 256
# swap axes to make image from (224, 224, 3) to (3, 224, 224)
sample = np.swapaxes(sample, 0, 2)
sample = np.swapaxes(sample, 1, 2)
# sub mean
normed_img = sample - mean_rgb
normed_img = normed_img.reshape((1, 3, 224, 224))return [mx.nd.array(normed_img)]
复制代码


评估并比较模型

为了进行评估,我们使用两个数据集:IM2GPS 数据集和 Flickr 图像测试数据集,后者用于 MediaEval Placing 2016 基准测试


IM2GPS 测试集结果


以下值表示 IM2GPS 测试集中正确位于与实际位置的每个距离内的图像的百分比。



Flickr 图像结果


由于 PlaNet 中使用的测试集图像尚未公开发布,因此不能直接比较这些结果。这些值表示测试集中正确位于与实际位置的每个距离内的图像的百分比。



通过目测检查定位图像,我们可以看到该模型不仅在地标位置方面表现出色,而且也能准确定位非标志性场景。


使用 URL 估算图像的地理位置

现在我们试着用 URL 对网页上的图像进行定位。


Java


Batch = namedtuple('Batch', ['data'])
def predict(imgurl, prefix='images/'):
download_url(imgurl, prefix)
imgname = imgurl.split('/')[-1]
batch = PreprocessImage(prefix + imgname, True)
#predict and show top 5 results
mod.forward(Batch(batch), is_train=False)
prob = mod.get_outputs()[0].asnumpy()[0]
pred = np.argsort(prob)[::-1]
result = list()
for i in range(5):
pred_loc = grids[int(pred[i])]
res = (i+1, prob[pred[i]], pred_loc)
print('rank=%d, prob=%f, lat=%s, lng=%s' \
% (i+1, prob[pred[i]], pred_loc[0], pred_loc[1]))
result.append(res[2])
return result
def download_url(imgurl, img_directory):
if not os.path.exists(img_directory):
os.mkdir(img_directory)
imgname = imgurl.split('/')[-1]
filepath = os.path.join(img_directory, imgname)
if not os.path.exists(filepath):
filepath, _ = urllib.urlretrieve(imgurl, filepath)
statinfo = os.stat(filepath)
print('Succesfully downloaded', imgname, statinfo.st_size, 'bytes.')
复制代码


Java


return filepath
复制代码


来看看我们的模型如何处理东京塔图片。以下代码从 URL 下载图像,并输出模型的位置预测。


Java


#download and predict geo-location of an image of Tokyo Tower
url = 'https://farm5.staticflickr.com/4275/34103081894_f7c9bfa86c_k_d.jpg'
复制代码


Java


result = predict(url)
复制代码


结果列出了置信度分数 (概率) 排在前五位的输出以及地理坐标:


Java


rank=1, prob=0.139923, lat=35.6599344486, lng=139.728919109
rank=2, prob=0.095210, lat=35.6546613641, lng=139.745685815
rank=3, prob=0.042224, lat=35.7098435803, lng=139.810458528
rank=4, prob=0.032602, lat=35.6641725688, lng=139.746648114
rank=5, prob=0.023119, lat=35.6901996892, lng=139.692857396
复制代码


仅通过原始纬度和经度值,很难判断地理位置输出的质量。我们可以通过将输出制成地图来直观地显示结果。


在 Jupyter Notebook 上使用 Google Maps 直观显示结果

为了直观地显示预测结果,我们可以在 Jupyter Notebook 中使用 Google Maps。它让您能够看到预测是否有意义。我们使用一个名为 gmaps 的插件,它允许我们在 Jupyter Notebook 中使用 Google Maps。要安装 gmaps,请按照 gmaps GitHub 页面上的安装说明操作。


使用 gmaps 直观显示结果只需几行代码。请在您的 Notebook 输入以下内容:


Java


import gmaps
gmaps.configure(api_key="") # Fill in with your API key
fig = gmaps.figure()
for i in range(len(result)):
marker = gmaps.marker_layer([result[i]], label=str(i+1))
fig.add_layer(marker)
fig
复制代码



事实上,排在第一位的定位估算结果就是东京塔所在的位置。


现在,试着对您选择的图像进行定位吧!


鸣谢

在 AWS 上训练 LocationNet 的工作得到了 AWS 研究与教育计划的大力支持。我们还要感谢 AWS 公共数据集计划托管 Multimedia Commons 数据集以供公众使用。我们的工作也得到了劳伦斯·利弗莫尔国家实验室领导的合作 LDRD 的部分支持 (美国能源部合同 DE-AC52-07NA27344)。


本文转载自 AWS 技术博客。


原文链接:


https://amazonaws-china.com/cn/blogs/china/estimating-image-locations-using-the-apache-mxnet-and-multimedia-commons-datasets-on-aws-ec2/


2019 年 11 月 11 日 08:00259

评论

发布
暂无评论
发现更多内容

Netty进阶:手把手教你如何编写一个NIO服务端,java集合容器面试

Java 程序员 后端

Redis常用命令总结,java项目实例教程详细

Java 程序员 后端

Redis应用之缓存实现,java异步编程实战pdf

Java 程序员 后端

Nginx + Tomcat 搭建负载均衡,大牛带你直击优秀开源框架灵魂

Java 程序员 后端

【架构实战营】模块三

Henry | 衣谷

架构实战营

Red5搭建直播平台,java淘宝客教程

Java 程序员 后端

Nginx配置反向代理和负载均衡,疯狂java讲义pdf百度云

Java 程序员 后端

Redis 笔记之 Java 操作 Redis(Jedis),springcloud实战pdf

Java 程序员 后端

Redis-数据库、键过期的实现,跟面试官侃半小时MySQL事务隔离性

Java 程序员 后端

Redis源码剖析——客户端和服务器,springboot入门程序

Java 后端

RabbitMQ不讲武德,发个消息也这么多花招,nginx实现负载均衡原理

Java 程序员 后端

Redis 配置文件重要属性介绍,java面试项目经验

Java 程序员 后端

RocketMQ ACL版本升级过程中的曲折经历(大厂线上环境大规模MQ升级开启ACL实战)

Java 程序员 后端

Rpc与RMI服务,java面试笔试题代码

Java 程序员 后端

Nginx服务不行了怎么办,网商银行java面试

Java 程序员 后端

Oracle最新的Sql笔试题及答案,Java面试真题解析火爆全网

Java 程序员 后端

Redis-中会涉及那么多数据结构,那你数据对象的底层实现方式你都了解吗?

Java 程序员 后端

Redis持久化--Redis宕机或者出现意外删库导致数据丢失--解决方案

Java 程序员 后端

Netty进阶:手把手教你如何编写一个NIO服务端(1),Java笔试常见编程题

Java 程序员 后端

new-Object()到底占用几个字节,看完这篇彻底明白了!,springboot微服务架构书籍

Java 程序员 后端

OpenFaaS实战之二:函数入门,mysql集群数据同步原理

Java 程序员 后端

OpenYurt v0,linuxshell学习

Java 程序员 后端

Redis、MongoDB及Memcached的区别,老男孩linux运维54期视频

Java 程序员 后端

Netty案例介绍-群聊案例实现,java架构师教程百度云

Java 程序员 后端

P8级大佬整理在Github上45K+star手册,吃透消化,java算法面试题及答案pdf

Java 程序员 后端

Peter-Java 8中的Lambda表达式,java领域的相关技术领域

Java 程序员 后端

Redis分布式锁的原理以及如何续期,java程序设计实验实训教程答案

Java 程序员 后端

redis数据迁移之redis-shake,java高级技术经理面试题

Java 程序员 后端

Redis的各种用途以及使用场景,mybatis技术原理

Java 程序员 后端

Netty相关面试题汇总,java从入门到精通第五版电子书下载微盘

Java 程序员 后端

Redis 变慢了?那你这样试试,不行就捶我,mybatis工作原理图

Java 程序员 后端

使用 AWS EC2 上的 Apache MXNet 和 Multimedia Commons 数据集来估计图像位置_语言 & 开发_亚马逊云科技 (Amazon Web Services)_InfoQ精选文章