我有这个 LightningDataModule:
class MTmetricDataModule(pl.LightningDataModule):
def __init__(self, df):
super().__init__()
self.reference = df['reference'].astype(str)
self.translation = df['translation'].astype(str)
self.z_score = df['avg-score']
self.z_score = np.array(self.z_score)
def setup(self, stage=None):
self.reference_train, self.reference_test, self.translation_train, self.translation_test,
self.z_score_train, self.z_score_test = train_test_split(self.reference, self.translation,
self.z_score, test_size=0.2)
self.reference_test, self.reference_dev, self.translation_test, self.translation_dev,
self.z_score_test, self.z_score_dev = train_test_split(self.reference_test,
self.translation_test, self.z_score_test, test_size=0.1)
self.df_train = pd.DataFrame()
self.df_train['reference'] = self.reference_train
self.df_train['translation'] = self.translation_train
self.df_train['z_score'] = self.z_score_train
self.train = self.df_train.to_dict("records")
self.df_dev = pd.DataFrame()
self.df_dev['reference'] = self.reference_dev
self.df_dev['translation'] = self.translation_dev
self.df_dev['z_score'] = self.z_score_dev
self.dev = self.df_dev.to_dict("records")
self.df_test = pd.DataFrame()
self.df_test['reference'] = self.reference_test
self.df_test['translation'] = self.translation_test
self.df_test['z_score'] = self.z_score_test
self.test = self.df_test.to_dict("records")
def train_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.train,
batch_size=batch_size
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.dev,
batch_size=1
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.test,
batch_size=1
)
然后我只是输入我的 LightningModule
data = MTmetricDataModule(df)
model = MTmetric()
trainer = Trainer(gpus=1, progress_bar_refresh_rate=20, max_epochs=2)
trainer.fit(model, data)
但我目前收到错误“RuntimeError:输入、输出和索引必须在当前设备上”。
我目前正在使用 GPU 实例在 Colab 中运行它,但似乎没有任何东西可以让它工作。
谁知道怎么修它?
谢谢