2

我想知道使用连接到 MongoDB 的 DataLoader 是否明智,以及如何实现。

背景

我在(本地)MongoDB 中有大约 2000 万个文档。比记忆中的文件多得多。我想在数据上训练一个深度神经网络。到目前为止,我首先将数据导出到文件系统,子文件夹被命名为文档的类。但我觉得这种方法很荒谬。如果数据已经保存在数据库中,为什么要先导出(然后再删除)。

问题一:

我对吗?直接连接到 MongoDB 有意义吗?还是有理由不这样做(例如,数据库通常太慢等)?如果数据库太慢(为什么?),可以以某种方式预取数据吗?

问题2:

如何实现 PyTorch DataLoader?我在网上只找到了很少的代码片段([1][2]),这让我怀疑我的方法。

代码片段

我访问 MongoDB 的一般方式如下。这没什么特别的,我想。

import pymongo
from pymongo import MongoClient

myclient = pymongo.MongoClient("mongodb://localhost:27017/")
mydb = myclient["xyz"]
mycol = mydb["xyz_documents"]

query = {
    # some filters
}

results = mycol.find(query)

# results is now a cursor that can run through all docs
# Assume, for the sake of this example, that each doc contains a class name and some image that I want to train a classifier on
4

1 回答 1

2

介绍

这个有点开放,但让我们试试,如果我在某个地方错了,请纠正我。

到目前为止,我首先将数据导出到文件系统,子文件夹被命名为文档的类。

IMO 这是不明智的,因为:

  • 你本质上是在复制数据
  • 任何时候你想训练一个新的给定的代码和数据库,这个操作必须重复
  • 您可以一次访问多个数据点并将它们缓存在 RAM 中以供以后重用,而无需从硬盘驱动器多次读取(这很重)

我对吗?直接连接到 MongoDB 有意义吗?

鉴于上述情况,可能是的(尤其是在清晰和可移植的实现方面)

还是有理由不这样做(例如,数据库通常会变慢等)?

在这种情况下,AFAIK DB 不应该变慢,因为它会缓存对它的访问,但不幸的是我不是数据库专家。许多加快访问速度的技巧都是开箱即用的数据库实现的。

可以以某种方式预取数据吗?

是的,如果您只想获取数据,您可以一次加载大部分数据(例如1024记录)并从中返回批量数据(例如batch_size=128

执行

如何实现 PyTorch DataLoader?我在网上只找到了很少的代码片段([1] 和 [2]),这让我怀疑我的方法。

我不确定你为什么要这样做。您应该采取的措施torch.utils.data.Dataset如您列出的示例所示。

我将从与此处类似的简单非优化方法开始,因此:

  • 打开与 db 的连接并在__init__使用时保留它(我将创建一个上下文管理器,torch.utils.data.Dataset以便在 epoch 完成后关闭连接)
  • 我不会将结果转换为list(特别是因为明显的原因你不能将它放入 RAM 中),因为它错过了生成器的要点
  • 我会在这个数据集中执行批处理(这里有一个参数batch_size
  • 我不确定__getitem__函数,但它似乎可以一次返回多个数据点,因此我会使用它并且它应该允许我们使用num_workers>0(假设mycol.find(query)每次都以相同的顺序返回数据)

鉴于此,我会做一些类似的事情:

class DatabaseDataset(torch.utils.data.Dataset):
    def __init__(self, query, batch_size, path: str, database: str):
        self.batch_size = batch_size

        client = pymongo.MongoClient(path)
        self.db = client[database]
        self.query = query
        # Or non-approximate method, if the approximate method
        # returns smaller number of items you should be fine
        self.length = self.db.estimated_document_count()

        self.cursor = None

    def __enter__(self):
        # Ensure that this find returns the same order of query every time
        # If not, you might get duplicated data
        # It is rather unlikely (depending on batch size), shouldn't be a problem
        # for 20 million samples anyway
        self.cursor = self.db.find(self.query)
        return self

    def shuffle(self):
        # Find a way to shuffle data so it is returned in different order
        # If that happens out of the box you might be fine without it actually
        pass

    def __exit__(self, *_, **__):
        # Or anything else how to close the connection
        self.cursor.close()

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        # Read takes long, hence if you can load a batch of documents it should speed things up
        examples = self.cursor[index * batch_size : (index + 1) * batch_size]
        # Do something with this data
        ...
        # Return the whole batch
        return data, labels

现在批处理由处理DatabaseDataset,因此torch.utils.data.DataLoader可以有batch_size=1。您可能需要挤压额外的维度。

使用锁(这MongoDB并不奇怪,但请参见此处num_workers>0应该不是问题。

可能的用法(示意图):

with DatabaseDataset(...) as e:
    dataloader = torch.utils.data.DataLoader(e, batch_size=1)
    for epoch in epochs:
        for batch in dataloader:
            # And all the stuff
            ...
        dataset.shuffle() # after each epoch

记住在这种情况下改组实现!(也可以在上下文管理器中完成洗牌,您可能想要手动关闭连接或类似的东西)。

于 2021-01-13T21:50:15.883 回答