1

我在下面有以下数据加载器:

def load_dataset(size_batch, size):
    data_path = "/home/bledc/dataset/test_set/crops_BSD"
    transformations = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor()
        ])

    train_dataset = datasets.ImageFolder(
        root=data_path,
        transform=transformations
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=size_batch,
        shuffle=True,
        num_workers=0,
        drop_last=True
    )
    return train_loader

我在训练循环中使用以下内容迭代它:

data_loader = load_dataset(batch_size, width)
for data in data_loader:
        model.zero_grad()
        optimizer.zero_grad()
        img, _ = data
        img = img.to(device)

有人可以向我解释将 load_dataset() 函数写入类有什么好处吗?原因是我一直在使用上面的模板从我在网上找到的代码中加载数据,但似乎大多数代码库使用class LoaderName(Dataset)后跟定义初始条件和 super()。

谢谢你。

4

1 回答 1

0

用@Berriel 的话来说:

在你的情况下,没有。由于 datasets.ImageFolder 适合您的数据集,因此您无需实现任何东西。在您的情况下不需要自定义实现。当需要自定义行为时,通常会实现一个类。

于 2021-08-19T20:03:09.130 回答