我有一个 3 文件。在datamodule
文件中,我创建了数据并使用了PyTorch Lightning
. 在linear_model
我linear regression model
根据这个页面做了一个。最后,我有一个train
文件,我正在调用模型并尝试拟合数据。但我收到了这个错误
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/mostafiz/Dropbox/MSc/Thesis/regreesion_EC/src/test_train.py", line 10, in <module>
train_dataloader=datamodule.DataModuleClass().setup().train_dataloader(),
AttributeError: 'tuple' object has no attribute 'train_dataloader'
示例数据模块文件
class DataModuleClass(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.sigma = 5
self.batch_size = 10
self.prepare_data()
def prepare_data(self):
x = np.random.uniform(0, 10, 10)
e = np.random.normal(0, self.sigma, len(x))
y = x + e
X = np.transpose(np.array([x, e]))
self.x_train_tensor = torch.from_numpy(X).float().to(device)
self.y_train_tensor = torch.from_numpy(y).float().to(device)
training_dataset = TensorDataset(self.x_train_tensor, self.y_train_tensor)
self.training_dataset = training_dataset
def setup(self):
data = self.training_dataset
self.train_data, self.val_data = random_split(data, [8, 2])
return self.train_data, self.val_data
def train_dataloader(self):
return DataLoader(self.train_data)
def val_dataloader(self):
return DataLoader(self.val_data)
样本训练文件
from . import datamodule, linear_model
model = linear_model.LinearRegression(input_dim=2, l1_strength=1, l2_strength=1)
trainer = pl.Trainer()
trainer.fit(model,
train_dataloader=datamodule.DataModuleClass().setup().train_dataloader(),
val_dataloaders=datamodule.DataModuleClass().setup().val_dataloaders())
如果您需要更多代码或解释,请告诉我。
更新(基于评论)
现在,self.prepare_data()
从 中删除,从中__init__()
删除DataModuleClass()
并将文件更改为return self.train_data, self.val_data
setup()
test
data_module = datamodule.DataModuleClass()
trainer = pl.Trainer()
trainer.fit(model,data_module)
错误:
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/mostafiz/Dropbox/MSc/Thesis/regreesion_EC/src/test_train.py", line 10, in <module>
train_dataloader=datamodule.DataModuleClass().train_dataloader(),
File "/home/mostafiz/Dropbox/MSc/Thesis/regreesion_EC/src/datamodule.py", line 54, in train_dataloader
return DataLoader(self.train_data)
AttributeError: 'DataModuleClass' object has no attribute 'train_data'