AI 安全之对抗样本入门 (35):打造对抗样本工具箱 3.4

阅读数:12 2019 年 11 月 30 日 15:14

AI安全之对抗样本入门(35):打造对抗样本工具箱 3.4

(PyTorch)

内容简介
第 1 章介绍了深度学习的基础知识,重点介绍了与对抗样本相关的梯度、优化器、反向传递等知识点。
第 2 章介绍了如何搭建学习对抗样本的软硬件环境,虽然 GPU 不是必需的,但是使用 GPU 可以更加快速地验证你的想法。
第 3 章概括介绍了常见的深度学习框架,从 TensorFlow、Keras、PyTorch 到 MXNet。
第 4 章介绍了图像处理领域的基础知识,这部分知识对于理解对抗样本领域的一些常见图像处理技巧非常有帮助。
第 5 章介绍了常见的白盒攻击算法,从基础的 FGSM、DeepFool 到经典的 JSMA 和 CW。
第 6 章介绍了常见的黑盒攻击算法。
第 7 章介绍了对抗样本在目标识别领域的应用。
第 8 章介绍了对抗样本的常见抵御算法,与对抗样本一样,抵御对抗样本的技术也非常有趣。
第 9 章介绍了常见的对抗样本工具以及如何搭建 NIPS 2017 对抗防御环境和轻量级攻防对抗环境 robust-ml,通过这章读者可以了解如何站在巨人的肩膀上,快速生成自己的对抗样本,进行攻防对抗。

PyTorch 是 torch 的 Python 版本,是由 Facebook 开源的神经网络框架。PyTorch 虽然是深度学习框架中的后起之秀,但是发展极其迅猛。PyTorch 提供了 NumPy 风格的 Tensor 操作,熟悉 NumPy 操作的用户非常容易上手。我们以解决经典的手写数字识别的问题为例,介绍 PyTorch 的基本使用方法,代码路径为:

复制代码
https://github.com/duoergun0729/adversarial_examples/blob/master/code/2-pytorch.ipynb
  1. 加载相关库

加载处理经典的手写数字识别问题相关的 Python 库:

复制代码
import os
import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data
  1. 加载数据集

PyTorch 中针对常见的数据集进行了封装,免去了用户手工下载的过程并简化了预处理的过程。这里需要特别指出的是,PyTorch 中每个 Tensor 包括输入节点,并且都可以有自己的梯度值,因此训练数据集要设置为 train=True,测试数据集要设置为 train=False:

复制代码
train_data = torchvision.datasets.MNIST(
'dataset/mnist-pytorch', train=True,
transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
'dataset/mnist-pytorch', train=False,
transform=torchvision.transforms.ToTensor()
)

如果需要对数据进行归一化,可以进一步使用 transforms.Normalize 方法:

复制代码
transform=transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5], [0.5])])

第一次运行该程序时,PyTorch 会从互联网直接下载数据集并处理:

复制代码
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing... Done!
  1. 定义网络结构

使用与 Keras 类似的网络结构,即两层隐藏层结构,不过使用 BatchNorm 层替换了 Dropout 层,在抵御过拟合的同时加快了训练的收敛速度。在 PyTorch 中定义网络结构,通常需要继承 torch.nn.Module 类,重点是在 forward 中完成前向传播的定义,在 init 中完成主要网络层的定义:

复制代码
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.dense = torch.nn.Sequential(
#全连接层
torch.nn.Linear(784, 512),
#BatchNorm 层
torch.nn.BatchNorm1d(512),
torch.nn.ReLU(),
torch.nn.Linear(512, 10),
torch.nn.ReLU()
) def forward(self, x):
#把输出转换成大小为 784 的一维向量
x = x.view(-1, 784)
x=self.dense(x)
return torch.nn.functional.log_softmax(x, dim=1)

最后可视化网络结构,细节如图 3-7 所示。

AI安全之对抗样本入门(35):打造对抗样本工具箱 3.4

图 3-7 PyTorch 处理 MNIST 的网络结构图
  1. 定义损失函数和优化器

损失函数使用交叉熵 CrossEntropyLoss,优化器使用 Adam,优化的对象是全部网络参数:

复制代码
optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()
  1. 训练与验证

PyTorch 的训练和验证过程是分开的,在训练阶段需要把训练数据进行前向传播后,使用损失函数计算训练数据的真实标签与预测标签之间损失值,然后显示调用反向传递 backward(),使用优化器来调整参数,这一操作需要调用 optimizer.step():

复制代码
for i, data in enumerate(train_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = model(inputs)
loss = loss_func(outputs, labels)
#反向传递
loss.backward()
optimizer.step()

每轮训练需要花费较长的时间,为了让训练过程可视化,可以打印训练的中间结果,比如每 100 个批次打印下平均损失值:

复制代码
# 每训练 100 个批次打印一次平均损失值
sum_loss += loss.item()
if (i+1) % 100 == 0:
print('epoch=%d, batch=%d loss: %.04f'% (epoch + 1, i+1, sum_loss / 100))
sum_loss = 0.0

验证阶段要手工关闭反向传递,需要通过 torch.no_grad() 实现:

复制代码
# 每跑完一次 epoch,测试一下准确率进入测试模式,禁止梯度传递
with torch.no_grad():
correct = 0
total = 0
for data in test_loader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
# 取得分最高的那个类
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('epoch=%d accuracy=%.02f%%' % (epoch + 1, (100 * correct /
total)))

经过 20 轮训练,在测试集上准确度达到了 97.00%:

复制代码
epoch=20, batch=100 loss: 0.0035
epoch=20, batch=200 loss: 0.0049
epoch=20, batch=300 loss: 0.0040
epoch=20, batch=400 loss: 0.0042
epoch=20 accuracy=97.00%

PyTorch 保存的模型文件后缀为 pth:

复制代码
torch.save(model.state_dict(), 'models/pytorch-mnist.pth')

AI安全之对抗样本入门(35):打造对抗样本工具箱 3.4

购书地址 https://item.jd.com/12532163.html?dist=jd

评论

发布