我最近开始使用 PyTorch,我喜欢它的面向对象风格。但是,我想知道在预测模型时最好和建议的工作流程是什么。我想使用我编写的自定义数据集类,用于训练和验证我的模型。这个类是一个地图风格的数据集,因此我实现__getitem__
了返回图像和目标的方法:
class CustomDataset:
def __init__(self, ...):
...
def __getitem__(self, image_id):
....
return (
torch.tensor(image, dtype=torch.float),
torch.tensor(target, dtype=torch.long),
)
但是,当我使用此类进行预测时,我没有任何返回目标。我目前的解决方法是
def __getitem__(self, image_id):
....
if predict:
return (
torch.tensor(image, dtype=torch.float),
np.nan,
)
else:
return (
torch.tensor(image, dtype=torch.float),
torch.tensor(target, dtype=torch.long),
)
但是,我想知道是否有更好的方法来做到这一点。同时,由于感觉有点不自然,我开始想知道使用同一个类进行训练和预测是否是可取的(应该是,但我的解决方案的笨拙让我想知道)。当然,我根本无法返回元组,只能返回第一个元素,但这仍然需要 if-else。