gpt4 book ai didi

python - 为什么我的损失趋于下降,而我的准确度却趋于零?

转载 作者:太空宇宙 更新时间:2023-11-04 07:26:29 25 4
gpt4 key购买 nike

我正在尝试使用 Tensorflow/Keras 练习我的机器学习技能,但我在拟合模型方面遇到了问题。让我解释一下我做了什么以及我在哪里。

我正在使用来自 Kaggle 的 Costa Rican Household Poverty Level Prediction Challenge 的数据集

因为我只是想熟悉 Tensorflow 工作流程,所以我通过删除一些有很多缺失数据的列来清理数据集,然后用它们的平均值填充其他列。所以我的数据集中没有缺失值。

接下来,我使用来自 TF 的 make_csv_dataset 加载了新的、清理过的 csv。

batch_size = 32

train_dataset = tf.data.experimental.make_csv_dataset(
'clean_train.csv',
batch_size,
column_names=column_names,
label_name=label_name,
num_epochs=1)

我设置了一个函数来返回我的编译模型,如下所示:

f1_macro = tfa.metrics.F1Score(num_classes=4, average='macro')

def get_compiled_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(137,)), # input shape required
tf.keras.layers.Dense(256, activation=tf.nn.relu),
tf.keras.layers.Dense(4, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=[f1_macro, 'accuracy'])
return model
model = get_compiled_model()
model.fit(train_dataset, epochs=15)

结果如下

This is my output

我笔记本的链接是Here

我应该提一下,我的实现很大程度上基于 Tensorflow 的鸢尾花数据 walkthrough

谢谢!

最佳答案

一段时间后,我找到了您代码中的问题,它们按重要性排序。 (第一个最重要)

  1. 您正在进行多类分类(而不是二分类)。因此,您的损失应该是 categorical_crossentropy

  2. 您不是一次性编码您的标签。使用 binary_crossentropy 并将标签作为数字 ID 绝对不是前进的方向。相反,您应该对标签进行 onehot 编码并像解决多类分类问题一样解决这个问题。以下是您的操作方法。

def pack_features_vector(features, labels):
"""Pack the features into a single array."""
features = tf.stack(list(features.values()), axis=1)
return features, tf.one_hot(tf.cast(labels-1, tf.int32), depth=4)
  1. 规范化您的数据。如果你看看你的训练数据。它们没有标准化。他们的值(value)观无处不在。因此,您应该考虑通过执行以下操作来规范化数据。这仅用于演示目的。你应该阅读 Scalers在 scikit 中学习并选择最适合你的。
x = train_df[feature_names].values #returns a numpy array
min_max_scaler = preprocessing.StandardScaler()
x_scaled = min_max_scaler.fit_transform(x)
train_df = pd.DataFrame(x_scaled)

这些问题应该让您的模型直截了当。

关于python - 为什么我的损失趋于下降,而我的准确度却趋于零?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59063887/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com