0


在我的数据集中,我torchvision.dataset以前制作的每个类有 6 个类和 23 张图片,ImageFolder并且效果很好。

dataset = vision_dataset.ImageFolder(root = DATA_ROOT,
                                     transform = vision_trans.Compose([
                                                    vision_trans.Resize(256),
                                                    vision_trans.CenterCrop(256),
                                                    vision_trans.ToTensor()
                                     ]))

dataloader = torch.utils.data.DataLoader(dataset = dataset, batch_size = SHOT_K,
                                         shuffle = False, num_workers = 2, )

但我想获得具有相同类别的批量图像。

...
tensor([2, 2, 2, 2, 2])
tensor([2, 2])
tensor([3, 3, 3, 3, 3])
...

这就是我想要的标签(批处理数据的类)形式
,但实际上 DataLoader 会这样工作

...
tensor([2, 2, 2, 2, 2])
tensor([2, 2, 3, 3, 3])
tensor([3, 3, 3, 3, 3])
...

如何获取每个标签的批次数据?

4

1 回答 1

0

您不能方便地使用ImageFolder. 您应该为每个类创建一个数据集,并从您需要的数据集中加载您的批次。

更具体地说,假设您的文件夹结构是ImageFolder所需的,您需要创建一个小型数据集类:

class ImageSubFolder(torch.utils.data.Dataset):
    def __init__(self, root_dir, label):
        # Path toward the label-sorted subfolders of your dataset
        # Assuming images are named smthg like /path/to/label/xxxx.npy
        self._path = root_dir + label+ "{:04d}"

    def __len__(self):
        return count_files_in_directory(self._path)

    def __getitem__(self, index):
        return (np.load(self._path.format(index), label)

这只是为了展示类的逻辑,我相信你仍然需要实现一些功能(你可以按照本教程进行操作)。“其余要实现的功能留给读者作为练习”。无论如何,对于此类,您只需要创建 6 个实例(每个类一个):

loaders = {}
for label in ("dog", "cat", "plane", "tree", "mug", "car"):
    dataset = SubFolderDataset(DATA_ROOT, label)
    loaders[label] = torch.utils.data.DataLoader(dataset = dataset, batch_size = SHOT_K,shuffle = False, num_workers = 2, )

现在你有一个包含数据加载器的字典,它只加载给定类的样本。

于 2021-02-04T10:08:23.733 回答