本节课主要介绍CIFAR10数据集
登录http://www.cs.toronto.edu/~kriz/cifar.html网站,可以自行下载数据集。
打开页面后
前讲的MNIST数据集为0~9的数字识别,而这里的为10类物品识别。由上可见物品包含有飞机、汽车、鸟、猫等。照片大小为32*32的彩色图片。每一个类别大概有6000张照片,其中随机筛选出5000作为学习,余下的1000用于测试。
首先在pycharm软件中新建文件夹,并创建main.py文件。
首先引入一部分相关的工具包
代码语言:javascript
复制
import torch
from torchvision import datasets
# 引入pytorch、datasets工具包
定义main函数
代码语言:javascript
复制
def main():
if name == 'main':
main()
下面开始在里面写入代码
首先开始加载数据集
代码语言:javascript
复制
def main():
cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([ transforms.Resize((32, 32)),</code></pre></div></div><p>继续</p><p>在前面引入工具包处加入代码</p><div class="rno-markdown-code"><div class="rno-markdown-code-toolbar"><div class="rno-markdown-code-toolbar-info"><div class="rno-markdown-code-toolbar-item is-type"><span class="is-m-hidden">代码语言:</span>javascript</div></div><div class="rno-markdown-code-toolbar-opt"><div class="rno-markdown-code-toolbar-copy"><i class="icon-copy"></i><span class="is-m-hidden">复制</span></div></div></div><div class="developer-code-block"><pre class="prism-token token line-numbers language-javascript"><code class="language-javascript" style="margin-left:0">from torchvision import transforms
引入数据变换工具包
继续定义数据集代码
代码语言:javascript
复制
def main():
cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([ transforms.Resize((32, 32)), # .Compose相当于一个数据转换的集合 # 进行数据转换,首先将图片统一为32*32 transforms.ToTensor() # 将数据转化到Tensor中 ])) # 直接在datasets中导入CIFAR10数据集,放在"cifar"文件夹中</code></pre></div></div><p>这里暂时不写Normalize函数</p><p>写到这里别忘了让pytorch自己下载数据集</p><p>在代码后面加入download=True即可实现</p><div class="rno-markdown-code"><div class="rno-markdown-code-toolbar"><div class="rno-markdown-code-toolbar-info"><div class="rno-markdown-code-toolbar-item is-type"><span class="is-m-hidden">代码语言:</span>javascript</div></div><div class="rno-markdown-code-toolbar-opt"><div class="rno-markdown-code-toolbar-copy"><i class="icon-copy"></i><span class="is-m-hidden">复制</span></div></div></div><div class="developer-code-block"><pre class="prism-token token line-numbers language-javascript"><code class="language-javascript" style="margin-left:0">]), download=True)</code></pre></div></div><p>Cifar_train 的代码部分已经写好</p><p>写到这里要注意这里只是建立了一次加载一张的代码</p><p>若想一次性加载一批,则要利用其多线程的特性</p><p>继续在引入工具包部分加入相关工具包</p><div class="rno-markdown-code"><div class="rno-markdown-code-toolbar"><div class="rno-markdown-code-toolbar-info"><div class="rno-markdown-code-toolbar-item is-type"><span class="is-m-hidden">代码语言:</span>javascript</div></div><div class="rno-markdown-code-toolbar-opt"><div class="rno-markdown-code-toolbar-copy"><i class="icon-copy"></i><span class="is-m-hidden">复制</span></div></div></div><div class="developer-code-block"><pre class="prism-token token line-numbers language-javascript"><code class="language-javascript" style="margin-left:0">from torch.utils.data import DataLoader
多线程数据读取
继续书写数据读取部分代码
按照其提示,写入相关参数
cifar_train = DataLoader(cifar_train, batch_size=batchsz, )