Pytorch学习记录[2]

项目全源码及数据集:http://netdisk.scutvk.cn/pytorch_lesson_1.zip

Dataloader的使用

我们首先导入个数据集

import torchvision
from torch.utils.data import DataLoader




test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())

会报错如下

图1

说我们还要下载,那么我们给多一个参数 download=True

import torchvision
from torch.utils.data import DataLoader


test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(), downloader=True)

后来又有报错如下

图2

可见是关于ssl证书的,应该是下载源用了https协议,那我们禁用https

在源码中加上以下两行

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

然后便可以正常运行了

我们再看看变量test_data的数据

import torchvision
from torch.utils.data import DataLoader





test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())



#仅提取1张图片
img, target = test_data[0]
print(img.shape)
print(target)

输出

图3

加上Dataloader之后

import torchvision
from torch.utils.data import DataLoader

#import ssl
#ssl._create_default_https_context = ssl._create_unverified_context

#test_data = torchvision.datasets.CIFAR10("./dataset", train=False, download=True, transform=torchvision.transforms.ToTensor())
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=48, shuffle=False, num_workers=0, drop_last=False)

#仅提取1张图片
img, target = test_data[0]
print(img.shape)
print(target)

print(len(test_data))
#使用SummaryWriter & Dataloader 提取展示数据集中的图片
writer = SummaryWriter("dataloader-logs")

for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        #print(imgs)
        #print(targets)
        writer.add_images("shuffle_False: {}".format(epoch), imgs, step)
        step += 1

Dataloader的构造函数官网查阅文档

图4、Pytorch官网文档Dataloader

batch_size: 每次Dataloader从数据集中取得的样本量

shuffle: 多次使用Dataloader时,Dataloader传出的数据是否顺序相同

drop_last: 当数据量不被batch_size整除时,余下的数据是否保留(若保留,则最后一次的数据量不足batch_size)

程序执行关键流程:

在for data in test_loader: 处,每次都会从test_loader中取出1个给data,而给data的这个数据又包含了48张图片以及图片的target标签。

当我们用imgs, targets = data的时候,我们提取出了这48张图片到imgs,在通过add_images()添加到我们的tensorboard

最外面的for循环用于看看shuffle的作用

shuffle=True

图5

shuffle=False

图6

One Reply to “Pytorch学习记录[2]”

  1. Nearly all of what you say is supprisingly appropriate and that makes me wonder the reason why I hadn’t looked at this in this light previously. This particular piece truly did switch the light on for me as far as this topic goes. But there is one particular factor I am not really too comfy with and whilst I try to reconcile that with the actual central idea of your position, allow me observe exactly what all the rest of the subscribers have to point out.Nicely done.

发表回复