gpt4 book ai didi

python - 计算均方误差返回 y_true 和 y_pred 有不同数量的输出 (1!=10)

转载 作者:行者123 更新时间:2023-12-05 00:57:10 27 4
gpt4 key购买 nike

我对深度学习真的很陌生。我想做一个任务:根据测试数据评估模型并计算预测混凝土强度和实际混凝土强度之间的均方误差。您可以使用 Scikit-learn 中的 mean_squared_error 函数。

这是我的代码:

import pandas as pd
from tensorflow.python.keras import Sequential
from tensorflow.python.keras.layers import Dense
from sklearn.model_selection import train_test_split

concrete_data = pd.read_csv('https://cocl.us/concrete_data')

n_cols = concrete_data.shape[1]
model = Sequential()
model.add(Dense(units=10, activation='relu', input_shape=(n_cols-1,)))

model.compile(loss='mean_squared_error',
optimizer='adam')


y = concrete_data.Cement
x = concrete_data.drop('Cement', axis=1)
xTrain, xTest, yTrain, yTest = train_test_split(x, y, test_size = 0.3)

model.fit(xTrain, yTrain, epochs=50)

现在为了评估均方误差,我写了这个:

from sklearn.metrics import mean_squared_error
predicted_y = model.predict(xTest)
mean_squared_error(yTest, predicted_y)

我得到了这个错误:

y_true and y_pred have different number of output (1!=10)

我的 predict_y 形状是:(309, 10)

我用谷歌搜索了它,我真的找不到解决这个问题的答案。我不知道我的代码有什么问题。

最佳答案

您的 y_test 数据形状是 (N, 1),但是因为您在输出层放置了 10 个神经元,所以您的模型会做出 10 种不同的预测,这就是错误。

您需要将输出层中的神经元数量更改为 1,或者添加一个只有 1 个神经元的新输出层。

下面的代码可能对你有用。

import pandas as pd
from tensorflow.python.keras import Sequential
from tensorflow.python.keras.layers import Dense
from sklearn.model_selection import train_test_split

concrete_data = pd.read_csv('https://cocl.us/concrete_data')

n_cols = concrete_data.shape[1]
model = Sequential()
model.add(Dense(units=10, activation='relu', input_shape=(n_cols-1,)))
model.add(Dense(units=1))
model.compile(loss='mean_squared_error',
optimizer='adam')


y = concrete_data.Cement
x = concrete_data.drop('Cement', axis=1)
xTrain, xTest, yTrain, yTest = train_test_split(x, y, test_size = 0.3)

model.fit(xTrain, yTrain, epochs=50)

关于python - 计算均方误差返回 y_true 和 y_pred 有不同数量的输出 (1!=10),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60988136/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com