gpt4 book ai didi

dataset - tensorflow : convert PrefetchDataset to BatchDataset

转载 作者:行者123 更新时间:2023-12-05 07:00:42 31 4
gpt4 key购买 nike

Tensorflow:将 PrefetchDataset 转换为 BatchDataset

使用最新的 Tensorflow 版本 2.3.1,我正在尝试遵循以下基本文本分类示例:https://www.tensorflow.org/tutorials/keras/text_classification .我没有像示例那样从目录创建数据集,而是使用 csv 文件:

SELECT_COLUMNS = ['SentimentText','Sentiment']
LABEL_COLUMN = 'Sentiment'
LABELS = [0, 1]

def get_dataset(file_path, **kwargs):
dataset = tf.data.experimental.make_csv_dataset(
file_path,
batch_size=3, # Artificially small to make examples easier to show.
label_name=LABEL_COLUMN,
na_value="?",
num_epochs=1,
ignore_errors=True,
**kwargs)
return dataset

all_data = get_dataset(data_path, select_columns=SELECT_COLUMNS)

结果我得到:

type(all_data)
tensorflow.python.data.ops.dataset_ops.PrefetchDataset

示例从目录加载数据:

batch_size = 32
seed = 42

raw_train_ds = tf.keras.preprocessing.text_dataset_from_directory(
'aclImdb/train',
batch_size=batch_size,
validation_split=0.2,
subset='training',
seed=seed)

并获取另一种类型的数据集:

type(raw_train_ds)
tensorflow.python.data.ops.dataset_ops.BatchDataset

现在,当我尝试使用示例中的函数对数据进行标准化和矢量化时:

def custom_standardization(input_data):
lowercase = tf.strings.lower(input_data)
stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')
return tf.strings.regex_replace(stripped_html,
'[%s]' % re.escape(string.punctuation),
'')

max_features = 10000
sequence_length = 250

vectorize_layer = TextVectorization(
standardize=custom_standardization,
max_tokens=max_features,
output_mode='int',
output_sequence_length=sequence_length)

将它们应用到我的数据集时出现错误:

# Make a text-only dataset (without labels), then call adapt
train_text = all_data.map(lambda x, y: x)
vectorize_layer.adapt(train_text)

---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-20-1f1fc445912d> in <module>
1 # Make a text-only dataset (without labels), then call adapt
2 train_text = all_data.map(lambda x, y: x)
----> 3 vectorize_layer.adapt(train_text)

/opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/layers/preprocessing/text_vectorization.py in adapt(self, data, reset_state)
378 shape = dataset_ops.get_legacy_output_shapes(data)
379 if not isinstance(shape, tensor_shape.TensorShape):
--> 380 raise ValueError("The dataset passed to 'adapt' must contain a single "
381 "tensor value.")
382 if shape.rank == 0:

ValueError: The dataset passed to 'adapt' must contain a single tensor value.

如何将 PrefetchDataset 转换为 BatchDataset?

最佳答案

您可以使用 tf.stack 方法将特征打包到一个数组中。以下功能来自Custom training: walkthrough在 Tensorflow 中。

def pack_features_vector(features, labels):
features = tf.stack(list(features.values()), axis=1)
return features, labels

all_data = get_dataset(data_path, select_columns=SELECT_COLUMNS)

train_dataset = all_data.map(pack_features_vector)

train_text = train_dataset.map(lambda x, y: x)

vectorize_layer.adapt(train_text)

关于dataset - tensorflow : convert PrefetchDataset to BatchDataset,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64068620/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com