Pytorch 划分数据集torch.utils.data.dataset.Dataset

关键字:划分数据集、torch.utils.data.dataset.Dataset、torch.utils.data.random_split

实现效果:将Dataset类按指定比例划分为训练集、验证集(和测试集)

为实现此目的,我们使用到pytorch的函数torch.utils.data.random_split()


函数尝试:

random_split(dataset, lengths, generator)

输入参数:

dataset,类型:torch.utils.data.dataset.Dataset,欲划分的数据集

lengths,类型:sequence,例如[7, 3],注意意思不是把数据集按7:3比例划分,而是划分为大小分别为7和3的数据集。

generator,类型:torch.Generator,随机数生成器,带默认参数generator = default_generator

返回参数:

torch.utils.data.dataset.Subset的数组

官方例子:

random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))

返回的是2个dataset.Subset

笔者例子:

# load dataset
total_dataset = mrds.TxtDataset(neg_loc="../raw_data/rt-polaritydata/rt-polaritydata/rt-polarity.neg", pos_loc="../raw_data/rt-polaritydata/rt-polaritydata/rt-polarity.pos", word2index=glove_vocab_dic) # 返回了一个Dataset
train_length = int(total_dataset.__len__() * 0.7) #  获取训练集长度
val_length = total_dataset.__len__() - train_length #  获取验证集长度
train_dataset, val_dataset = random_split(dataset=total_dataset, lengths=[train_length, val_length])

观察train_dataset、val_dataset两个Subset的成员,发现内由dataset成员,这便是我们要的、分割后的训练集。

加上两端源码

train_dataset = train_dataset.dataset
val_dataset = val_dataset.dataset

大功告成。


总结:

直接搬砖可用、总源码:

# load dataset
total_dataset = mrds.TxtDataset(neg_loc="../raw_data/rt-polaritydata/rt-polaritydata/rt-polarity.neg", pos_loc="../raw_data/rt-polaritydata/rt-polaritydata/rt-polarity.pos", word2index=glove_vocab_dic) # 返回了一个Dataset
train_length = int(total_dataset.__len__() * 0.7) #  获取训练集长度
val_length = total_dataset.__len__() - train_length #  获取验证集长度
train_dataset, val_dataset = random_split(dataset=total_dataset, lengths=[train_length, val_length])
train_dataset = train_dataset.dataset
val_dataset = val_dataset.dataset

发表回复