0

我对 Pytorch 比较陌生,并且一直在 MNIST 数据集上训练一个 AutoEncoder 模型。在训练模型之前,我有三个数据加载器用于训练集、验证集和测试集。

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

# get minibatch
x_train, _ = next(iter(train_loader)) 
x_val, _ = next(iter(valid_loader))
x_test, _ = next(iter(test_loader))

这三个小批量具有以下大小:

torch.Size([128, 784])
torch.Size([128, 784])
torch.Size([128, 784])

但是,当我运行训练循环(在验证阶段)时,验证数据的形状不匹配,并且出现以下错误:

ValueError:不推荐使用与输入大小 (torch.Size([128, 784])) 不同的目标大小 (torch.Size([96, 784]))。

简单模型看起来像

class AE(nn.Module):
def __init__(self,latent_dim):
    super(AE, self).__init__()
    ### Encoder layers
    self.fc_enc1 = nn.Linear(784, 32)
    self.fc_enc2 = nn.Linear(32, 16)
    self.fc_enc3 = nn.Linear(16, latent_dim)
    
    ### Decoder layers
    self.fc_dec1 = nn.Linear(latent_dim, 16)
    self.fc_dec2 = nn.Linear(16,32)
    self.fc_dec3 = nn.Linear(32,784)

def encode(self, x):       
    z = F.relu(self.fc_enc1(x))
    z = F.relu(self.fc_enc2(z))
    z = F.relu(self.fc_enc3(z))
    
    return z

def decode(self, z):    
    xHat = F.relu(self.fc_dec1(z))
    xHat = F.relu(self.fc_dec2(xHat))
    xHat = F.sigmoid(self.fc_dec3(xHat))

    return xHat

def forward(self, x):
    ### Autoencoder returns the reconstruction and latent representation
    z = self.encode(x)
    
    ### decode z
    xHat = self.decode(z)
    return xHat, z 

训练循环如下所示:

AEmodel = AE(latent_dim).to(device)
optimizer = optim.Adam(AEmodel.parameters(), lr=lr)
loss_function = nn.BCELoss()

for epoch in range(1, epochs + 1):
AEmodel.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
    data = data.float().to(device)
    optimizer.zero_grad()
    xHat, z = AEmodel(data)
    loss = loss_function(xHat, data)
    loss.backward()
    train_loss += loss.item()
    optimizer.step()

AEmodel.eval()
valid_loss = 0
with torch.no_grad():
    for i, (data, _) in enumerate(valid_loader):
        data = data.float().to(device)
        valid_loss += loss_function(xHat, data).item()

错误发生在上述代码的最后一行。我一直无法弄清楚重塑出现在哪里,这会导致一些不匹配。我是不是瞎了眼,没有看到明显的错误??

4

1 回答 1

0

有意与否,在您的最后一个循环xHat中是恒定的,因为它没有被重新计算。它仍然是xHat您的火车循环中的最后一个。所以你比较xHatdata但它们不是来自同一个数据加载器(它们分别来自train_loadervalid_loader),因此没有任何东西强制它们具有相同的形状。如果您的验证循环,我相信您希望xHat在每次迭代中重新计算。

但是,这不是完整的解释,因为您粘贴的形状都不是 (96, 784)。我很确定它实际上是xHat你最后一个循环中的形状,请在那里添加一个打印语句来确认它。由于它是您的最后一批train_loader,因此它的大小不一定等于您的批量大小。当数据集的大小不能被批量大小整除时,就会发生这种情况。请查看datloader 文档,尤其是drop_last选项。

因此,要么您添加drop_last=True到您的数据加载器(这将使代码运行良好,但我不确定您是否希望它像那样工作),或者您xHat在验证循环的每次迭代中重新计算以获得有意义的自动编码器验证循环。

于 2021-04-07T23:34:48.613 回答