Pytorch 划分数据集torch.utils.data.dataset.Dataset
关键字:划分数据集、torch.utils.data.dataset.Dataset、torch.utils.data.random_split
实现效果:将Dataset类按指定比例划分为训练集、验证集(和测试集)
为实现此目的,我们使用到pytorch的函数torch.utils.data.random_split()
函数尝试:
random_split(dataset, lengths, generator)
输入参数:
dataset,类型:torch.utils.data.dataset.Dataset,欲划分的数据集
lengths,类型:sequence,例如[7, 3],注意意思不是把数据集按7:3比例划分,而是划分为大小分别为7和3的数据集。
generator,类型:torch.Generator,随机数生成器,带默认参数generator = default_generator
返回参数:
torch.utils.data.dataset.Subset的数组
官方例子:
random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))

返回的是2个dataset.Subset
笔者例子:
# load dataset
total_dataset = mrds.TxtDataset(neg_loc="../raw_data/rt-polaritydata/rt-polaritydata/rt-polarity.neg", pos_loc="../raw_data/rt-polaritydata/rt-polaritydata/rt-polarity.pos", word2index=glove_vocab_dic) # 返回了一个Dataset
train_length = int(total_dataset.__len__() * 0.7) # 获取训练集长度
val_length = total_dataset.__len__() - train_length # 获取验证集长度
train_dataset, val_dataset = random_split(dataset=total_dataset, lengths=[train_length, val_length])

观察train_dataset、val_dataset两个Subset的成员,发现内由dataset成员,这便是我们要的、分割后的训练集。

加上两端源码
train_dataset = train_dataset.dataset
val_dataset = val_dataset.dataset
大功告成。
总结:
直接搬砖可用、总源码:
# load dataset
total_dataset = mrds.TxtDataset(neg_loc="../raw_data/rt-polaritydata/rt-polaritydata/rt-polarity.neg", pos_loc="../raw_data/rt-polaritydata/rt-polaritydata/rt-polarity.pos", word2index=glove_vocab_dic) # 返回了一个Dataset
train_length = int(total_dataset.__len__() * 0.7) # 获取训练集长度
val_length = total_dataset.__len__() - train_length # 获取验证集长度
train_dataset, val_dataset = random_split(dataset=total_dataset, lengths=[train_length, val_length])
train_dataset = train_dataset.dataset
val_dataset = val_dataset.dataset