1

我需要创建一个 DataLoader,其中 collat​​or 函数需要进行非平凡的计算,实际上是一个双层循环,它显着减慢了训练过程。例如,考虑这个玩具代码,我尝试使用 numba 对 collat​​e 函数进行 JIT:

import torch
import torch.utils.data

import numba as nb


class Dataset(torch.utils.data.Dataset):
    def __init__(self):
        self.A = np.zeros((100000, 300))
        self.B = np.ones((100000, 300))
    
    def __getitem__(self, index):
        return self.A[index], self.B[index]
    
    def __len__(self):
        return self.A.shape[0]

@nb.njit(cache=True)
def _collate_fn(batch):
    batch_data = np.zeros((len(batch), 300))
    for i in range(len(batch)):
        batch_data[i] = batch[i][0] + batch[i][1]

    return batch_data

然后我按如下方式创建 DataLoader:

train_dataset = Dataset()

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=256,
    num_workers=6,
    collate_fn=_collate_fn,
    shuffle=True)

然而,这只是卡住了,但如果我删除_collate_fn. 我无法理解这里发生了什么。我不必坚持 numba 并且可以使用任何可以帮助我克服 Python 中循环效率低下的东西。TIA 和快乐 12,021。

4

0 回答 0