gpt4 book ai didi

python - 如何读取一个 hdf5 数据文件中的批处理进行训练?

转载 作者:太空狗 更新时间:2023-10-29 22:04:03 24 4
gpt4 key购买 nike

我有一个大小为 (21760, 1, 33, 33) 的 hdf5 训练数据集。 21760 是训练样本的总数。我想使用大小为 128 的小批量训练数据来训练网络。

我想问:

如何使用 tensorflow 每次从整个数据集中提供 128 mini-batch 训练数据?

最佳答案

如果你的数据集太大以至于无法像keveman建议的那样导入内存,你可以直接使用h5py对象:

import h5py
import tensorflow as tf

data = h5py.File('myfile.h5py', 'r')
data_size = data['data_set'].shape[0]
batch_size = 128
sess = tf.Session()
train_op = # tf.something_useful()
input = # tf.placeholder or something
for i in range(0, data_size, batch_size):
current_data = data['data_set'][position:position+batch_size]
sess.run(train_op, feed_dict={input: current_data})

如果您愿意,您还可以运行大量迭代并随机选择一个批处理:

import random
for i in range(iterations):
pos = random.randint(0, int(data_size/batch_size)-1) * batch_size
current_data = data['data_set'][pos:pos+batch_size]
sess.run(train_op, feed_dict={inputs=current_data})

或依次:

for i in range(iterations):
pos = (i % int(data_size / batch_size)) * batch_size
current_data = data['data_set'][pos:pos+batch_size]
sess.run(train_op, feed_dict={inputs=current_data})

您可能想要编写一些更复杂的代码来随机遍历所有数据,但要跟踪已使用的批处理,因此您不会比其他批处理更频繁地使用任何批处理。完整运行训练集后,再次启用所有批处理并重复。

关于python - 如何读取一个 hdf5 数据文件中的批处理进行训练?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38225770/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com