我需要创建一个 DataLoader,其中 collator 函数需要进行非平凡的计算,实际上是一个双层循环,它显着减慢了训练过程。例如,考虑这个玩具代码,我尝试使用 numba 对 collate 函数进行 JIT:
import torch
import torch.utils.data
import numba as nb
class Dataset(torch.utils.data.Dataset):
def __init__(self):
self.A = np.zeros((100000, 300))
self.B = np.ones((100000, 300))
def __getitem__(self, index):
return self.A[index], self.B[index]
def __len__(self):
return self.A.shape[0]
@nb.njit(cache=True)
def _collate_fn(batch):
batch_data = np.zeros((len(batch), 300))
for i in range(len(batch)):
batch_data[i] = batch[i][0] + batch[i][1]
return batch_data
然后我按如下方式创建 DataLoader:
train_dataset = Dataset()
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=256,
num_workers=6,
collate_fn=_collate_fn,
shuffle=True)
然而,这只是卡住了,但如果我删除_collate_fn
. 我无法理解这里发生了什么。我不必坚持 numba 并且可以使用任何可以帮助我克服 Python 中循环效率低下的东西。TIA 和快乐 12,021。