这是我在 PyTorch 中的代码片段,当我使用 num_workers > 0 时,我的 jupiter notebook 卡住了,我在这个问题上花了很多时间却没有任何答案。我没有 GPU,我只使用 CPU。
class IndexedDataset(Dataset):
def __init__(self,data,targets, test=False):
self.dataset = data
if not test:
self.labels = targets.numpy()
self.mask = np.concatenate((np.zeros(NUM_LABELED), np.ones(NUM_UNLABELED)))
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image = self.dataset[idx]
return image, self.labels[idx]
def display(self, idx):
plt.imshow(self.dataset[idx], cmap='gray')
plt.show()
train_set = IndexedDataset(train_data, train_target, test = False)
test_set = IndexedDataset(test_data, test_target, test = True)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, num_workers=2)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, num_workers=2)
任何帮助,不胜感激。