我有一个图像数据集,每个图像都有一个附加属性“channel_no”。每个图像都应根据其 channel_no 使用 nn 层进行处理:
images with channel_no=1 have to be processed with layer1
images with channel_no=2 have to be processed with layer2
images with channel_no=3 have to be processed with layer3
etc...
问题是当批次包含多个图像时,forward() 函数会获取一个以该批次图像作为输入的 Torch 张量,并且每个图像具有不同的 channel_no。所以不清楚如何分别处理每个图像。
以下是批次只有 1 张图片的情况的代码:
class Net(nn.Module):
def __init__ (self, weight):
super(Net, self).__init__()
self.layer1 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
self.layer2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
self.layer3 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
self.outp = nn.Linear(hidden_sizes[1], output_size)
def forward(self, x, channel_no):
channel_no = channel_no[0] #extract channel_no from the batch list
x = x.view(-1,hidden_sizes[0])
if channel_no == 1: x = F.relu(self.layer1(x))
if channel_no == 2: x = F.relu(self.layer2(x))
if channel_no == 3: x = F.relu(self.layer3(x))
x = torch.sigmoid(self.outp(x))
return x
是否可以使用 batch size > 1 分别处理每个图像?