gpt4 book ai didi

python - 使用 tflearn 构建 CNN

转载 作者:行者123 更新时间:2023-11-28 18:56:08 24 4
gpt4 key购买 nike

我正在尝试使用以下代码使用 tflearn 构建 CNN

train_images, train_labels, test_images, test_labels = load_dataset()
convnet = input_data(shape=[None, 28, 28, 1], name='input')
convnet = conv_2d(convnet, 32, 2, activation='relu')
convnet = max_pool_2d(convnet, 2)
convnet = conv_2d(convnet, 64, 2, activation='relu')
convnet = max_pool_2d(convnet, 2)
convnet = fully_connected(convnet, 1024, activation='relu')
convnet = dropout(convnet, 0.8)
convnet = fully_connected(convnet, 24, activation='softmax')
convnet = regression(convnet, optimizer='adam', learning_rate=0.01, loss='categorical_crossentropy')

model = tflearn.DNN(convnet)
model.fit(train_images, train_labels, n_epoch=30,
validation_set=(test_images, test_labels),
snapshot_step=500, show_metric=True, run_id='characterOCR')
model.save('CNN.model')

数据集被 reshape 为以下形状

def load_mnist_images(filename):
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
data = data.reshape(-1, 28, 28, 1)
return data

这是我的数据集,但我是基于 MNIST 结构构建的,现在出现以下错误:

Traceback (most recent call last):
File "/home/hassan/JPG-PNG-to-MNIST-NN-Format/CNN_network.py", line 56, in <module>
snapshot_step=500, show_metric=True, run_id='characterOCR')
File "/home/hassan/anaconda3/envs/object-detection/lib/python3.7/site-packages/tflearn/models/dnn.py", line 216, in fit
callbacks=callbacks)
File "/home/hassan/anaconda3/envs/object-detection/lib/python3.7/site-packages/tflearn/helpers/trainer.py", line 339, in fit
show_metric)
File "/home/hassan/anaconda3/envs/object-detection/lib/python3.7/site-packages/tflearn/helpers/trainer.py", line 818, in _train
feed_batch)
File "/home/hassan/anaconda3/envs/object-detection/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run
run_metadata_ptr)
File "/home/hassan/anaconda3/envs/object-detection/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1149, in _run
str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (64,) for Tensor 'TargetsData/Y:0', which has shape '(?, 24)'

有没有人有解决办法?

最佳答案

对于 tflearn 我没有解决所有问题 我所知道的错误来自于我向完全连接的层形状 24 提供形状为 64 的数据 我不知道在哪里这64来自

但我从 tflearn 切换到 Keras API,现在它可以工作了

如果有人想知道源码请告诉我

关于python - 使用 tflearn 构建 CNN,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58779264/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com