1

我曾经SubsetRandomSampler将训练数据拆分为训练(80%)和验证数据(20%)。但它在拆分后(4996)显示了相同数量的图像:

>>> print('len(train_data): ', len(train_loader.dataset))
>>> print('len(valid_data): ', len(validation_loader.dataset))
len(train_data):  4996
len(valid_data):  4996

完整代码:

import numpy as np
import torch
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler

train_transforms = transforms.Compose([transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406],
                                                            [0.229, 0.224, 0.225])])
dataset = datasets.ImageFolder( '/data/images/train', transform=train_transforms )

validation_split = .2
shuffle_dataset = True
random_seed= 42
batch_size = 20

dataset_size = len(dataset) #4996
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))

if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)
4

1 回答 1

1

train_loader.dataset并且validation_loader.dataset是返回加载程序从中采样的底层原始数据集(即大小为 4996 的原始数据集)的方法。

但是,如果您遍历加载器本身,您将看到它们只返回与您在每个采样器的索引中包含的一样多的样本(考虑批处理)。

于 2021-03-20T09:05:20.407 回答