1

我正在尝试将 tf.data.Dataset 的标签转换为一个热编码标签。我正在使用这个数据集。我在列中添加了标题(情感、文本),其他一切都是原创的。

这是我用来将标签(正面、负面、中性)编码为一个热点 (3,) 的代码:

def _map_func(text, labels):
   labels_enc = []
   for label in labels:
      if label=='negative':
         label = -1
      elif label=='neutral':
         label = 0
      else: 
         label = 1

      label = tf.one_hot(
         label, 3, name='label', axis=-1)

      labels_enc.append(label)

   return text, labels_enc

raw_train_ds = tf.data.experimental.make_csv_dataset(
   './data/sentiment_data/train.csv', BATCH_SIZE, column_names=['sentiment', 'text'],
   label_name='sentiment', header=True
)

train_ds = raw_train_ds.cache().prefetch(buffer_size=AUTOTUNE)

train_ds = train_ds.map(_map_func)

我收到错误消息:ValueError: Value [<tf.Tensor 'while/label:0' shape=(3,) dtype=float32>] is not convertible to a tensor with dtype <dtype: 'float32'> and shape (1, 3).

标签的第二个参数_map_func(text, label)的形状为 (64,) type=string。

如果我正确理解了 tensorflows tf.data.Dataset.map 函数,它会使用转换函数应用的转换创建一个新数据集。但是由于错误指出标签的列不能从具有一个字符串的列转换为具有包含 3 个浮点数的列表的列。有没有办法强制新列的类型接受编码标签?

谢谢您的帮助 :)

4

2 回答 2

2

映射函数适用于每个元素,因此您无需创建列表并循环遍历批处理项。仅对一个样本进行尝试:

def _map_func(text, label):
    if label=='negative':
        label = -1
    elif label=='neutral':
        label = 0
    else: 
        label = 1

    label = tf.one_hot(label, 3, name='label', axis=-1)

   return text, label
于 2020-12-16T17:56:03.510 回答
1

我通过使用 TensorFlow TensorArray 解决了这个问题,如下所示:

def _map_func(text, labels):
    i=0
    labels_enc = tf.TensorArray(tf.float32, size=0, dynamic_size=True,
        clear_after_read=False)
    for label in labels:
        if label=='negative':
            label = tf.constant(-1)
        elif label=='neutral':
            label = tf.constant(0)
        else: 
            label = tf.constant(1)

        label = tf.one_hot(
            label, 3, name='label', axis=-1)

        labels_enc.write(i, label)
            i = i+1

    return text, labels_enc.concat()
于 2020-12-16T21:14:58.803 回答