gpt4 book ai didi

tensorflow - 发电机在错误的时间调用(keras)

转载 作者:行者123 更新时间:2023-12-04 01:00:48 25 4
gpt4 key购买 nike

我在 keras 2.0.2 中使用 fit_generator(),批量大小为 10,步长为 320,因为我有 3209 个样本用于训练。在第一个 epoch 开始之前,生成器被调用了 11 次,显示:

Train -- get ind: 0 to 10
...
Train -- get ind: 100 to 110

然后,在第一批(1/320)之后,它打印出 Train -- get ind: 110 to 120,但我认为应该是 Train -- get ind: 0 到 10。我对 train_generator() 函数的实现是否不正确?或者为什么我会遇到这个问题?

这是我的生成器代码:

EPOCH = 10
x_train_img = img[:train_size] # shape: (3209,512,512)
x_test_img = img[train_size:] # shape: (357,512,512)

def train_generator():
global x_train_img

last_ind = 0

while 1:
x_train = x_train_img[last_ind:last_ind+BATCH_SIZE]
print('Train -- get ind: ',last_ind," to ",last_ind+BATCH_SIZE)
last_ind = last_ind+BATCH_SIZE
x_train = x_train.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 512, 512, 1))
yield (x_train, x_train)
if last_ind >= x_train_img.shape[0]:
last_ind = 0

def test_generator():
...

train_steps = x_train_img.shape[0]//BATCH_SIZE #320
test_steps = x_test_img.shape[0]//BATCH_SIZE #35

autoencoder.fit_generator(train_generator(),
steps_per_epoch=train_steps,
epochs=EPOCH,
validation_data=test_generator(),
validation_steps=test_steps,
callbacks=[csv_logger] )

更好?生成器的写法:

def train_generator():
global x_train_img

while 1:
for i in range(0, x_train_img.shape[0], BATCH_SIZE):
x_train = x_train_img[i:i+BATCH_SIZE]
print('Train -- get ind: ',i," to ",i+BATCH_SIZE)
x_train = x_train.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 512, 512, 1))
yield (x_train, x_train)

最佳答案

默认情况下,fit_generator() 使用 max_queue_size=10。所以你观察到的是:

  1. 在纪元开始之前,您的生成器会产生 10 个批处理来填充队列。这是样本 0 到 100。
  2. 然后,epoch 开始,从队列中弹出一批进行模型拟合。
  3. 生成器生成一个新批处理来填充队列中的空白空间。这是样本 100 到 110。
  4. 然后,更新进度条。进度 1/320 打印在屏幕上。
  5. 再次执行步骤2和3,打印get ind: 110 to 120

所以这个模型拟合过程没有任何问题。生成的第一个批处理确实是第一个用于拟合模型的批处理。只是它背后隐藏着一个队列,在第一次模型更新发生之前,生成器被多次调用以填充队列。

关于tensorflow - 发电机在错误的时间调用(keras),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45303518/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com