我创建了一个简单的数据 集 dataset。我需要使用batch.src 和batch.trg 获取批处理,因为我的带有torch 架构的模型使用batch.trg 和batch.src 获取批处理的迭代器样本。我不想重做火炬架构。
from torchtext.legacy.data import BucketIterator
import pandas as pd
import torch
from spacy.lang.ru import Russian
from torchtext.legacy import data
from torchtext.legacy.data import Dataset
import pandas
transcript = 'Привет что делаешь'
transcript_1 = 'Пока'
features = 'Где был'
features_1 = 'Как дела'
d = {'src': [features, features_1], 'trg': [transcript, transcript_1]}
df = pd.DataFrame(data=d, columns=['src','trg'])
def tokenize_ru(transcrip):
nlp = Russian()
transcrip = str(transcrip)
doc = nlp(transcrip)
return [token.text for token in doc]
BOS_WORD = '<s>'
EOS_WORD = '</s>'
BLANK_WORD = "<blank>"
SRC = data.Field(tokenize=tokenize_ru, init_token=BOS_WORD,
eos_token=EOS_WORD, pad_token=BLANK_WORD)
TGT = data.Field(tokenize=tokenize_ru, init_token=BOS_WORD,
eos_token=EOS_WORD, pad_token=BLANK_WORD)
data_fields = {'src':('src', SRC),'trg':('trg', TGT)}
train_ds_for_iter = Dataset(examples=df, fields=data_fields)
SRC.build_vocab(train_ds_for_iter, min_freq=1)
TGT.build_vocab(train_ds_for_iter, min_freq=1)
train_iter = BucketIterator(train_ds_for_iter,batch_size=1,)
for batch in train_iter:
print(batch.src)
得到这个错误: