gpt4 book ai didi

python - 如何使用状态为 Keras fit_generator 编写生成器?

转载 作者:太空宇宙 更新时间:2023-11-03 13:08:20 25 4
gpt4 key购买 nike

我正在尝试将大型数据集提供给 keras 模型。数据集不适合内存。它目前存储为一系列 hd5f 文件

我想训练我的模型使用

model.fit_generator(my_gen, steps_per_epoch=30, epochs=10, verbose=1)

但是,在我可以在网上找到的所有示例中,my_gen 仅用于对已加载的数据集执行数据扩充。例如

def generator(features, labels, batch_size):

# Create empty arrays to contain batch of features and labels#

batch_features = np.zeros((batch_size, 64, 64, 3))
batch_labels = np.zeros((batch_size,1))

while True:
for i in range(batch_size):
# choose random index in features
index= random.choice(len(features),1)
batch_features[i] = some_processing(features[index])
batch_labels[i] = labels[index]
yield batch_features, batch_labels

在我的例子中,它需要类似于

def generator(features, labels, batch_size):    
while True:
for i in range(batch_size):
# choose random index in features
index= # SELECT THE NEXT FILE
batch_features[i] = some_processing(features[files[index]])
batch_labels[i] = labels[file[index]]
yield batch_features, batch_labels

如何跟踪上一批中已读取的文件?

最佳答案

From the keras doc

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. [...]

这意味着您可以编写一个继承自 keras.utils.sequence 的类

class ProductSequence(keras.utils.Sequence):
def __init__(self):
pass

def __len__(self):
pass

def __getitem__(self, idx):
pass

__init__ 用于初始化类。__len__ 应该返回每个时期的批处理数。 Keras 将使用它来知道哪个索引可以传递给 __getitem____getitem__ 将根据索引返回批处理数据。可以找到一个简单的例子here

通过这种方法,您可以简单地拥有一个内部类对象,您可以在其中保存已读取的文件。

关于python - 如何使用状态为 Keras fit_generator 编写生成器?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50850477/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com