4

我最近开始使用 PyTorch,我喜欢它的面向对象风格。但是,我想知道在预测模型时最好和建议的工作流程是什么。我想使用我编写的自定义数据集类,用于训练和验证我的模型。这个类是一个地图风格的数据集,因此我实现__getitem__了返回图像和目标的方法:

class CustomDataset:

    def __init__(self, ...):
        ...

    def __getitem__(self, image_id):
        ....
        return (
            torch.tensor(image, dtype=torch.float),
            torch.tensor(target, dtype=torch.long),
       )

但是,当我使用此类进行预测时,我没有任何返回目标。我目前的解决方法是

def __getitem__(self, image_id):
    ....
    if predict:
        return (
            torch.tensor(image, dtype=torch.float),
            np.nan,
       )
   else:
        return (
             torch.tensor(image, dtype=torch.float),
             torch.tensor(target, dtype=torch.long),
       )

但是,我想知道是否有更好的方法来做到这一点。同时,由于感觉有点不自然,我开始想知道使用同一个类进行训练和预测是否是可取的(应该是,但我的解决方案的笨拙让我想知道)。当然,我根本无法返回元组,只能返回第一个元素,但这仍然需要 if-else。

4

3 回答 3

2

PyTorch 的DataSet类非常简单。所以,不要想太多。它只不过是用于访问数据的包装器。

您不必返回元组,甚至不必返回张量。你可以返回任何你想要的数据。通常,它将采用以下样式之一:

  • 对于无监督数据:Sample(Sample, None)
  • 对于监督数据:(Sample, Label)
  • 对于具有多个目标的监督数据,例如对象检测:(Sample, [Label1, Label2, ...])(Sample, Label1, Label2, ...)

训练/测试使用相同的 DataSet 类也很常见。

(sample, None) 因此,在您的情况下,只需像在 torchvision 中所做的那样返回样本或元组并相应地调整您的管道。我不建议使用np.nan它,因为它会使简单的无检查 ( np.nan == None) 失败。另外,我鼓励你继承自torch.data.Dataset.

但是,如果您的管道迫使您使用元组或有其他限制,我建议您重新表述您的问题。

于 2021-05-20T14:51:46.357 回答
0

您必须编写代码来创建与您的数据和问题场景相匹配的数据集;没有两个 Dataset 实现是完全相同的。另一方面,DataLoader 对象的使用几乎相同,无论它与哪个 Dataset 对象相关联。例如:

class MyDataSet(T.utils.data.Dataset):
  # implement custom code to load data here

my_ds = MyDataset("my_train_data.txt")
my_ldr = torch.utils.data.DataLoader(my_ds, 10, True)
for (idx, batch) in enumerate(my_ldr):
  . . .
于 2021-05-23T07:15:04.127 回答
0

我认为如果您想要一个“纯”(= 否如果)解决方案,您可以定义一个“不关心”类,并让您的损失以某种方式忽略它(可能通过内部屏蔽完成,从技术上讲,它是一个“如果",但它是矢量化的)。

例如,请参阅CrossEntropyLoss必须ignore_index处理此类情况。这让我相信不关心类索引是设计的方式。

class CDiscountDataset:
    self.ignore_index = self._preprocess_number_of_classes() + 1

    def __getitem__(self, image_id):
        target_tensor = torch.tensor(target, dtype=torch.long)

        if predict:
            return (
                torch.tensor(image, dtype=torch.float),
                torch.ones_like(target_tensor, dtype=torch.long) * self.ignore_index,
            )
        else:
            return (
                torch.tensor(image, dtype=torch.float),
                target_tensor,
           )

作为旁注,我正在使用Pytorch-lightning,它为训练管道提供了很好的抽象,它隐含地假设只有张量元组从数据加载器返回,这让我相信返回其他类型是“不那么规范”,这加强了我的信念上述“不在乎”的方式是要走的路。

于 2021-05-23T08:10:24.240 回答