3

我有这个 LightningDataModule:

class MTmetricDataModule(pl.LightningDataModule):
 def __init__(self, df):
  super().__init__()
  self.reference = df['reference'].astype(str)
  self.translation = df['translation'].astype(str)
  self.z_score = df['avg-score']
  self.z_score = np.array(self.z_score)


def setup(self, stage=None):
 self.reference_train, self.reference_test, self.translation_train, self.translation_test, 
 self.z_score_train, self.z_score_test = train_test_split(self.reference, self.translation, 
 self.z_score, test_size=0.2)
 self.reference_test, self.reference_dev, self.translation_test, self.translation_dev, 
 self.z_score_test, self.z_score_dev = train_test_split(self.reference_test, 
 self.translation_test, self.z_score_test, test_size=0.1)


 self.df_train = pd.DataFrame()
 self.df_train['reference'] = self.reference_train
 self.df_train['translation'] = self.translation_train
 self.df_train['z_score'] = self.z_score_train
 self.train = self.df_train.to_dict("records")

 self.df_dev = pd.DataFrame()
 self.df_dev['reference'] = self.reference_dev
 self.df_dev['translation'] = self.translation_dev
 self.df_dev['z_score'] = self.z_score_dev
 self.dev = self.df_dev.to_dict("records")

 self.df_test = pd.DataFrame()
 self.df_test['reference'] = self.reference_test
 self.df_test['translation'] = self.translation_test
 self.df_test['z_score'] = self.z_score_test
 self.test = self.df_test.to_dict("records")


def train_dataloader(self) -> DataLoader:
 return DataLoader(
   dataset=self.train,
   batch_size=batch_size
 )

def val_dataloader(self) -> DataLoader:
 return DataLoader(
   dataset=self.dev,
   batch_size=1
 )

def test_dataloader(self) -> DataLoader:
 return DataLoader(
   dataset=self.test,
   batch_size=1
 )

然后我只是输入我的 LightningModule

data = MTmetricDataModule(df)
model = MTmetric()
trainer = Trainer(gpus=1, progress_bar_refresh_rate=20, max_epochs=2) 
trainer.fit(model, data)

但我目前收到错误“RuntimeError:输入、输出和索引必须在当前设备上”。

我目前正在使用 GPU 实例在 Colab 中运行它,但似乎没有任何东西可以让它工作。

谁知道怎么修它?

谢谢

4

2 回答 2

0

使用.to(device)可能有用。device = cuda()或者cpu()

Pytorch_Forums 上的这个解决方案会很有帮助。

于 2022-01-12T11:59:57.407 回答
0

这意味着您的一些数据在 gpu 上,而一些正在使用 cpu。请在同一设备上传输整个数据并再次运行。目前,您的数据加载器已加载到 cpu 上,而您使用 GPU 的“培训师”可能是错误的。

于 2022-01-12T14:06:05.127 回答