-1

如何在洗牌的同时保持每批中的序列不洗牌?

受到此处提出的问题的启发。

4

2 回答 2

1

虽然这不是您问题的直接答案。我想用你自己发布的答案来解决一个问题。在我看来,执行以下操作是一个非常糟糕的主意:

dataloader = random.sample(list(dataloader), len(dataloader))

这首先破坏了创建数据集和数据加载器的整个目的。因为一旦你打电话list(dataloader),你最终会将你的数据集编译成一个张量列表。换句话说,它将调用__getitem__数据集中的每个索引。数据加载器是设计人员逐批加载数据(或更多取决于工作人员的数量),避免一次将整个数据集加载到内存中。

在处理需要从文件系统加载图像的图像时,这一点更为重要。这很关键,我相信你根本不应该这样做。

看看这里,有一个虚拟数据集:

class DS(Dataset):
    def __getitem__(self, _):
        return torch.rand(100)

    def __len__(self):
        return 10000

dl = DataLoader(DS(), batch_size=16)
x = list(dl)

这里x将包含10,000个大小为100的张量,您的计算机可以处理这些张量。现在想象一下有一个由10,000 个 512x512 RGB 图像组成的数据集,你无法在内存中保存那么多!

此外,我什至没有提到的是数据增强。这只有在保留数据加载器(即生成器)时才有可能。因此,使用list(dataloader).


相反,我建议您为每个项目Dataset生成未打乱DataLoader的序列,然后使用shuffle=True. 这感觉比生成一个DataLoader只编译下来要自然得多。按预期使用您的数据集类。它应该是构建每个序列(即数据点)的那个,或者正如@Prune所说的那样,它是一个“单个观察对象”。

于 2021-01-03T21:28:15.030 回答
0

[错误答案 - 使用上面@Ivan 的答案]

  1. 创建数据集
dataset = [1, 2, 3, 4, 5, 6, 7, 8, 9] # Realistically use torch.utils.data.Dataset
  1. 创建一个非洗牌的数据加载器
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
  1. 将数据加载器强制转换为 a并list使用函数randomsample()
import random
dataloader = random.sample(list(dataloader), len(dataloader))

使用自定义批处理采样器或其他东西可能有更好的方法来做到这一点,但这对我来说太混乱了,所以上面的方法似乎效果很好。

于 2021-01-03T20:40:51.300 回答