我在下面有以下数据加载器:
def load_dataset(size_batch, size):
data_path = "/home/bledc/dataset/test_set/crops_BSD"
transformations = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor()
])
train_dataset = datasets.ImageFolder(
root=data_path,
transform=transformations
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=size_batch,
shuffle=True,
num_workers=0,
drop_last=True
)
return train_loader
我在训练循环中使用以下内容迭代它:
data_loader = load_dataset(batch_size, width)
for data in data_loader:
model.zero_grad()
optimizer.zero_grad()
img, _ = data
img = img.to(device)
有人可以向我解释将 load_dataset() 函数写入类有什么好处吗?原因是我一直在使用上面的模板从我在网上找到的代码中加载数据,但似乎大多数代码库使用class LoaderName(Dataset)
后跟定义初始条件和 super()。
谢谢你。