0

我需要有关转换功能的帮助。我想要做的是打乱句子的标记/单词并将它们仅输入到编码器,但对于解码器,我想要相同的未加扰标记/单词作为输入。我想我无法弄清楚 for 循环的事情?请帮忙。

'''

这是在句子中打乱单词的功能。

def scramble(text):
    words = text.split()
    random.shuffle(words)
    return ' '.join(words)

此函数标记文本/句子。

def tokenize(text):
    tokens = [re.sub(REMOVE_CHARS, '', token)
              for token in re.split("[-\n ]", text)]
    return tokens

此函数在每个句子的单词中引入随机拼写错误。

def add_speling_erors(tokn, error_rate):
    """Simulate some artificial spelling mistakes."""
    assert(0.0 <= error_rate < 1.0)
    if len(tokn) < 3:
        return tokn
    rand = np.random.rand()
    # Here are 4 different ways spelling mistakes can occur,
    # each of which has equal chance.
    prob = error_rate / 4.0
    if rand < prob:
        # Replace a character with a random character.
        random_char_index = np.random.randint(len(tokn))
        tokn = tokn[:random_char_index] + np.random.choice(CHARS) \
                + tokn[random_char_index + 1:]
    elif prob < rand < prob * 2:
        # Delete a character.
        random_char_index = np.random.randint(len(tokn))
        tokn = tokn[:random_char_index] + tokn[random_char_index + 1:]
    elif prob * 2 < rand < prob * 3:
        # Add a random character.
        random_char_index = np.random.randint(len(tokn))
        tokn = tokn[:random_char_index] + np.random.choice(CHARS) \
                + tokn[random_char_index:]
    elif prob * 3 < rand < prob * 4:
        # Transpose 2 characters.
        random_char_index = np.random.randint(len(tokn) - 1)
        tokn = tokn[:random_char_index]  + tokn[random_char_index + 1] \
                + tokn[random_char_index] + tokn[random_char_index + 2:]
    else:
        # No spelling errors.
        pass
    return tokn

此函数将输入令牌转换为编码器、解码器和目标令牌。

    def transform(tokens, maxlen, error_rate=0.3, shuffle=True):

        if shuffle:
            print('Shuffling data.')
            np.random.shuffle(tokens)

        encoder_tokens = []
        decoder_tokens = []
        target_tokens = []        
        for token in tokens:

            text = TreebankWordDetokenizer().detokenize(tokens)
            text = scramble (text)
            tokens = tokenize (text)

            encoder = add_speling_erors(token, error_rate=error_rate)
            encoder += EOS * (maxlen - len(encoder)) # Padded to maxlen.
            encoder_tokens.append(encoder)

            decoder = SOS + token
            decoder += EOS * (maxlen - len(decoder))
            decoder_tokens.append(decoder)

            target = decoder[1:]
            target += EOS * (maxlen - len(target))
            target_tokens.append(target)

            assert(len(encoder) == len(decoder) == len(target))
        return encoder_tokens, decoder_tokens, target_tokens

'''
4

0 回答 0