gpt4 book ai didi

python - 数据集不适合 LSTM 训练的内存

转载 作者:行者123 更新时间:2023-12-02 00:47:40 27 4
gpt4 key购买 nike

我正在尝试创建在大型音乐数据集上训练的模型。 midi 文件被转换为 numpy 数组。由于 LSTM 需要顺序数据,因此在转换为 LSTM 的序列时数据集大小变得如此巨大。

我根据基调和持续时间将 midi 音符转换为索引,因此我得到了 6 个 C4 键类。同样,我从 C3 到 B5,总共 288 节课以及休息时间的课。

单个 midi 的转换格式如下所示。

midi = [0,23,54,180,23,45,34,.....];

为了训练 LSTM,x 和 y 变为

x = [[0,23,54..45],[23,54,..,34],...];

y=[[34],[76],...]

x 和 y 中的值进一步转换为单热编码。因此,对于 60 个中小型文件来说,数据量变得很大,但我有 1700 个文件。如何使用这么多文件训练模型。我检查了 ImageGenerator 但它要求数据位于单独的类目录中。如何实现?

最佳答案

您应该在训练过程中即时生成训练数据。基于tf documentation ,您可以编写自己的生成器用作训练数据,或继承自 Sequence .

第一个选项应该是这样的

def create_data_generator(your_files):
raw_midi_data = process_files(your_files)
seq_size = 32

def _my_generator():
i = 0
while True:
x = raw_midi_data[i:i + seq_size]
y = raw_midi_data[i + seq_size]
i = (i + 1) % (len(raw_midi_data) - seq_size)
yield x, y

return _my_generator()

然后调用它(假设 tf >= 2.0)

generator = create_data_generator(your_files)
model.fit(x=generator, ...)

如果您使用的是“旧”Keras(在 tensorflow 2.0 之前)Keras 团队本身不推荐,您应该改用 fit_generator:

model.fit_generator(generator, ...)

使用此解决方案,您只需将数据存储在内存中一次,不会因重叠序列而导致重复。

关于python - 数据集不适合 LSTM 训练的内存,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60486804/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com