2

我的代码是:

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=2, shuffle=True, num_workers=4,
collate_fn=utils.collate_fn)
# For Training
images,targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images,targets)   # Returns losses and detections
# For inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)           # Returns predictions

我得到错误:

“collat​​e_fn=utils.collat​​e_fn”显示错误“名称 'utils' 未定义”。“添加火炬后,模块'torch.utils'没有属性'collat​​e_fn'错误。

4

1 回答 1

1

好的,所以我阅读了教程,似乎它希望您使用此存储库中的帮助文件:https ://github.com/pytorch/vision/tree/master/references/detection 。

其中utils.py包含 collat​​e_fn 函数。因此,您似乎没有下载/复制此存储库以将其集成到您的项目中,对吧?

要解决这个错误,您可以复制 utils.py 中的 collat​​e_fn

def collate_fn(batch):
    return tuple(zip(*batch))

并将其粘贴到您的项目中。但是由于本教程可能还希望您使用 utils.py 的其他 util 功能,因此您可能需要下载此目录并将其放入您的项目目录中,以便您可以访问它。

于 2021-05-14T11:28:27.573 回答