你的问题是什么?
我正在尝试实现一个需要访问整个数据的指标。因此,我没有在 *_step() 方法中更新指标,而是尝试在 *_epoch_end() 方法中收集输出。但是,输出仅包含每个设备获取的数据分区的输出。基本上,如果有 n 个设备,那么每个设备将获得总输出的 1/n。
你的环境是什么?
OS: ubuntu
Packaging: conda
Version [1.0.4
Pytorch: 1.6.0
你的问题是什么?
我正在尝试实现一个需要访问整个数据的指标。因此,我没有在 *_step() 方法中更新指标,而是尝试在 *_epoch_end() 方法中收集输出。但是,输出仅包含每个设备获取的数据分区的输出。基本上,如果有 n 个设备,那么每个设备将获得总输出的 1/n。
你的环境是什么?
OS: ubuntu
Packaging: conda
Version [1.0.4
Pytorch: 1.6.0
使用 DDP 后端时,每个 GPU 都有一个单独的进程运行。它们无法访问彼此的数据,但有一些特殊操作(reduce、all_reduce、gather、all_gather)可以使进程同步。当您在张量上使用此类操作时,进程将等待彼此到达同一点并以某种方式组合它们的值,例如从每个进程中获取总和。
理论上,可以从所有进程中收集所有数据,然后在一个进程中计算指标,但这很慢并且容易出现问题,因此您希望最小化传输的数据。最简单的方法是分段计算指标,然后取平均值。当您使用时,self.log()
调用会自动执行此sync_dist=True
操作。
如果您不想取 GPU 进程的平均值,也可以在每一步更新一些状态变量,并在 epoch 同步状态变量并根据这些值计算您的指标。推荐的方法是创建一个使用 Metrics API 的类,该类最近从 PyTorch Lightning 转移到了TorchMetrics项目。
如果存储一组状态变量还不够,您可以尝试让您的指标收集来自所有进程的所有数据。从Metric基类派生您自己的指标,覆盖update()
和compute()
方法。用于add_state("data", default=[], dist_reduce_fx="cat")
创建一个列表,您可以在其中收集计算指标所需的数据。dist_reduce_fx="cat"
将导致来自不同进程的数据与torch.cat()
. 它在内部使用torch.distributed.all_gather。这里棘手的部分是它假设所有进程都创建相同大小的张量。如果大小不匹配,同步将无限期挂起。
请参阅pytorch-lightningmanual。我认为您正在寻找training_step_end
/ validation_step_end
(假设您使用的是 DP/DDP2)。
...因此,当 Lightning 调用任何 training_step、validation_step、test_step 时,您将只对其中一个部分进行操作。(...) 对于大多数指标,这并不重要。但是,如果您想使用所有批处理部分向您的计算图(如 softmax)添加一些东西,您可以使用 training_step_end 步骤。