gpt4 book ai didi

machine-learning - 如何将 LMDB 文件加载到 TensorFlow 中?

转载 作者:行者123 更新时间:2023-11-30 08:27:58 25 4
gpt4 key购买 nike

我有大量 (1 TB) 数据集,分为约 3,000 个 CSV 文件。我的计划是将其转换为一个大型 LMDB 文件,以便可以快速读取它以训练神经网络。但是,我无法找到任何有关如何将 LMDB 文件加载到 TensorFlow 中的文档。有谁知道如何做到这一点?我知道 TensorFlow 可以读取 CSV 文件,但我认为那会太慢。

最佳答案

根据this在 TensorFlow 中读取数据的方法有多种。

最简单的一种是通过 占位符提供数据。使用 占位符时 - 洗牌和批处理的责任由您承担。

如果您想将洗牌和批处理委托(delegate)给框架,那么您需要创建一个输入管道。问题是这样的 - 如何将 lmdb 数据注入(inject)符号输入管道。一种可能的解决方案是使用 tf.py_func 操作。这是一个例子:

def create_input_pipeline(lmdb_env, keys, num_epochs=10, batch_size=64):
key_producer = tf.train.string_input_producer(keys,
num_epochs=num_epochs,
shuffle=True)
single_key = key_producer.dequeue()

def get_bytes_from_lmdb(key):
with lmdb_env.begin() as txn:
lmdb_val = txn.get(key)
example = get_example_from_val(lmdb_val) # A single example (numpy array)
label = get_label_from_val(lmdb_val) # The label, could be a scalar
return example, label

single_example, single_label = tf.py_func(get_bytes_from_lmdb,
[single_key], [tf.float32, tf.float32])
# if you know the shapes of the tensors you can set them here:
# single_example.set_shape([224,224,3])

batch_examples, batch_labels = tf.train.batch([single_example, single_label],
batch_size)
return batch_examples, batch_labels

tf.py_func 操作会在 TensorFlow 图内插入对常规 Python 代码的调用,我们需要指定输入以及输出的数量和类型。 tf.train.string_input_ Producer 使用给定的键创建一个打乱的队列。 tf.train.batch 操作创建另一个包含批量数据的队列。训练时,batch_examplesbatch_labels 的每次评估都会使另一个批处理从该队列中出列。

因为我们创建了队列,所以在开始训练之前我们需要小心并运行 QueueRunner 对象。这样做是这样的(来自 TensorFlow 文档):

# Create the graph, etc.
init_op = tf.initialize_all_variables()

# Create a session for running operations in the Graph.
sess = tf.Session()

# Initialize the variables (like the epoch counter).
sess.run(init_op)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
while not coord.should_stop():
# Run training steps or whatever
sess.run(train_op)

except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
# When done, ask the threads to stop.
coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()

关于machine-learning - 如何将 LMDB 文件加载到 TensorFlow 中?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37337523/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com