gpt4 book ai didi

python - 如何 "one hot encode"Tensorflow 数据集?

转载 作者:太空宇宙 更新时间:2023-11-03 11:59:31 24 4
gpt4 key购买 nike

Newby here...我按如下方式加载了TF数据集:

dataset = tf.data.TFRecordDataset(files)
dataset.map(extract_fn)

数据集包含一个带有一些值的“字符串列”,我想对它们进行“一次性”编码。如果我有索引和深度(目前我只有一个字符串值),我可以在 extract_fn 中逐条记录地执行此操作。但是,是否有一个 TF 函数可以为我做到这一点?即

  • 计算不同值的数量
  • 将每个值映射到一个索引
  • 为此创建一个单热编码列

最佳答案

我认为这可以满足您的需求:

import tensorflow as tf
def one_hot_any(a):
# Save original shape
s = tf.shape(a)
# Find unique values
values, idx = tf.unique(tf.reshape(a, [-1]))
# One-hot encoding
n = tf.size(values)
a_1h_flat = tf.one_hot(idx, n)
# Reshape to original shape
a_1h = tf.reshape(a_1h_flat, tf.concat([s, [n]], axis=0))
return a_1h, values

# Test
x = tf.constant([['a', 'b'], ['a', 'd'], ['c', 'd'], ['b', 'd']])
x_1h, x_vals = one_hot_any(x)
with tf.Session() as sess:
print(*sess.run([x_1h, x_vals]), sep='\n')

输出:

[[[1. 0. 0. 0.]
[0. 1. 0. 0.]]

[[1. 0. 0. 0.]
[0. 0. 1. 0.]]

[[0. 0. 0. 1.]
[0. 0. 1. 0.]]

[[0. 1. 0. 0.]
[0. 0. 1. 0.]]]
[b'a' b'b' b'd' b'c']

但问题是,不同的输入会产生不一致的输出,具有不同的值顺序甚至不同的 one-hot 深度,所以我不确定它是否真的有用。

关于python - 如何 "one hot encode"Tensorflow 数据集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53560541/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com