gpt4 book ai didi

python - Tensorflow 解析 CSV 迭代器按行移位

转载 作者:行者123 更新时间:2023-12-01 02:20:21 25 4
gpt4 key购买 nike

我正在关注wide_deep教程,但我很难重现正确读取 CSV 的示例。

这是我生成虚拟 CSV 的代码:

data = pd.DataFrame({
'y': [1,2,3],
'x1':[4,5,6],
'x2':[7.0,8.0,9.0],
'x3':['ten','eleven','twelve']
})
file_path = 'tmp.csv'
data.to_csv(file_path, index=False, header=False)

这就是 CSV 的样子:

enter image description here

然后我尝试使用以下内容读取文件:

def parse_csv(line):
_CSV_COLUMNS = ['x1','x2','x3','y']
defaults = [[0],[0.0],[''],[0]]
columns = tf.decode_csv(line, record_defaults=defaults)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('y')
return features, tf.equal(labels, 3)

dataset = tf.data.TextLineDataset(file_path)
dataset = dataset.map(parse_csv)

iterator = dataset.make_one_shot_iterator()

for i in range(3):
features, labels = iterator.get_next()
for k,v in features.items():
print(k, v.eval())
print('-'*50)

输出如下:

x1 4
x2 8.0
x3 b'twelve'
--------------------------------------------------
<error message: OutOfRangeError (see above for traceback): End of sequence>

为什么不是4, 7.0, '10'

最佳答案

您面临的问题是由于 v.eval() 将为所有组件推进迭代器。来自(DOCS):

Note that evaluating any of next1, next2, or next3 will advance the iterator for all components. A typical consumer of an iterator will include all components in a single expression.

获得您想要的东西的一种方法是:

代码:

iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()

for i in range(3):
for k, v in sess.run(features).items():
print(k, v)
print('-' * 50)

测试代码:

import tensorflow as tf

sess = tf.InteractiveSession()

data = pd.DataFrame({
'y': [1, 2, 3],
'x1': [4, 5, 6],
'x2': [7.0, 8.0, 9.0],
'x3': ['ten', 'eleven', 'twelve']
})
file_path = 'tmp.csv'
data.to_csv(file_path, index=False, header=False)

def parse_csv(line):
_CSV_COLUMNS = ['x1', 'x2', 'x3', 'y']
defaults = [[0], [0.0], [''], [0]]
columns = tf.decode_csv(line, record_defaults=defaults)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('y')
return features, tf.equal(labels, 3)

dataset = tf.data.TextLineDataset(file_path)
dataset = dataset.map(parse_csv)

iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()

for i in range(3):
for k, v in sess.run(features).items():
print(k, v)
print('-' * 50)

结果:

x1 4
x2 7.0
x3 b'ten'
--------------------------------------------------
x1 5
x2 8.0
x3 b'eleven'
--------------------------------------------------
x1 6
x2 9.0
x3 b'twelve'

关于python - Tensorflow 解析 CSV 迭代器按行移位,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48029704/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com