在我的数据集中,我torchvision.dataset
以前制作的每个类有 6 个类和 23 张图片,ImageFolder
并且效果很好。
dataset = vision_dataset.ImageFolder(root = DATA_ROOT,
transform = vision_trans.Compose([
vision_trans.Resize(256),
vision_trans.CenterCrop(256),
vision_trans.ToTensor()
]))
dataloader = torch.utils.data.DataLoader(dataset = dataset, batch_size = SHOT_K,
shuffle = False, num_workers = 2, )
但我想获得具有相同类别的批量图像。
...
tensor([2, 2, 2, 2, 2])
tensor([2, 2])
tensor([3, 3, 3, 3, 3])
...
这就是我想要的标签(批处理数据的类)形式
,但实际上 DataLoader 会这样工作
...
tensor([2, 2, 2, 2, 2])
tensor([2, 2, 3, 3, 3])
tensor([3, 3, 3, 3, 3])
...
如何获取每个标签的批次数据?