gpt4 book ai didi

python - 当batch_size与数据量不匹配时,Keras自定义生成器

转载 作者:行者123 更新时间:2023-11-30 09:04:49 25 4
gpt4 key购买 nike

我正在使用 Keras 和 Python 2.7。我正在制作自己的数据生成器来计算火车的批处理。我对基于此模型的 data_generator 有一些疑问 seen here :

class DataGenerator(keras.utils.Sequence):

def __init__(self, list_IDs, ...):
#init

def __len__(self):
return int(np.floor(len(self.list_IDs) / self.batch_size))

def __getitem__(self, index):
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# Find list of IDs
list_IDs_temp = [self.list_IDs[k] for k in indexes]
# Generate data
X, y = self.__data_generation(list_IDs_temp)
return X, y

def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.list_IDs))
if self.shuffle == True:
np.random.shuffle(self.indexes)

def __data_generation(self, list_IDs_temp):
#generate data
return X, y

好的,这是我的几个问题:

你能证实我对调用函数顺序的想法吗?这是:

- __init__
- loop for each epoc :
- loop for each batches :
- __len_
- __get_item__ (+data generation)
- on_epoch_end

如果您知道调试生成器的方法,我想知道它,断点和打印不适用于此..

更多,我的情况很糟糕,但我认为每个人都有这个问题:

例如,我有 200 个数据(也可以有 200 个标签),并且我想要批量大小为 64。如果我想得好,__len_ 会给出 200/64 = 3(而不是 3,125)。那么 1 epoch 将通过 3 个批处理完成?其余数据呢?我有一个错误,因为我的数据量不是批量大小的倍数...

第二个例子,我有 200 个数据,我想要一批 256 个?在这种情况下我必须做什么来调整我的发电机?我考虑过检查batch_size是否优于我的数据量,以向CNN提供1批数据,但该批数据不会达到预期的大小,所以我认为它会出错?

感谢您的阅读。我更喜欢放置伪代码,因为我的问题更多的是关于理论而不是编码错误!

最佳答案

  • __len__ :返回批处理数
  • __getitem__ :返回第i批

通常你不会在模型架构中提及批量大小,因为它是训练参数而不是模型参数。因此,在训练时使用不同的批量大小是可以的。

示例

from keras.models import Sequential
from keras.layers import Dense, Conv2D, Flatten
from keras.utils import to_categorical
import keras

#create model
model = Sequential()
#add model layers
model.add(Conv2D(64, kernel_size=3, activation='relu', input_shape=(10,10,1)))
model.add(Flatten())
model.add(Dense(2, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
class DataGenerator(keras.utils.Sequence):
def __init__(self, X, y, batch_size):
self.X = X
self.y = y
self.batch_size = batch_size

def __len__(self):
l = int(len(self.X) / self.batch_size)
if l*self.batch_size < len(self.X):
l += 1
return l

def __getitem__(self, index):
X = self.X[index*self.batch_size:(index+1)*self.batch_size]
y = self.y[index*self.batch_size:(index+1)*self.batch_size]
return X, y

X = np.random.rand(200,10,10,1)
y = to_categorical(np.random.randint(0,2,200))
model.fit_generator(DataGenerator(X,y,13), epochs=10)

输出:

纪元 1/10
16/16 [================================] - 0s 2ms/步 - 损耗:0.6774 - 加速器:0.6097

如您所见,它在一个时期内运行了 16 个批处理,即 13*15+5=200

关于python - 当batch_size与数据量不匹配时,Keras自定义生成器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55147276/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com