gpt4 book ai didi

Python 如何在神经网络中绘制错误

转载 作者:行者123 更新时间:2023-11-30 09:08:10 25 4
gpt4 key购买 nike

我正在从 http://iamtrask.github.io/2015/07/27/python-network-part2/ 学习神经网络博客。我想使用 matplotlib 绘制如图所示的错误。

enter image description here

我怎样才能做到这一点?我尝试将数据存储在列表中,但我的解决方案不起作用。来自 trask 博客:

Let's try to plot what the error plane looks like for the network/dataset above. So, how do we compute the error for a given set of weights? Lines 31,32,and 35 show us that. If we take that logic and plot the overall error (a single scalar representing the network error over the entire dataset) for every possible set of weights (from -10 to 10 for x and y), it looks something like this.

import numpy as np
import matplotlib as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter

# 2 layer neural network

def sigmoid(x):
output = 1 / (1+np.exp(-x))
return output


def sigmoid_output_to_derivative(output):
return output*(1-output)


X = np.array([
[0,1],
[0,1],
[1,0],
[1,0]
])

y = np.array([[0, 0, 1, 1]]).T

np.random.seed(1)

synapse_0 = 2*np.random.random((2, 1)) - 1

data = list()

for iter in xrange(1000):

layer_0 = X
layer_1 = sigmoid(np.dot(layer_0, synapse_0))

layer_1_error = layer_1 - y

layer_1_delta = layer_1_error * sigmoid_output_to_derivative(layer_1)
synapse_0_deriative = np.dot(layer_0.T, layer_1_delta)

synapse_0 -= synapse_0_deriative

data.append(np.array([layer_0, layer_1, layer_1_error]))

print "Error: {}".format(layer_1_error)

fig = plt.figure()
ax = fig.gca(projection='3d')

# x,y,z,c = data

print data


# surf = ax.plot_surface(x,y,z, cmap=cm.coolwarm,
# linewidth=0, antialiased=False)

编辑:

我尝试:

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import random

def sigmoid(x):
output = 1 / (1+np.exp(-x))
return output


def sigmoid_output_to_derivative(output):
return output*(1-output)

X = np.array([
[0,1],
[0,1],
[1,0],
[1,0]
])

y = np.array([[0, 0, 1, 1]]).T

np.random.seed(1)

synapse_0 = 2*np.random.random((2, 1)) - 1

layer_1_error = ""

errors_sum = np.array([])

for iter in xrange(12):

layer_0 = X
layer_1 = sigmoid(np.dot(layer_0, synapse_0))

layer_1_error = layer_1 - y

layer_1_delta = layer_1_error * sigmoid_output_to_derivative(layer_1)
synapse_0_deriative = np.dot(layer_0.T, layer_1_delta)

synapse_0 -= synapse_0_deriative

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

surf = ax.plot_surface(range(-10, 10), range(-10, 10), layer_1_error, linewidth=0, antialiased=False)
plt.show()

结果:

enter image description here

我不知道如何收集 for 循环中的所有 layer_1_error

最佳答案

绘制相对于突触权重的误差曲面需要改变权重并评估每个组合的平均误差。您可以在这里找到执行此操作的代码草图:

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np

def sigmoid(x):
output = 1.0 / (1.0 + np.exp(-x))
return output

def sigmoid_output_to_derivative(output):
return output*(1-output)

X = np.array([
[0,1],
[0,1],
[1,0],
[1,0]
])
y = np.array([[0, 0, 1, 1]]).T

synapse_0 = np.empty((2,1))

# the error aggregation starts here
x_range = np.linspace(-10, 10, 20, dtype=np.float)
y_range = np.linspace(-10, 10, 20, dtype=np.float)
errors = []
for _x in x_range:
synapse_0[0] = _x
for _y in y_range:
synapse_0[1] = _y

# apply the model to the input
layer_0 = X
layer_1 = sigmoid(np.dot(layer_0, synapse_0))

# evaluate the error using the RMSE
error = np.mean(np.sqrt((layer_1 - y) ** 2))
errors.append(error)

# in order to plot we need to transform x,y and z in 2D array
error_surface = np.reshape(np.array(errors), (x_range.shape[0], y_range.shape[0]))
_X, _Y = np.meshgrid(x_range, y_range, indexing='ij')

# plot
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_surface(_X, _Y, error_surface, cmap=cm.YlOrBr_r, edgecolor='gray', linewidth=0.004, antialiased=False)
plt.show()

结果图如下: enter image description here

关于Python 如何在神经网络中绘制错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47083947/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com