gpt4 book ai didi

python - 使用 tf.data 的 One-hot 编码会混淆列

转载 作者:行者123 更新时间:2023-12-01 06:27:26 24 4
gpt4 key购买 nike

最少工作示例

考虑以下 CSV 文件 ( example.csv )

animal,size,weight,category
lion,large,200,mammal
ostrich,large,150,bird
sparrow,small,0.1,bird
whale,large,3000,mammal
bat,small,0.2,mammal
snake,small,1,reptile
condor,medium,12,bird

目标是将所有分类值转换为 one-hot 编码。 standard在 Tensorflow 2.0 中执行此操作的方法是使用 tf.data 。按照该示例,处理上述数据集的代码是

import collections
import tensorflow as tf

# Load the dataset.
dataset = tf.data.experimental.make_csv_dataset(
'example.csv',
batch_size=5,
num_epochs=1,
shuffle=False)

# Specify the vocabulary for each category.
categories = collections.OrderedDict()
categories['animal'] = ['lion', 'ostrich', 'sparrow', 'whale', 'bat', 'snake', 'condor']
categories['size'] = ['large', 'medium', 'small']
categories['category'] = ['mammal', 'reptile', 'bird']

# Define the categorical feature columns.
categorical_columns = []
for feature, vocab in categories.items():
cat_col = tf.feature_column.categorical_column_with_vocabulary_list(
key=feature, vocabulary_list=vocab)
categorical_columns.append(tf.feature_column.indicator_column(cat_col))

# Retrieve the first batch and apply the one-hot encoding to it.
iterator = iter(dataset)
first_batch = next(iterator)
categorical_layer = tf.keras.layers.DenseFeatures(categorical_columns)

print(categorical_layer(first_batch).numpy())

问题

运行上面的代码,得到

[[1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1.]
[0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 1.]]

看起来像最后两列 sizecategory尽管 categories 已被翻转是一个有序字典以及实际数据集中列的预先存在的顺序。就好像tf.feature_column.categorical_column_with_vocabulary_list()对列进行了一些不必要的字母排序。

出现上述情况的原因是什么?这真的是本着tf.data的精神进行one-hot编码的最佳方式吗? ?

最佳答案

排序在哪里?

tf.feature_column.categorical_column_with_vocabulary_list() 处未进行排序。如果打印categorical_columns,您将看到这些列仍然按照您将它们添加到feature_column 的顺序排列:

[
IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='animal', vocabulary_list=('lion', 'ostrich', 'sparrow', 'whale', 'bat', 'snake', 'condor'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),
IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='size', vocabulary_list=('large', 'medium', 'small'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),
IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='category', vocabulary_list=('mammal', 'reptile', 'bird'), dtype=tf.string, default_value=-1, num_oov_buckets=0))
]

排序发生在 tf.keras.layers.DenseFeatures对象。

在代码中,您可以看到排序发生在哪里 here (我通过跟踪从 tf.keras.layers.DenseFeatures 类到 tensorflow.python.feature_column.dense_features.DenseFeatures 类到 tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer 类到 _normalize_feature_columns 函数的类继承发现了这一点)。

为什么要排序?

那么为什么要排序呢? Elsewhere在包含 _normalize_feature_columns 函数(对数据进行排序的函数)的同一文件中,有一个类似的排序函数,并带有以下注释:

# Sort the columns so the default collection name is deterministic even if the
# user passes columns from an unsorted collection, such as dict.values().

我认为这个解释也适用于为什么在使用 tf.keras.layers.DenseFeatures 类时对列进行排序。您的列和数据是一致的,但 TensorFlow 并不假设输入是一致的,因此它会对输入进行排序以确保顺序一致。

关于python - 使用 tf.data 的 One-hot 编码会混淆列,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60064351/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com