0

大家好,我正在使用 pytorch 处理 CIFAR10 数据集。我开发了一个模型,它工作得非常好,但主要问题出现在运行以下代码时:

import time
start_time=time.time()

epochs=5
train_losses=[]
test_losses=[]
train_correct=[]
test_correct=[]

for i in range(epochs):
    tsn_corr=0
    tst_corr=0
    
    for b, (X_train,y_train) in enumerate(train_loader):
        b+=1
        
        y_pred=model(X_train)
        loss=criterion(y_pred,y_train)
        
        
        #Tally the number of correct predictions
        
        predicted= torch.max(y_pred.data, 1)[1]
        batch_corr=(predicted==y_train).sum()
        tsn_corr += batch_corr
        
        #optimize paramters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    
        
        #print interim results
        if b%600 == 0:
            print(f"epochs: {i}, batch: {b}, loss: {loss.item():10.8f}")
        
    loss=loss.detach().numpy()
    train_losses.append(loss)
    train_correct.append(tsn_corr)
    
    #Running the test_batches
    
    with torch.no_grad():
        for b, (X_test,y_test) in enumerate(test_loader):
            b+=1
            
            y_val=model(X_test)
            
            
            
            #TALLY THE NUMBER OF CORRECT PREDICTIONS
            
            predicted=torch.max(y_val.data, 1)[1]
            batch_corr= (predicted==y_test).sum()
            tst_corr += batch_corr
            
    loss=criterion(y_val,y_test)    
    loss=loss.detach().numpy()
    test_losses.append(loss)
    test_correct.append(tst_corr)

运行以下代码时出现以下错误:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-43-48e21e83e9f7> in <module>
     15         b+=1
     16 
---> 17         y_pred=model(X_train)
     18         loss=criterion(y_pred,y_train)
     19 

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _forward_unimplemented(self, *input)
    199         registered hooks while the latter silently ignores them.
    200     """
--> 201     raise NotImplementedError
    202 
    203 

NotImplementedError: 

有人可以告诉我我该怎么做才能修复此代码。除此之外,之前的所有代码都可以正常工作,并且我使用卷积神经网络制作的模型也可以成功运行,这意味着模型没有问题。我想这个细节可能会有所帮助。可能会注意到此代码在 MNIST 数据集上运行良好。我不知道 CIFAR 数据集有什么问题

4

1 回答 1

0

你的模型类需要实现一个 forward 方法。请参阅有关子类化的 PyTorch 示例以查看示例。

于 2021-08-20T21:36:33.603 回答