2

我正在使用 PyTorch Lightning 编写一个简单的训练器,但是当我尝试运行训练器时,由于某种原因,10 次中有 9 次返回“CUDA 错误:设备端断言”。简单地在它之前打印一个换行符似乎使它工作。有任何想法吗?

我的代码:

class Elementwise(nn.ModuleList):
    """
    A simple network container.
    Parameters are a list of modules.
    Inputs are a 3d Tensor whose last dimension is the same length
    as the list.
    Outputs are the result of applying modules to inputs elementwise.
    An optional merge parameter allows the outputs to be reduced to a
    single Tensor.
    """

    def __init__(self, merge=None, *args):
        assert merge in [None, 'first', 'concat', 'sum', 'mlp']
        self.merge = merge
        super(Elementwise, self).__init__(*args)

    def forward(self, inputs):
        inputs_ = [feat.squeeze(1) for feat in inputs.split(1, dim=1)]
        for i, j in enumerate(inputs_):
            inp = torch.tensor(j).to(device).long()
            inputs_[i] = inp

        # this does not work
        outputs = [f(x) for i, (f, x) in enumerate(zip(self, inputs_))]
                    
        if self.merge == 'first':
            return outputs[0]
        elif self.merge == 'concat' or self.merge == 'mlp':
            return torch.cat(outputs, 1)
        elif self.merge == 'sum':
            return sum(outputs)
        else:
            return outputs

但不知何故,这很神奇:

class Elementwise(nn.ModuleList):
    """
    A simple network container.
    Parameters are a list of modules.
    Inputs are a 3d Tensor whose last dimension is the same length
    as the list.
    Outputs are the result of applying modules to inputs elementwise.
    An optional merge parameter allows the outputs to be reduced to a
    single Tensor.
    """

    def __init__(self, merge=None, *args):
        assert merge in [None, 'first', 'concat', 'sum', 'mlp']
        self.merge = merge
        super(Elementwise, self).__init__(*args)

    def forward(self, inputs):
        inputs_ = [feat.squeeze(1) for feat in inputs.split(1, dim=1)]
        for i, j in enumerate(inputs_):
            inp = torch.tensor(j).to(device).long()
            inputs_[i] = inp

        print("")
        outputs = [f(x) for i, (f, x) in enumerate(zip(self, inputs_))]
                    
        if self.merge == 'first':
            return outputs[0]
        elif self.merge == 'concat' or self.merge == 'mlp':
            return torch.cat(outputs, 1)
        elif self.merge == 'sum':
            return sum(outputs)
        else:
            return outputs

关于如何通过简单地打印到输出来解决此错误的任何想法?

编辑:仅在使用 PyTorch Lightning 进行抽象训练时才会出现此错误,使用普通的 PyTorch 使其工作正常。

4

0 回答 0