gpt4 book ai didi

具有多处理功能的 Tensorflow2.x 自定义数据生成器

转载 作者:行者123 更新时间:2023-12-02 02:35:37 27 4
gpt4 key购买 nike

我刚刚升级到tensorflow 2.3。我想制作自己的数据生成器用于训练。使用tensorflow 1.x,我这样做了:

def get_data_generator(test_flag):
item_list = load_item_list(test_flag)
print('data loaded')
while True:
X = []
Y = []
for _ in range(BATCH_SIZE):
x, y = get_random_augmented_sample(item_list)
X.append(x)
Y.append(y)
yield np.asarray(X), np.asarray(Y)

data_generator_train = get_data_generator(False)
data_generator_test = get_data_generator(True)
model.fit_generator(data_generator_train, validation_data=data_generator_test,
epochs=10000, verbose=2,
use_multiprocessing=True,
workers=8,
validation_steps=100,
steps_per_epoch=500,
)

这段代码在tensorflow 1.x 上运行良好。系统中创建了8个进程。处理器和显卡加载完美。 “数据已加载”打印了 8 次。

使用tensorflow 2.3我收到警告:

WARNING: tensorflow: multiprocessing can interact badly with TensorFlow, causing nondeterministic deadlocks. For high performance data pipelines tf.data is recommended.

“数据加载”被打印一次(应该是8次)。 GPU 没有得到充分利用。每个 epoch 都会有内存泄漏,因此训练会在几个 epoch 后停止。 use_multiprocessing 标志没有帮助。

如何在tensorflow(keras) 2.x中制作一个可以轻松跨多个CPU进程并行化的生成器/迭代器?死锁和数据顺序并不重要。

最佳答案

使用 tf.data 管道,您可以在多个位置进行并行化。根据数据的存储和读取方式,您可以并行读取。您还可以并行化增强,并且可以在训练时预取数据,因此您的 GPU(或其他硬件)永远不会渴望数据。

在下面的代码中,我演示了如何并行化增强并添加预取。

import numpy as np
import tensorflow as tf

x_shape = (32, 32, 3)
y_shape = () # A single item (not array).
classes = 10

# This is tf.data.experimental.AUTOTUNE in older tensorflow.
AUTOTUNE = tf.data.AUTOTUNE

def generator_fn(n_samples):
"""Return a function that takes no arguments and returns a generator."""
def generator():
for i in range(n_samples):
# Synthesize an image and a class label.
x = np.random.random_sample(x_shape).astype(np.float32)
y = np.random.randint(0, classes, size=y_shape, dtype=np.int32)
yield x, y
return generator

def augment(x, y):
return x * tf.random.normal(shape=x_shape), y

samples = 10
batch_size = 5
epochs = 2

# Create dataset.
gen = generator_fn(n_samples=samples)
dataset = tf.data.Dataset.from_generator(
generator=gen,
output_types=(np.float32, np.int32),
output_shapes=(x_shape, y_shape)
)
# Parallelize the augmentation.
dataset = dataset.map(
augment,
num_parallel_calls=AUTOTUNE,
# Order does not matter.
deterministic=False
)
dataset = dataset.batch(batch_size, drop_remainder=True)
# Prefetch some batches.
dataset = dataset.prefetch(AUTOTUNE)

# Prepare model.
model = tf.keras.applications.VGG16(weights=None, input_shape=x_shape, classes=classes)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy")

# Train. Do not specify batch size because the dataset takes care of that.
model.fit(dataset, epochs=epochs)

关于具有多处理功能的 Tensorflow2.x 自定义数据生成器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64356769/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com