gpt4 book ai didi

python - 我如何改进我的模型,使其能够处理数据集中没有的更多字符?

转载 作者:行者123 更新时间:2023-12-01 07:58:06 25 4
gpt4 key购买 nike

在我的上一篇文章中链接 here ,据说我必须修改我的模型才能变得更好。引用唯一回答者对我的问题的评论(再次感谢您,先生):

The accuracy of prediction is a metric of how good your neural network architecture is and it also depends on your train/validation data. You will have to tune your neural network in such a way that you generalize well by adjusting the hyper parameters such as number of layers, type of layers, learning rate, optimizer etc. ...

我想知道我会如何做这些提到的事情。或者至少,指向正确的方向。老实说,我在理论和实践中都迷失了。

我唯一能做的就是将纪元调整到100以上。我还尽可能地清理了要识别的图像。

目前,这是我创建模型的方式。它仅基于Tensorflow 2.0的教程。

import numpy as np
import tensorflow as tf
from tensorflow import keras

# Load and prepare the MNIST dataset. Convert the samples from integers to floating-point numbers:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

def createModel():
# Build the tf.keras.Sequential model by stacking layers.
# Choose an optimizer and loss function used for training:
model = tf.keras.models.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

return model

model = createModel()
model.fit(x_train, y_train, epochs=102, validation_data=(x_test, y_test))
model.evaluate(x_test, y_test)

它为我提供了大约 0.9800 的验证准确度。但它对我从文档中提取的手写字符图像的表现却很糟糕。我还希望对其进行扩展,以便它还可以读取其他选定的字符,但我想这可能是另一天的另一个问题。

谢谢!

最佳答案

您可以在开始时使用多层卷积/最大池,通过扫描图像来执行特征提取。之后,您可以像以前一样使用完全连接的神经网络和 softmax。

您可以这样创建一个带有 CNN 的模型:

from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout 
from keras.models import Sequential

# Create the model
model = Sequential()

# Add the 1st Convolution/ max pool
model.add(Conv2D(40, kernel_size=5, padding="same",input_shape=(28, 28, 1), activation = 'relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))

# 2nd convolution / max pool
model.add(Conv2D(200, kernel_size=3, padding="same", activation = 'relu'))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(1, 1)))

# 3rd convolution/ max pool
model.add(Conv2D(512, kernel_size=3, padding="valid", activation = 'relu'))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(1, 1)))

# Reduce dimensions from 2d to 1d
model.add(Flatten())
model.add(Dense(units=100, activation='relu'))

# Add dropout to prevent overfitting
model.add(Dropout(0.5))

# Final fullyconnected layer
model.add(Dense(10, activation="softmax"))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

print(model.summary())

返回以下模型:

Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D) (None, 28, 28, 40) 1040
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 40) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 14, 14, 200) 72200
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 12, 12, 200) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 10, 10, 512) 922112
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 8, 8, 512) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 32768) 0
_________________________________________________________________
dense_1 (Dense) (None, 100) 3276900
_________________________________________________________________
dropout_1 (Dropout) (None, 100) 0
_________________________________________________________________
dense_2 (Dense) (None, 10) 1010
=================================================================
Total params: 4,273,262
Trainable params: 4,273,262
Non-trainable params: 0
_________________________________________________________________

关于python - 我如何改进我的模型,使其能够处理数据集中没有的更多字符?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55849555/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com