gpt4 book ai didi

python - 神经网络在 1000 个时期后不学习来解决 XOR 问题

转载 作者:行者123 更新时间:2023-12-05 04:57:21 25 4
gpt4 key购买 nike

我正在学习 TensorFlow,并且正在尝试解决 XOR 问题。我创建了一个 3 层神经网络来做到这一点,但在 500 或 1000 个时期之后它根本没有学习。我做错了什么?

我在 colab.research.google 中使用 TensorFlow 2.3.0。

from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.metrics import Accuracy
from tensorflow.keras import Sequential

import numpy as np



x = np.array([[0., 0.],
[1., 1.],
[1., 0.],
[0., 1.]], dtype=np.float32)

y = np.array([[0.],
[0.],
[1.],
[1.]], dtype=np.float32)



model = Sequential()
model.add(Dense(2, activation='sigmoid'))
model.add(Dense(2, activation='sigmoid'))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='SGD', loss='mean_squared_error', metrics='accuracy')
model.fit(x, y, batch_size=1, epochs=1000, verbose=False)

pred = model.predict_on_batch(x)
print(pred)

最佳答案

由于您提到的隐藏层单元为 2,即 Dense(2) 这不足以让模型在给定具有 2 个输入的数组输入的情况下进行学习。我包含了 16 个单元,您可以尝试使用 32、64 等单元进行试验。

对于神经网络中的隐藏层,使用激活函数 ReLu 是最理想的。 (有关这方面的更多详细信息,请参阅 Mark 的评论)。
但是对于这个用例,您可以不提及任何激活函数,但需要更多的时间才能收敛到一个解决方案。

下面是修改后的代码,它用更少的 epoch 预测了正确的输出。

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.metrics import Accuracy
from tensorflow.keras import Sequential

import numpy as np



x = np.array([[0., 0.],
[1., 1.],
[1., 0.],
[0., 1.]], dtype=np.float32)

y = np.array([[0.],
[0.],
[1.],
[1.]], dtype=np.float32)



model = Sequential()
model.add(Dense(16, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='SGD', loss='mean_squared_error', metrics=['accuracy'])
model.fit(x, y, batch_size=1, epochs=500, verbose=False)

pred = model.predict(x).round()
print(pred)

输出:

[[0.]
[0.]
[1.]
[1.]]

关于python - 神经网络在 1000 个时期后不学习来解决 XOR 问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64446561/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com