我正在从事多项选择质量检查。我正在使用为 SWAG 数据集实现的拥抱脸/变压器的官方笔记本。
我想将它用于其他多项选择数据集。因此,我添加了一些与数据集相关的修改。所有代码都在notebook中给出。
SWAG 数据集包含以下列,包括“标签”。
train: Dataset({
features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label'],
num_rows: 73546
})
我要使用的数据集具有以下列,包括目标的“answerKey”。
train: Dataset({
features: ['id', 'question_stem', 'choices', 'answerKey'],
num_rows: 4957
})
错误在数据加载器中给出,即
@dataclass
class DataCollatorForMultipleChoice:
"""
Data collator that will dynamically pad the inputs for multiple choice received.
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
def __call__(self, features):
print(features[0].keys())
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature.pop(label_name) for feature in features]
batch_size = len(features)
num_choices = len(features[0]["input_ids"])
flattened_features = [[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features]
flattened_features = sum(flattened_features, [])
batch = self.tokenizer.pad(
flattened_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
# Un-flatten
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
# Add back labels
batch["labels"] = torch.tensor(labels, dtype=torch.int64)
return batch
在以下行中给出了错误:
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature.pop(label_name) for feature in features]
错误是在 trainer.train() 中获得的
KeyError Traceback (most recent call last)
<ipython-input-64-3435b262f1ae> in <module>()
----> 1 trainer.train()
5 frames
<ipython-input-60-d1262e974b03> in <listcomp>(.0)
18 print(features[0].keys())
19 label_name = "label" if "label" in features[0].keys() else "labels"
---> 20 labels = [feature.pop(label_name) for feature in features]
21 batch_size = len(features)
22 num_choices = len(features[0]["input_ids"])
KeyError: 'labels'
我不知道是什么导致了错误。我认为它与目标键有关。但我无法解决。有任何想法吗?
谢谢,