我正在尝试将 tf.data.Dataset 的标签转换为一个热编码标签。我正在使用这个数据集。我在列中添加了标题(情感、文本),其他一切都是原创的。
这是我用来将标签(正面、负面、中性)编码为一个热点 (3,) 的代码:
def _map_func(text, labels):
labels_enc = []
for label in labels:
if label=='negative':
label = -1
elif label=='neutral':
label = 0
else:
label = 1
label = tf.one_hot(
label, 3, name='label', axis=-1)
labels_enc.append(label)
return text, labels_enc
raw_train_ds = tf.data.experimental.make_csv_dataset(
'./data/sentiment_data/train.csv', BATCH_SIZE, column_names=['sentiment', 'text'],
label_name='sentiment', header=True
)
train_ds = raw_train_ds.cache().prefetch(buffer_size=AUTOTUNE)
train_ds = train_ds.map(_map_func)
我收到错误消息:ValueError: Value [<tf.Tensor 'while/label:0' shape=(3,) dtype=float32>] is not convertible to a tensor with dtype <dtype: 'float32'> and shape (1, 3).
标签的第二个参数_map_func(text, label)
的形状为 (64,) type=string。
如果我正确理解了 tensorflows tf.data.Dataset.map 函数,它会使用转换函数应用的转换创建一个新数据集。但是由于错误指出标签的列不能从具有一个字符串的列转换为具有包含 3 个浮点数的列表的列。有没有办法强制新列的类型接受编码标签?
谢谢您的帮助 :)