介绍
这个有点开放,但让我们试试,如果我在某个地方错了,请纠正我。
到目前为止,我首先将数据导出到文件系统,子文件夹被命名为文档的类。
IMO 这是不明智的,因为:
- 你本质上是在复制数据
- 任何时候你想训练一个新的给定的代码和数据库,这个操作必须重复
- 您可以一次访问多个数据点并将它们缓存在 RAM 中以供以后重用,而无需从硬盘驱动器多次读取(这很重)
我对吗?直接连接到 MongoDB 有意义吗?
鉴于上述情况,可能是的(尤其是在清晰和可移植的实现方面)
还是有理由不这样做(例如,数据库通常会变慢等)?
在这种情况下,AFAIK DB 不应该变慢,因为它会缓存对它的访问,但不幸的是我不是数据库专家。许多加快访问速度的技巧都是开箱即用的数据库实现的。
可以以某种方式预取数据吗?
是的,如果您只想获取数据,您可以一次加载大部分数据(例如1024
记录)并从中返回批量数据(例如batch_size=128
)
执行
如何实现 PyTorch DataLoader?我在网上只找到了很少的代码片段([1] 和 [2]),这让我怀疑我的方法。
我不确定你为什么要这样做。您应该采取的措施torch.utils.data.Dataset
如您列出的示例所示。
我将从与此处类似的简单非优化方法开始,因此:
- 打开与 db 的连接并在
__init__
使用时保留它(我将创建一个上下文管理器,torch.utils.data.Dataset
以便在 epoch 完成后关闭连接)
- 我不会将结果转换为
list
(特别是因为明显的原因你不能将它放入 RAM 中),因为它错过了生成器的要点
- 我会在这个数据集中执行批处理(这里有一个参数
batch_size
)。
- 我不确定
__getitem__
函数,但它似乎可以一次返回多个数据点,因此我会使用它并且它应该允许我们使用num_workers>0
(假设mycol.find(query)
每次都以相同的顺序返回数据)
鉴于此,我会做一些类似的事情:
class DatabaseDataset(torch.utils.data.Dataset):
def __init__(self, query, batch_size, path: str, database: str):
self.batch_size = batch_size
client = pymongo.MongoClient(path)
self.db = client[database]
self.query = query
# Or non-approximate method, if the approximate method
# returns smaller number of items you should be fine
self.length = self.db.estimated_document_count()
self.cursor = None
def __enter__(self):
# Ensure that this find returns the same order of query every time
# If not, you might get duplicated data
# It is rather unlikely (depending on batch size), shouldn't be a problem
# for 20 million samples anyway
self.cursor = self.db.find(self.query)
return self
def shuffle(self):
# Find a way to shuffle data so it is returned in different order
# If that happens out of the box you might be fine without it actually
pass
def __exit__(self, *_, **__):
# Or anything else how to close the connection
self.cursor.close()
def __len__(self):
return len(self.examples)
def __getitem__(self, index):
# Read takes long, hence if you can load a batch of documents it should speed things up
examples = self.cursor[index * batch_size : (index + 1) * batch_size]
# Do something with this data
...
# Return the whole batch
return data, labels
现在批处理由处理DatabaseDataset
,因此torch.utils.data.DataLoader
可以有batch_size=1
。您可能需要挤压额外的维度。
使用锁(这MongoDB
并不奇怪,但请参见此处)num_workers>0
应该不是问题。
可能的用法(示意图):
with DatabaseDataset(...) as e:
dataloader = torch.utils.data.DataLoader(e, batch_size=1)
for epoch in epochs:
for batch in dataloader:
# And all the stuff
...
dataset.shuffle() # after each epoch
记住在这种情况下改组实现!(也可以在上下文管理器中完成洗牌,您可能想要手动关闭连接或类似的东西)。