gpt4 book ai didi

theano - Keras:如果数据大小不能被batch_size整除怎么办?

转载 作者:行者123 更新时间:2023-12-04 04:29:42 25 4
gpt4 key购买 nike

我是 Keras 的新手,刚刚开始研究一些示例。我正在处理以下问题:我有 4032 个样本,其中大约 650 个用于拟合或基本上是训练状态,然后使用其余样本来测试模型。问题是我不断收到以下错误:

Exception: In a stateful network, you should only pass inputs with a number of samples that can be divided by the batch size.

我明白为什么我会收到这个错误,我的问题是,如果我的数据大小不能被 batch_size 整除怎么办? ?我以前用 Deeplearning4j LSTM 工作过,没必要处理这个问题。有没有办法解决这个问题?

谢谢

最佳答案

最简单的解决方案是使用 fit_generator 而不是 fit。我写了一个简单的数据加载器类,可以继承它来做更复杂的事情。它看起来像这样,将 get_next_batch_data 重新定义为您的数据包括增强等内容。

class BatchedLoader():
def __init__(self):
self.possible_indices = [0,1,2,...N] #(say N = 33)
self.cur_it = 0
self.cur_epoch = 0

def get_batch_indices(self):
batch_indices = self.possible_indices [cur_it : cur_it + batchsize]
# If len(batch_indices) < batchsize, the you've reached the end
# In that case, reset cur_it to 0 and increase cur_epoch and shuffle possible_indices if wanted
# And add remaining K = batchsize - len(batch_indices) to batch_indices


def get_next_batch_data(self):
# batch_indices = self.get_batch_indices()
# The data points corresponding to those indices will be your next batch data

关于theano - Keras:如果数据大小不能被batch_size整除怎么办?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37974340/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com