gpt4 book ai didi

python - 如何正确组合 tf.data.Dataset 和 tf.estimator.DNNRegressor

转载 作者:行者123 更新时间:2023-11-30 09:15:06 25 4
gpt4 key购买 nike

我目前正在学习使用tensorflow,但在入门时遇到困难。我想使用最新的 API,即估算器和数据集。但如果我运行下面提供的代码,我会收到错误。

在tensorflow页面上https://www.tensorflow.org/api_docs/python/tf/estimator/DNNRegressor我发现,“该函数应该构造并返回以下内容之一:* tf.data.Dataset 对象:Dataset 对象的输出必须是具有与下面相同的约束的元组(特征、标签)。”

我以为我的代码可以提供这一点,但似乎有问题,我没有主意。

import tensorflow as tf
def input_evaluation_set():
data = [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]
labels = []
for d in data:
labels.append(1)
return tf.data.Dataset.from_tensor_slices((tf.constant(data), tf.constant(labels)))

point = tf.feature_column.numeric_column('points')
estimator = tf.estimator.DNNRegressor(feature_columns = [point],hidden_units = [100,100,100])

estimator.train(input_fn = input_evaluation_set)

我希望在具有 3 个隐藏层、100 个神经元的深度神经网络上运行训练,以逼近“常数 1”函数;相反,我收到错误“ValueError:功能应该是“张量”的字典。给定类型:类,“tensorflow.python.framework.ops.Tensor”

最佳答案

代码中的主要问题是您将数据集中的数据作为简单张量发送。但数据集中输入的数据应该是字典,其键名与特征列中使用的键名相同。除此之外,我在输入数据中添加了额外的维度。以下代码将起作用。

import tensorflow as tf
import numpy as np

### DEFINE NEW MAP FUNCTION
def map_fn(d, l):
return {'points': d}, l

def input_evaluation_set():
data = [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]
labels = []
for d in data:
labels.append(1)

### CHANGE STARTS HERE
data = np.array(data)
data = np.expand_dims(data, axis=-1)
labels = np.array(labels)
labels = np.expand_dims(labels, axis=-1)
### CHANGE ENDS HERE

dataset = tf.data.Dataset.from_tensor_slices((tf.constant(data), tf.constant(labels)))

### CREATE DICTIONARY PAIR IN INPUT DATA
dataset = dataset.map(map_fn)
return dataset

point = tf.feature_column.numeric_column('points')
estimator = tf.estimator.DNNRegressor(feature_columns = [point],hidden_units = [100,100,100])

estimator.train(input_fn = input_evaluation_set)

关于python - 如何正确组合 tf.data.Dataset 和 tf.estimator.DNNRegressor,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57823210/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com