我正在尝试在 PyTorch 中对已存储到特定尺寸的图像训练深度学习模型。我想使用小批量训练我的模型,但小批量大小并不能很好地划分每个桶中的示例数量。
我在上一篇文章中看到的一个解决方案是用额外的空白填充图像(在训练开始时即时或一次全部填充),但我不想这样做。相反,我想在训练期间允许批量大小灵活。
具体来说,如果N
是存储桶中的图像数量并且B
是批量大小,那么对于该存储桶,我希望得到N // B
批次(如果是B
分N
),N // B + 1
否则是批次。最后一批可以有少于B
示例。
例如,假设我有索引 [0, 1, ..., 19],包括在内,我想使用 3 的批量大小。
索引 [0, 9] 对应存储桶 0 中的图像(形状 (C, W1, H1))
索引 [10, 19] 对应存储桶 1 中的图像(形状 (C, W2, H2))
(所有图像的通道深度相同)。那么可接受的索引分区将是
batches = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
[9],
[10, 11, 12],
[13, 14, 15],
[16, 17, 18],
[19]
]
我更愿意分别处理索引为 9 和 19 的图像,因为它们具有不同的尺寸。
通过查看 PyTorch 的文档,我找到了BatchSampler
生成小批量索引列表的类。我创建了一个Sampler
模拟上述索引分区的自定义类。如果有帮助,这是我的实现:
class CustomSampler(Sampler):
def __init__(self, dataset, batch_size):
self.batch_size = batch_size
self.buckets = self._get_buckets(dataset)
self.num_examples = len(dataset)
def __iter__(self):
batch = []
# Process buckets in random order
dims = random.sample(list(self.buckets), len(self.buckets))
for dim in dims:
# Process images in buckets in random order
bucket = self.buckets[dim]
bucket = random.sample(bucket, len(bucket))
for idx in bucket:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
# Yield half-full batch before moving to next bucket
if len(batch) > 0:
yield batch
batch = []
def __len__(self):
return self.num_examples
def _get_buckets(self, dataset):
buckets = defaultdict(list)
for i in range(len(dataset)):
img, _ = dataset[i]
dims = img.shape
buckets[dims].append(i)
return buckets
但是,当我使用自定义Sampler
类时,会生成以下错误:
Traceback (most recent call last):
File "sampler.py", line 143, in <module>
for i, batch in enumerate(dataloader):
File "/home/roflcakzorz/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 263, in __next__
indices = next(self.sample_iter) # may raise StopIteration
File "/home/roflcakzorz/anaconda3/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 139, in __iter__
batch.append(int(idx))
TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'
该类DataLoader
似乎期望传递索引,而不是索引列表。
我不应该Sampler
为此任务使用自定义类吗?我还考虑过自定义collate_fn
传递给DataLoader
,但使用这种方法我不相信我可以控制允许哪些索引在同一个小批量中。任何指导将不胜感激。