Pytorch学习记录[2]
项目全源码及数据集:http://netdisk.scutvk.cn/pytorch_lesson_1.zip
Dataloader的使用
我们首先导入个数据集
import torchvision
from torch.utils.data import DataLoader
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
会报错如下
说我们还要下载,那么我们给多一个参数 download=True
import torchvision
from torch.utils.data import DataLoader
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(), downloader=True)
后来又有报错如下
可见是关于ssl证书的,应该是下载源用了https协议,那我们禁用https
在源码中加上以下两行
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
然后便可以正常运行了
我们再看看变量test_data的数据
import torchvision
from torch.utils.data import DataLoader
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
#仅提取1张图片
img, target = test_data[0]
print(img.shape)
print(target)
输出
加上Dataloader之后
import torchvision
from torch.utils.data import DataLoader
#import ssl
#ssl._create_default_https_context = ssl._create_unverified_context
#test_data = torchvision.datasets.CIFAR10("./dataset", train=False, download=True, transform=torchvision.transforms.ToTensor())
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=48, shuffle=False, num_workers=0, drop_last=False)
#仅提取1张图片
img, target = test_data[0]
print(img.shape)
print(target)
print(len(test_data))
#使用SummaryWriter & Dataloader 提取展示数据集中的图片
writer = SummaryWriter("dataloader-logs")
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
#print(imgs)
#print(targets)
writer.add_images("shuffle_False: {}".format(epoch), imgs, step)
step += 1
Dataloader的构造函数官网查阅文档
batch_size: 每次Dataloader从数据集中取得的样本量
shuffle: 多次使用Dataloader时,Dataloader传出的数据是否顺序相同
drop_last: 当数据量不被batch_size整除时,余下的数据是否保留(若保留,则最后一次的数据量不足batch_size)
程序执行关键流程:
在for data in test_loader: 处,每次都会从test_loader中取出1个给data,而给data的这个数据又包含了48张图片以及图片的target标签。
当我们用imgs, targets = data的时候,我们提取出了这48张图片到imgs,在通过add_images()添加到我们的tensorboard
最外面的for循环用于看看shuffle的作用
shuffle=True
shuffle=False
Nearly all of what you say is supprisingly appropriate and that makes me wonder the reason why I hadn’t looked at this in this light previously. This particular piece truly did switch the light on for me as far as this topic goes. But there is one particular factor I am not really too comfy with and whilst I try to reconcile that with the actual central idea of your position, allow me observe exactly what all the rest of the subscribers have to point out.Nicely done.