gpt4 book ai didi

python - Tensorflow Python 读取 2 个文件

转载 作者:太空宇宙 更新时间:2023-11-04 02:29:37 25 4
gpt4 key购买 nike

我尝试运行以下(缩短的)代码:

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
while not coord.should_stop():

# Run some code.... (Reading some data from file 1)

coord_dev = tf.train.Coordinator()
threads_dev = tf.train.start_queue_runners(sess=sess, coord=coord_dev)

try:
while not coord_dev.should_stop():

# Run some other code.... (Reading data from file 2)

except tf.errors.OutOfRangeError:
print('Reached end of file 2')
finally:
coord_dev.request_stop()
coord_dev.join(threads_dev)

except tf.errors.OutOfRangeError:
print('Reached end of file 1')
finally:
coord.request_stop()
coord.join(threads)

上面应该发生的是:

  • 文件 1 是一个 csv 文件,包括我的神经网络的训练数据。
  • 文件 2 包含开发集数据。

在训练期间迭代文件 1 时,我偶尔也想计算开发集数据(来自文件 2)的成本和准确度。但是当内层循环读完文件2的时候,很明显会触发异常

"tf.errors.OutOfRangeError"

这会导致我的代码也离开外循环。内循环的异常也简单地作为外循环的异常处理。但是在读完文件 2 之后,我希望我的代码在外循环中继续对文件 1 进行训练。

(我已经删除了一些细节,比如 num_epochs to train 等,以简化代码的可读性)

有没有人对如何解决这个问题有任何建议?我对此有点陌生。

提前致谢!

最佳答案

已解决。

显然,使用 queue_runners 并不是执行此操作的正确方法。 Tensorflow 文档表明应该改用数据集 API,这需要时间来理解。下面的代码完成了我之前尝试做的事情。在这里分享以防其他人也可能需要它。

我在 www.github.com/loheden/tf_examples/dataset api 下放置了一些额外的训练代码。我费了一番功夫才找到完整的例子。

# READING DATA FROM train and validation (dev set) CSV FILES by using INITIALIZABLE ITERATORS

# All csv files have same # columns. First column is assumed to be train example ID, the next 5 columns are feature
# columns, and the last column is the label column

# ASSUMPTIONS: (Otherwise, decode_csv function needs update)
# 1) The first column is NOT a feature. (It is most probably a training example ID or similar)
# 2) The last column is always the label. And there is ONLY 1 column that represents the label.
# If more than 1 column represents the label, see the next example down below

feature_names = ['f1','f2','f3','f4','f5']
record_defaults = [[""], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]


def decode_csv(line):
parsed_line = tf.decode_csv(line, record_defaults)
label = parsed_line[-1] # label is the last element of the list
del parsed_line[-1] # delete the last element from the list
del parsed_line[0] # even delete the first element bcz it is assumed NOT to be a feature
features = tf.stack(parsed_line) # Stack features so that you can later vectorize forward prop., etc.
#label = tf.stack(label) #NOT needed. Only if more than 1 column makes the label...
batch_to_return = features, label
return batch_to_return

filenames = tf.placeholder(tf.string, shape=[None])
dataset5 = tf.data.Dataset.from_tensor_slices(filenames)
dataset5 = dataset5.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1).map(decode_csv))
dataset5 = dataset5.shuffle(buffer_size=1000)
dataset5 = dataset5.batch(7)
iterator5 = dataset5.make_initializable_iterator()
next_element5 = iterator5.get_next()

# Initialize `iterator` with training data.
training_filenames = ["train_data1.csv",
"train_data2.csv"]

# Initialize `iterator` with validation data.
validation_filenames = ["dev_data1.csv"]

with tf.Session() as sess:
# Train 2 epochs. Then validate train set. Then validate dev set.
for _ in range(2):
sess.run(iterator5.initializer, feed_dict={filenames: training_filenames})
while True:
try:
features, labels = sess.run(next_element5)
# Train...
print("(train) features: ")
print(features)
print("(train) labels: ")
print(labels)
except tf.errors.OutOfRangeError:
print("Out of range error triggered (looped through training set 1 time)")
break

# Validate (cost, accuracy) on train set
print("\nDone with the first iterator\n")

sess.run(iterator5.initializer, feed_dict={filenames: validation_filenames})
while True:
try:
features, labels = sess.run(next_element5)
# Validate (cost, accuracy) on dev set
print("(dev) features: ")
print(features)
print("(dev) labels: ")
print(labels)
except tf.errors.OutOfRangeError:
print("Out of range error triggered (looped through dev set 1 time only)")
break

关于python - Tensorflow Python 读取 2 个文件,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49525056/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com