gpt4 book ai didi

python - 使用 ScikitLearn 的神经网络实现时出现的问题

转载 作者:行者123 更新时间:2023-11-28 18:45:29 25 4
gpt4 key购买 nike

我正在尝试使用 Scikit Learn 提供的神经网络实现来实现图像处理。我有近 10,000 张“JPG”格式的彩色图像,我将这些图像转换为“PNG”格式并删除了颜色信息。新图像都是黑色或白色图像。将这些图像转换为矢量格式后,这些图像矢量构成了神经网络的输入。

对于每个图像,还有一个输出,它构成了神经网络的输出。

输入文件只有 0 和 1 的值,没有任何其他值。每个图像的输出对应于一个连续的向量,该向量介于 0 和 1 之间,长度为 22。即每个图像的输出是一个长度为 22 的向量。

为了开始处理,我开始时只有 100 张图像及其相应的输出,但出现了以下错误:

ValueError: Array contains NaN or infinity

我还要指出,第一次迭代在这里完成,我在第二次迭代时遇到了这个错误。

为了尝试一些不同的东西,我将我的输入和输出缩减为每张 10 张图像。使用同一段代码(很快就会出现),我能够完成 7 次迭代(我已将迭代次数设置为 20),然后收到相同的错误。

然后我将迭代次数更改为 5,只是为了检查它是否有效。进行此更改后,我收到以下错误:

ValueError: bad input shape (10, 22)

我还尝试在我的输入和输出上使用 np.reval() 但这又给了我 NaN 或 Infinity 错误。

这是我在整个过程中使用的代码:

import numpy as np
import csv
import matplotlib.pyplot as plt
from scipy.ndimage import convolve
from sklearn import linear_model, datasets, metrics
from sklearn.cross_validation import train_test_split
from sklearn.neural_network import BernoulliRBM
from sklearn.pipeline import Pipeline


def ReadCsv(fileName):
in_file = open(fileName, 'rUb')
reader = csv.reader(in_file, delimiter=',', quotechar='"')
data = [[]]
for row in reader:
data.append(row)

data.pop(0)
return data

X_train = np.asarray(ReadCsv('100Images.csv'), 'float32')
Y_train = np.asarray(ReadCsv('100Images_Y_new.csv'), 'float32')
X_test = np.asarray(ReadCsv('ImagesForTest.csv'), 'float32')
Y_test = np.asarray(ReadCsv('ImagesForTest_Y_new.csv'), 'float32')

logistic = linear_model.LogisticRegression()
rbm = BernoulliRBM(random_state=0, verbose=True)

classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)])

rbm.learning_rate = 0.06
rbm.n_iter = 5

rbm.n_components = 100
logistic.C = 6000.0

classifier.fit(X_train, Y_train)

print()
print("Logistic regression using RBM features:\n%s\n" % (
metrics.classification_report(
Y_test,
classifier.predict(X_test))))

如果能就此问题提供任何形式的帮助,我将不胜感激。

TIA。

最佳答案

将学习率更改为较小的值可能会解决此问题。 (即 rbm.learning_rate)

至少这解决了我之前遇到的问题。

关于python - 使用 ScikitLearn 的神经网络实现时出现的问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/20849281/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com