0

我的源代码结构如下:

定义字段

TEXT = data.Field(
    sequential=True,
    tokenize = 'spacy',
    #batch_first = True,
) 
LABEL = data.LabelField()

更新:获取词汇

TEXT.build_vocab(

    train,

    max_size=25000,

    #vectors = "glove.6B.100d",

    vectors = vectors,

    #unk_init = torch.Tensor.normal_

    #min_freq=50,

)

LABEL.build_vocab(

    train

)

创建示例和数据集实例

fields=[            
        ('name', TEXT),
        ('component', LABEL),
    ]

examples= []

example = Example.fromdict({"name": "exmaple", "component": "Undefined"}, {"name": ("name", TEXT), "component": ("component", LABEL)})

examples.append(example)

instance = Dataset(examples=[example], fields=fields)

创建桶迭代器实例

instance_iterator = data.BucketIterator(instance,batch_size=64,sort_key=lambda x: x.name,sort= True, train=True,device=device)

更新:我也尝试使用 data.BucketIterator.split

BATCH_SIZE = 64

train_iterator,  instance_iterator = data.BucketIterator.splits(
    (train, instance), 
    sort_key = lambda x: x.name,
    sort = True,
    batch_size = BATCH_SIZE, 
    device = device)

遍历桶迭代器

for batch in instance_iterator:
  print(batch.name)

当尝试使用可自定义的数据集迭代存储桶迭代器时,会出现以下错误:

/usr/local/lib/python3.6/dist-packages/torchtext/data/iterator.py in __iter__(self)
    160                     else:
    161                         minibatch.sort(key=self.sort_key, reverse=True)
--> 162                 yield Batch(minibatch, self.dataset, self.device)
    163             if not self.repeat:
    164                 return

/usr/local/lib/python3.6/dist-packages/torchtext/data/batch.py in __init__(self, data, dataset, device)
     26             self.dataset = dataset
     27             self.fields = dataset.fields.keys()  # copy field names
---> 28             self.input_fields = [k for k, v in dataset.fields.items() if
     29                                  v is not None and not v.is_target]
     30             self.target_fields = [k for k, v in dataset.fields.items() if

/usr/local/lib/python3.6/dist-packages/torchtext/data/batch.py in <listcomp>(.0)
     27             self.fields = dataset.fields.keys()  # copy field names
     28             self.input_fields = [k for k, v in dataset.fields.items() if
---> 29                                  v is not None and not v.is_target]
     30             self.target_fields = [k for k, v in dataset.fields.items() if
     31                                   v is not None and v.is_target]

AttributeError: 'tuple' object has no attribute 'is_target'
4

0 回答 0