0

我有一个这样的数据集:

指数 标签 特色1 特征2 目标
1 标签1 1.4342 88.4554 0.5365
2 标签1 2.5656 54.5466 0.1263
3 标签2 5.4561 845.556 0.8613
4 标签3 6.5546 8.52545 0.7864
5 标签3 8.4566 945.456 0.4646

每个标签中的条目数并不总是相同的。

我的目标是仅加载具有特定标签或标签的数据,以便我只tag1获取一个小批量的条目,然后tag2如果我设置了另一个小批量的条目batch_size=1。或者例如tag1tag2如果我设置batch_size=2

到目前为止,我的代码完全无视tag标签,只是随机选择批次。

我构建了这样的数据集:

# features is a matrix with all the features columns through all rows
# target is a vector with the target column through all rows
featuresTrain, targetTrain = projutils.get_data(train=True, config=config)
train = torch.utils.data.TensorDataset(featuresTrain, targetTrain)
train_loader = make_loader(train, batch_size=config.batch_size)

我的装载机(通常)看起来像这样:

def make_loader(dataset, batch_size):
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=batch_size, 
                                     shuffle=True,
                                     pin_memory=True,
                                     num_workers=8)
return loader

然后我像这样训练:

for epoch in range(config.epochs):
    for _, (features, target) in enumerate(loader):
        loss = train_batch(features, target, model, optimizer, criterion)

train_batch

def train_batch(features, target, model, optimizer, criterion):
features, target = features.to(device), target.to(device)

# Forward pass ➡
outputs = model(features)
loss = criterion(outputs, target
return loss
4

1 回答 1

3

一个简单的数据集,尽我所能大致实现您正在寻找的特征。

class CustomDataset(data.Dataset):
    def __init__(self,featuresTrain,targetsTrain,tagsTrain,sample_equally = False):
       # self.tags should be a tensor in k-hot encoding form so a 2D tensor, 
       self.tags = tagsTrain
       self.x = featuresTrain
       self.y = targetsTrain
       self.unique_tagsets = None
       self.sample_equally = sample_equally

       # self.active tags is a 1D k-hot encoding vector
       self.active_tags = self.get_random_tag_set()
       
    
    def get_random_tag_set(self):
        # gets all unique sets of tags and returns one randomly
        if self.unique_tagsets is None:
             self.unique_tagsets = self.tags.unique(dim = 0)
        if self.sample_equally:
             rand_idx = torch.randint(len(self.unique_tagsets),[1])[1].detatch().int()
             return self.unique_tagsets[rand_idx]
        else:
            rand_idx = torch.randint(len(self.tags),[1])[1].detatch().int()
            return self.tags[rand_idx]

    def set_tags(self,tags):
       # specifies the set of tags that must be present for a datum to be selected
        self.active_tags = tags

    def __getitem__(self,index):
        # get all indices of elements with self.active_tags
        indices = torch.where(self.tags == self.active_tags)[0]

        # we select an index based on the indices of the elements that have the tag set
        idx = indices[index % len(indices)]

        item = self.x[idx], self.y[idx]
        return item

    def __len__(self):
        return len(self.y)

该数据集随机选择一组标签。然后,每次__getitem__()调用时,它使用指定的索引从具有标签集的数据元素中进行选择。您可以在每个 minibatch 之后调用set_tags()get_random_tag_set()then ,或者您想要更改标签集的频率,或者您可以自己手动指定标签集。set_tags()数据集继承自,torch.data.Dataset因此您应该能够在torch.data.Dataloader不修改的情况下使用 if 。

您可以使用sample_equally.

简而言之,这个数据集的边缘有点粗糙,但应该允许您对所有具有相同标签集的批次进行采样。主要缺点是每个元素可能每批次采样一次以上。

对于初始编码,假设开始每个数据示例都有一个标签tags列表,列表列表也是如此,每个子列表都包含标签。以下代码会将其转换为 k-hot 编码,因此您可以:

def to_k_hot(tags):
  all_tags = []
  for ex in tags:
    for tag in ex:
        all_tags.append(tag)
  unique_tags = list(set(all_tags)) # remove duplicates

  tagsTrain = torch.zeros([len(tags),len(unique_tags)]): 
  for i in range(len(tags)): # index through all examples
    for j in range(len(unique_tags)): # index through all unique_tags
        if unique_tags[j] in tags[i]:
             tagsTrain[i,j] = 1

  return tagsTrain

例如,假设您有以下数据集标签:

tags = [ [tag1],
         [tag1,tag2],
         [tag3],
         [tag2],
         [],
         [tag1,tag2,tag3] ]

调用to_k_hot(tags)将返回:

tensor([1,0,0],
       [1,1,0],
       [0,0,1],
       [0,1,0],
       [0,0,0],
       [1,1,1]])
于 2021-02-16T18:17:26.933 回答