gpt4 book ai didi

python - 如何使用小批量代替 SGD

转载 作者:行者123 更新时间:2023-11-30 09:00:35 30 4
gpt4 key购买 nike

这是一个用 python 快速实现单层神经网络的方法:

import numpy as np

# simulate data
np.random.seed(94106)
X = np.random.random((200, 3)) # 100 3d vectors
# first col is set to 1
X[:, 0] = 1
def simu_out(x):
return np.sum(np.power(x, 2))
y = np.apply_along_axis(simu_out, 1, X)
# code 1 if above average
y = (y > np.mean(y)).astype("float64")*2 - 1
# split into training and testing sets
Xtr = X[:100]
Xte = X[100:]
ytr = y[:100]
yte = y[100:]
w = np.random.random(3)

# 1 layer network. Final layer has one node
# initial weights,
def epoch():
err_sum = 0
global w
for i in range(len(ytr)):
learn_rate = .1
s_l1 = Xtr[i].T.dot(w) # signal at layer 1, pre-activation
x_l1 = np.tanh(s_l1) # output at layer 1, activation
err = x_l1 - ytr[i]
err_sum += err
# see here: https://youtu.be/Ih5Mr93E-2c?t=51m8s
delta_l1 = 2 * err * (1 - x_l1**2)
dw = Xtr[i] * delta_l1
w -= learn_rate * dw
print("Mean error: %f" % (err_sum / len(ytr)))
epoch()
for i in range(1000):
epoch()

def predict(X):
global w
return np.sign(np.tanh(X.dot(w)))

# > 80% accuracy!!
np.mean(predict(Xte) == yte)

它使用随机梯度下降进行优化。我在想如何在这里应用小批量梯度下降?

最佳答案

与“经典”SGD 和小批量梯度下降的区别在于,您使用多个样本(所谓的小批量)来计算 w 的更新。这样做的优点是,当您遵循平滑梯度时,您在解决方案方向上采取的步骤噪音较小。

为此,您需要一个内部循环来计算更新dw,在其中迭代小批量。例如(快速而肮脏的代码):

def epoch(): 
err_sum = 0
learn_rate = 0.1
global w
for i in range(int(ceil(len(ytr) / batch_size))):
batch = Xtr[i:i+batch_size]
target = ytr[i:i+batch_size]
dw = np.zeros_like(w)
for j in range(batch_size):
s_l1 = batch[j].T.dot(w)
x_l1 = np.tanh(s_l1)
err = x_l1 - target[j]
err_sum += err
delta_l1 = 2 * err * (1 - x_l1**2)
dw += batch[j] * delta_l1
w -= learn_rate * (dw / batch_size)
print("Mean error: %f" % (err_sum / len(ytr)))

在测试中准确率为 87%。

现在,还有一件事:您总是从头到尾检查训练集。您绝对应该在每次迭代中打乱数据。始终按照相同的顺序进行确实会影响您的表现,尤其是当您首先拥有 A 类的所有样本,然后是 B 类的所有样本。这也可以使您的训练循环进行。因此,只需按随机顺序浏览该集合即可,例如与

order = np.random.permutation(len(ytr))

并用 epoch() 函数中的 order[i] 替换所有出现的 i

还有一个更笼统的评论:全局变量通常被认为是糟糕的设计,因为您无法控制哪个代码段修改您的变量。而是将 w 作为参数传递。学习率和批量大小也是如此。

关于python - 如何使用小批量代替 SGD,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40710169/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com