3

如果我们使用DatasetDataloader类的组合(如下所示),我必须使用或将数据显式加载到GPU上。有没有办法指示数据加载器自动/隐式地执行它?.to().cuda()

理解/重现场景的代码:

from torch.utils.data import Dataset, DataLoader
import numpy as np

class DemoData(Dataset):
    def __init__(self, limit):
        super(DemoData, self).__init__()
        self.data = np.arange(limit)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return (self.data[idx], self.data[idx]*100)

demo = DemoData(100)

loader = DataLoader(demo, batch_size=50, shuffle=True)

for i, (i1, i2) in enumerate(loader):
    print('Batch Index: {}'.format(i))
    print('Shape of data item 1: {}; shape of data item 2: {}'.format(i1.shape, i2.shape))
    # i1, i2 = i1.to('cuda:0'), i2.to('cuda:0')
    print('Device of data item 1: {}; device of data item 2: {}\n'.format(i1.device, i2.device))

这将输出以下内容;注意 - 没有明确的设备传输指令,数据被加载到CPU上:

Batch Index: 0
Shape of data item 1: torch.Size([50]); shape of data item 2: torch.Size([50])
Device of data item 1: cpu; device of data item 2: cpu

Batch Index: 1
Shape of data item 1: torch.Size([50]); shape of data item 2: torch.Size([50])
Device of data item 1: cpu; device of data item 2: cpu

一个可能的解决方案是在这个 PyTorch GitHub 存储库中。问题在发布此问题时仍处于打开状态),但是当数据加载器必须返回多个数据项时,我无法使其工作!

4

1 回答 1

2

您可以修改collate_fn以一次处理多个项目:

from torch.utils.data.dataloader import default_collate

device = torch.device('cuda:0')  # or whatever device/cpu you like

# the new collate function is quite generic
loader = DataLoader(demo, batch_size=50, shuffle=True, 
                    collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))

请注意,如果您希望数据加载器有多个工作人员,则需要添加

torch.multiprocessing.set_start_method('spawn')

在你之后if __name__ == '__main__'(见这个问题)。

话虽如此,似乎使用pin_memory=Truein yourDataLoader会更有效率。你试过这个选项吗?
有关详细信息,请参阅内存固定


更新(2021 年 2 月 8 日)
这篇文章让我了解了我在训练期间花费的“数据到模型”时间。我比较了三种选择:

  1. DataLoader在 CPU 上工作,并且只有在检索到批次后,数据才会移动到 GPU。
  2. 与 (1) 相同,但带有pin_memory=Truein DataLoader
  3. 提出的collate_fn用于将数据移动到 GPU 的方法。

从我有限的实验来看,第二个选项似乎表现最好(但不是很大)。
第三个选项需要对数据加载器进程大惊小怪start_method,而且似乎在每个 epoch 开始时都会产生开销。

于 2021-02-02T15:17:52.753 回答