gpt4 book ai didi

python - TensorFlow解码_csv形状错误

转载 作者:行者123 更新时间:2023-11-30 22:18:04 25 4
gpt4 key购买 nike

我使用 tf.data.TextLineDataset 读取 *.csv 文件并在其上应用 map:

dataset = tf.data.TextLineDataset(os.path.join(data_dir, subset, 'label.txt'))
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
num_parallel_calls=num_parallel_calls)

解析函数parse_record_fn如下所示:

def parse_record(raw_record, is_training):
default_record = ["./", -1]
filename, label = tf.decode_csv([raw_record], default_record)
# do something
return image, label

但是在解析函数中的tf.decode_csv处引发了ValueError:

ValueError:形状必须为排名 1,但对于“DecodeCSV”(操作:“DecodeCSV”),其排名为 0,输入形状为:[1]、[]、[]。

我的*.csv 文件示例:

/data/1.png, 5
/data/2.png, 7

问题:

  1. 哪里出了问题?
  2. shapes: [1], [], [] 是什么意思?

重现

可以在此代码中重现此错误:

import tensorflow as tf
import os

def parse_record(raw_record, is_training):
default_record = ["./", -1]
filename, label = tf.decode_csv([raw_record], default_record)

# do something

return image, label

with tf.Session() as sess:
csv_path = './labels.txt'


dataset = tf.data.TextLineDataset(csv_path)

dataset = dataset.map(lambda value: parse_record(value, True))


sess.run(dataset)

最佳答案

查看 tf.decode_csv 的文档,它说明了默认记录:

record_defaults: A list of Tensor objects with specific types. Acceptable types are float32, float64, int32, int64, string. One tensor per column of the input record, with either a scalar default value for that column or empty if the column is required.

我相信您遇到的错误源于您定义张量default_record的方式。您的 default_record 当然是张量对象(或可转换为张量的对象)的列表,但我认为错误消息表明它们应该是 1 级张量,而不是像您的情况那样是 0 级张量.

您可以通过将默认记录设置为排名 1 张量来解决该问题。请参阅以下玩具示例:

import tensorflow as tf

my_line = 'filename.png, 10'
default_record_1 = [['./'], [-1]] # do this!
default_record_2 = ['./', -1] # this is what you do now

decoded_1 = tf.decode_csv(my_line, default_record_1)
with tf.Session() as sess:
d = sess.run(decoded_1)
print(d)

# This will cause an error
decoded_2 = tf.decode_csv(my_line, default_record_2)

最后一行产生的错误很熟悉:

ValueError: Shape must be rank 1 but is rank 0 for 'DecodeCSV_1' (op: 'DecodeCSV') with input shapes: [], [], [].

在消息中,输入形状,三个括号[],指的是输入参数的形状recordsrecord_defaults,和 tf.decode_csvfield_delim。在您的情况下,这些形状中的第一个是 [1],因为您输入了 [raw_record]。我同意此案例的消息信息量不大......

关于python - TensorFlow解码_csv形状错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49473963/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com