gpt4 book ai didi

python - softmax_loss函数: Turn the loop into matrix operation

转载 作者:行者123 更新时间:2023-11-30 09:19:09 25 4
gpt4 key购买 nike

我现在正在学习stanford cs231n类(class)。当完成softmax_loss函数时,我发现用全向量化类型编写并不容易,尤其是处理dw项。下面是我的代码。有人可以优化代码吗?将不胜感激。

def softmax_loss_vectorized(W, X, y, reg):

loss = 0.0
dW = np.zeros_like(W)


num_train = X.shape[0]
num_classes = W.shape[1]

scores = X.dot(W)
scores -= np.max(scores, axis = 1)[:, np.newaxis]
exp_scores = np.exp(scores)
sum_exp_scores = np.sum(exp_scores, axis = 1)
correct_class_score = scores[range(num_train), y]

loss = np.sum(np.log(sum_exp_scores)) - np.sum(correct_class_score)

exp_scores = exp_scores / sum_exp_scores[:,np.newaxis]

# **maybe here can be rewroten into matrix operations**
for i in xrange(num_train):
dW += exp_scores[i] * X[i][:,np.newaxis]
dW[:, y[i]] -= X[i]

loss /= num_train
loss += 0.5 * reg * np.sum( W*W )
dW /= num_train
dW += reg * W


return loss, dW

最佳答案

下面是一个矢量化实现。但我建议您尝试多花一点时间,自己找到解决方案。这个想法是构造一个包含所有 softmax 值的矩阵,并从正确的元素中减去 -1

def softmax_loss_vectorized(W, X, y, reg):
num_train = X.shape[0]

scores = X.dot(W)
scores -= np.max(scores)
correct_scores = scores[np.arange(num_train), y]

# Compute the softmax per correct scores in bulk, and sum over its logs.
exponents = np.exp(scores)
sums_per_row = np.sum(exponents, axis=1)
softmax_array = np.exp(correct_scores) / sums_per_row
information_array = -np.log(softmax_array)
loss = np.mean(information_array)

# Compute the softmax per whole scores matrix, which gives the matrix for X rows coefficients.
# Their linear combination is algebraically dot product X transpose.
all_softmax_matrix = (exponents.T / sums_per_row).T
grad_coeff = np.zeros_like(scores)
grad_coeff[np.arange(num_train), y] = -1
grad_coeff += all_softmax_matrix
dW = np.dot(X.T, grad_coeff) / num_train

# Regularization
loss += 0.5 * reg * np.sum(W * W)
dW += reg * W

return loss, dW

关于python - softmax_loss函数: Turn the loop into matrix operation,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46691293/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com