gpt4 book ai didi

python - ImageDataGenerator.flow : NumpyArrayIterator is set to use the data format convention "channels_last", 修复后导致拟合错误

转载 作者:太空宇宙 更新时间:2023-11-04 05:02:05 25 4
gpt4 key购买 nike

我的数据是形状 (60000, 1, 28, 28) 当我尝试按如下方式批量获取它们时:

gen = image.ImageDataGenerator()
train_batches = gen.flow(x_train, y_train, batch_size=64)

我得到错误:

ValueError: NumpyArrayIterator is set to use the data format convention "channels_last" (channels on axis 3), i.e. expected either 1, 3 or 4 channels on axis 3. However, it was passed an array with shape (60000, 1, 28, 28) (28 channels).

为了摆脱它,我这样做:

train_batches = gen.flow(np.swapaxes(x_train,1,3), y_train, batch_size=64)

虽然这确实消除了上述错误,但它会产生以下错误:

ValueError: Error when checking input: expected lambda_13_input to have shape (None, 1, 28, 28) but got array with shape (64, 28, 28, 1)

做的时候:

lin_model.fit_generator(train_batches, train_batches.n, nb_epoch=1, 
validation_data= test_batches, nb_val_samples=test_batches.n)

我确保我添加到我的代码排序说明符中:

import keras.backend as k
k.image_dim_ordering() == 'th'

完整的跟踪是:

ValueError                                Traceback (most recent call last)
<ipython-input-138-f8ea3b9faad4> in <module>()
----> 1 training_routine(lin_model)

<ipython-input-136-8b3171cd58ae> in training_routine(model)
2 model.optimizer.lr = 0.001
3 model.fit_generator(train_batches, train_batches.n, nb_epoch=1,
----> 4 validation_data= test_batches, nb_val_samples=test_batches.n)
5 model.optimizer.lr = 0.1
6 model.fit_generator(train_batches, train_batches.n, nb_epoch=1,

/home/matar/anaconda2/lib/python2.7/site-packages/keras/legacy/interfaces.pyc in wrapper(*args, **kwargs)
85 warnings.warn('Update your `' + object_name +
86 '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 87 return func(*args, **kwargs)
88 wrapper._original_function = func
89 return wrapper

/home/matar/anaconda2/lib/python2.7/site-packages/keras/models.pyc in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, initial_epoch)
1115 workers=workers,
1116 use_multiprocessing=use_multiprocessing,
-> 1117 initial_epoch=initial_epoch)
1118
1119 @interfaces.legacy_generator_methods_support

/home/matar/anaconda2/lib/python2.7/site-packages/keras/legacy/interfaces.pyc in wrapper(*args, **kwargs)
85 warnings.warn('Update your `' + object_name +
86 '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 87 return func(*args, **kwargs)
88 wrapper._original_function = func
89 return wrapper

/home/matar/anaconda2/lib/python2.7/site-packages/keras/engine/training.pyc in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, initial_epoch)
1838 outs = self.train_on_batch(x, y,
1839 sample_weight=sample_weight,
-> 1840 class_weight=class_weight)
1841
1842 if not isinstance(outs, list):

/home/matar/anaconda2/lib/python2.7/site-packages/keras/engine/training.pyc in train_on_batch(self, x, y, sample_weight, class_weight)
1557 sample_weight=sample_weight,
1558 class_weight=class_weight,
-> 1559 check_batch_axis=True)
1560 if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
1561 ins = x + y + sample_weights + [1.]

/home/matar/anaconda2/lib/python2.7/site-packages/keras/engine/training.pyc in _standardize_user_data(self, x, y, sample_weight, class_weight, check_batch_axis, batch_size)
1232 self._feed_input_shapes,
1233 check_batch_axis=False,
-> 1234 exception_prefix='input')
1235 y = _standardize_input_data(y, self._feed_output_names,
1236 output_shapes,

/home/matar/anaconda2/lib/python2.7/site-packages/keras/engine/training.pyc in _standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
138 ' to have shape ' + str(shapes[i]) +
139 ' but got array with shape ' +
--> 140 str(array.shape))
141 return arrays
142

ValueError: Error when checking input: expected lambda_13_input to have shape (None, 1, 28, 28) but got array with shape (64, 28, 28, 1)

最佳答案

在 keras.json 中将 "image_data_format":"channels_last" 更改为 "image_data_format":"channels_first",可以通过键入 whereis keras 找到它。 json 在终端中。

它被设置为 channels_last 以适应 tensorflow 作为后端,但因此这里使用了 theano,它应该相应地改变。

关于python - ImageDataGenerator.flow : NumpyArrayIterator is set to use the data format convention "channels_last", 修复后导致拟合错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45467720/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com