你已经学会如何去定义一个神经网络,计算损失值和更新网络的权重。
你现在可能在思考:数据哪里来呢?
关于数据
通常,当你处理图像,文本,音频和视频数据时,你可以使用标准的Python包来加载数据到一个numpy数组中.然后把这个数组转换成
torch.*Tensor
。
对于图像,有诸如Pillow,OpenCV包等非常实用
对于文本,可以用原始Python和Cython来加载,或者使用NLTK和SpaCy 对于视觉,我们创建了一个
torchvision
包,包含常见数据集的数据加载,比如Imagenet,CIFAR10,MNIST等,和图像转换器,也就是
torchvision.datasets
和
torch.utils.data.DataLoader
。
这提供了巨大的便利,也避免了代码的重复。
在这个教程中,我们使用CIFAR10数据集,它有如下10个类别:’airplane’,’automobile’,’bird’,’cat’,’deer’,’dog’,’frog’,’horse’,’ship’,’truck’。这个数据集中的图像大小为3*32*32,即,3通道,32*32像素。
训练一个图像分类器
我们将按照下列顺序进行:
使用
torchvision
加载和归一化CIFAR10训练集和测试集.
1. 加载和归一化CIFAR10
使用
torchvision
加载CIFAR10是非常容易的。
%matplotlib inlineimport torchimport torchvisionimport torchvision.transforms as transforms
torchvision的输出是[0,1]的PILImage图像,我们把它转换为归一化范围为[-1, 1]的张量。
注意
如果在Windows上运行时出现BrokenPipeError,尝试将torch.utils.data.DataLoader()的num_worker设置为0。
transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))]) trainset = torchvision.datasets.CIFAR10(root='./data' , train=True , download=True , transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4 , shuffle=True , num_workers=2 ) testset = torchvision.datasets.CIFAR10(root='./data' , train=False , download=True , transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4 , shuffle=False , num_workers=2 ) classes = ('plane' , 'car' , 'bird' , 'cat' , 'deer' , 'dog' , 'frog' , 'horse' , 'ship' , 'truck' )#这个过程有点慢,会下载大约340mb图片数据。
我们展示一些有趣的训练图像。
import matplotlib.pyplot as pltimport numpy as np# functions to show an image def imshow (img) : img = img / 2 + 0.5 # unnormalize npimg = img.numpy() plt.imshow(np.transpose(npimg, (1 , 2 , 0 ))) plt.show()# get some random training images dataiter = iter(trainloader) images, labels = dataiter.next()# show images imshow(torchvision.utils.make_grid(images))# print labels print(' ' .join('%5s' % classes[labels[j]] for j in range(4 )))
2. 定义一个卷积神经网络
从之前的神经网络一节复制神经网络代码,并修改为接受3通道图像取代之前的接受单通道图像。
import torch.nn as nnimport torch.nn.functional as Fclass Net (nn.Module) : def __init__ (self) : super(Net, self).__init__() self.conv1 = nn.Conv2d(3 , 6 , 5 ) self.pool = nn.MaxPool2d(2 , 2 ) self.conv2 = nn.Conv2d(6 , 16 , 5 ) self.fc1 = nn.Linear(16 * 5 * 5 , 120 ) self.fc2 = nn.Linear(120 , 84 ) self.fc3 = nn.Linear(84 , 10 ) def forward (self, x) : x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1 , 16 * 5 * 5 ) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net()
3. 定义损失函数和优化器
我们使用交叉熵作为损失函数,使用带动量的随机梯度下降。
import torch.optim as optim criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001 , momentum=0.9 )
4. 训练网络
这是开始有趣的时刻,我们只需在数据迭代器上循环,把数据输入给网络,并优化。
for epoch in range(2 ): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0 ): # get the inputs; data is a list of [inputs, labels] inputs, labels = data # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() if i % 2000 == 1999 : # print every 2000 mini-batches print('[%d, %5d] loss: %.3f' % (epoch + 1 , i + 1 , running_loss / 2000 )) running_loss = 0.0 print('Finished Training' )
保存一下我们的训练模型
PATH = './cifar_net.pth' torch.save(net.state_dict(), PATH)
点击这里查看关于保存模型的详细介绍
5. 在测试集上测试网络
我们在整个训练集上训练了两次网络,但是我们还需要检查网络是否从数据集中学习到东西。
我们通过预测神经网络输出的类别标签并根据实际情况进行检测,如果预测正确,我们把该样本添加到正确预测列表。
第一步,显示测试集中的图片一遍熟悉图片内容。
dataiter = iter(testloader) images, labels = dataiter.next()# print images imshow(torchvision.utils.make_grid(images)) print('GroundTruth: ' , ' ' .join('%5s' % classes[labels[j]] for j in range(4 )))
接下来,让我们重新加载我们保存的模型(注意:保存和重新加载模型在这里不是必要的,我们只是为了说明如何这样做):
net = Net() net.load_state_dict(torch.load(PATH))
现在我们来看看神经网络认为以上图片是什么?
outputs = net(images)
输出是10个标签的概率。一个类别的概率越大,神经网络越认为他是这个类别。所以让我们得到最高概率的标签。
_, predicted = torch.max(outputs, 1 ) print('Predicted: ' , ' ' .join('%5s' % classes[predicted[j]] for j in range(4 )))
这结果看起来非常的好。
接下来让我们看看网络在整个测试集上的结果如何。
correct = 0