我如何知道 PyTorch 中图层的输入节点或图层名称?假设我有一个torch.cat,我怎么知道它从哪里获取输入的张量或层的名称?
对于来自https://rosenfelder.ai/multi-input-neural-network-pytorch/的代码
class LitClassifier(pl.LightningModule):
def __init__(
self, lr: float = 1e-3, num_workers: int = 4, batch_size: int = 32,
):
super().__init__()
self.lr = lr
self.num_workers = num_workers
self.batch_size = batch_size
self.conv1 = conv_block(3, 16)
self.conv2 = conv_block(16, 32)
self.conv3 = conv_block(32, 64)
self.ln1 = nn.Linear(64 * 26 * 26, 16)
self.relu = nn.ReLU()
self.batchnorm = nn.BatchNorm1d(16)
self.dropout = nn.Dropout2d(0.5)
self.ln2 = nn.Linear(16, 5)
self.ln4 = nn.Linear(5, 10)
self.ln5 = nn.Linear(10, 10)
self.ln6 = nn.Linear(10, 5)
self.ln7 = nn.Linear(10, 1)
def forward(self, img, tab):
img = self.conv1(img)
img = self.conv2(img)
img = self.conv3(img)
img = img.reshape(img.shape[0], -1)
img = self.ln1(img)
img = self.relu(img)
img = self.batchnorm(img)
img = self.dropout(img)
img = self.ln2(img)
img = self.relu(img)
tab = self.ln4(tab)
tab = self.relu(tab)
tab = self.ln5(tab)
tab = self.relu(tab)
tab = self.ln6(tab)
tab = self.relu(tab)
x = torch.cat((img, tab), dim=1)
x = self.relu(x)
return self.ln7(x)
因此,如果我想知道 torch.cat 从哪一层接收输入。
对于我们拥有的 keras model.get_layer(id=idx).input.name
,PyTorch 是否也有类似的东西?