gpt4 book ai didi

python - Keras fit_generator() 与扩展序列的生成器返回的样本数多于总数

转载 作者:行者123 更新时间:2023-11-30 09:42:36 25 4
gpt4 key购买 nike

我正在使用 Keras 训练神经网络。由于数据集的大小,我需要使用生成器和 fit_generator() 方法。我正在关注本教程:

https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly

但是,我准备了一个小示例来检查每个时期馈送到网络的样本,并且该数字似乎高于样本数。

class DataGenerator(keras.utils.Sequence):
'Generates data for Keras'
def __init__(self, files, batch_size=2, dim=(160, 160), n_channels=3,
n_classes=2, shuffle=False):
'Initialization'
self.dim = dim
self.files = files
self.batch_size = batch_size
self.n_channels = n_channels
self.n_classes = n_classes
self.shuffle = shuffle
self.on_epoch_end()

def __len__(self):
'Denotes the number of batches per epoch'
print ("Number of batches per epoch")
print(int(np.floor(len(self.files) / self.batch_size)))
return int(np.floor(len(self.files) / self.batch_size))

def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

# Find list of IDs
files_temp = [self.files[k] for k in indexes]

# Generate data
X, y = self.__data_generation(files_temp)

return X, y


def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.files))
if self.shuffle == True:
np.random.shuffle(self.indexes)


def __data_generation(self, files_temp):
'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
# Initialization
X = np.empty((self.batch_size, *self.dim, self.n_channels))
y = np.empty((self.batch_size), dtype=int)

# Generate data
for i, ID in enumerate(files_temp):
# Store sample
X[i,] = read_image(ID)

# Store class
y[i] = get_label(ID)

return X, keras.utils.to_categorical(y, num_classes=self.n_classes)


...

params = {'dim': (160, 160),
'batch_size': 2,
'n_classes': 2,
'n_channels': 3,
'shuffle': True}


gen_train = DataGenerator(files, **params)
model.fit_generator(gen_train, steps_per_epoch=ceil(num_samples_train)/batch_size, validation_data=None,
epochs = 1, verbose=1,
callbacks = [tensorboard])

哪里read_imageget_label是我获取数据的方法。这些方法包括用于加载图像的 print() ,我得到的比我预期的要多。例如:

样本数量 = 10批量大小=2

每个时期的步骤将等于 5,这就是 keras 进度条显示的内容,但我得到了更多图像(我知道这是因为方法内部的打印)。

我尝试调试,发现__getitem__函数被调用超过 5 次!前五次的索引将在 0 到 4 之间(如预期),但随后我将获得重复的索引并加载更多数据。

知道为什么会发生这种情况吗?我已经调试到keras中的data_utils.py,但找不到索引传递到__getitem__的确切位置。 getitem 内的所有内容似乎都工作正常。

最佳答案

这是正常的,对于 steps_per_epoch = 5,您的 __getitem__ 将在每个周期被调用 5 次。因此,当然,拥有多个 epoch 意味着它会被调用更多次,而不仅仅是 5 次。

另请注意,涉及并行性,Keras 会自动在另一个线程/进程中运行您的序列(取决于配置),因此它们可能会在预期序列之外被调用。这也很正常。

关于python - Keras fit_generator() 与扩展序列的生成器返回的样本数多于总数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56856652/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com