gpt4 book ai didi

python - 需要帮助向量化一些 Python 代码

转载 作者:行者123 更新时间:2023-11-30 23:29:22 24 4
gpt4 key购买 nike

我有一些代码需要矢量化帮助。我想将以下内容转换为矢量形式,我该怎么做?我想摆脱内部循环 - 显然,这是可能的。X 是 NxD 矩阵。 y 是 1xD 向量。

def foo(X, y, mylambda, N, D, epsilon): 
...
for j in xrange(D):
aj = 0
cj = 0
for i in xrange(N):
aj += 2 * (X[i,j] ** 2)
cj += 2 * (X[i,j] * (y[i] - w.transpose()*X[i].transpose() + w[j]*X[i,j]))

...

如果我在函数上调用 numpy.vectorize(),它会在运行时抛出错误。

完整代码:

import scipy
import scipy.io
import numpy
from numpy import linalg
from scipy import *

def data(N, d, k, sigma, seed=12231):
random.seed(seed)
X = randn(N, d)
wg = zeros(1 + d)
wg[1:k + 1] = 10 * sign(randn(k))
eps = randn(N) * sigma
y = X.dot(wg[1:]) + wg[0] + eps
return (y, X)


def foo(X, y, mylambda, n, D, epsilon):
identityMatrix = numpy.matrix(numpy.identity(D))

w = (X.transpose() * X + mylambda * identityMatrix).getI() * X.transpose() * y
newweight = (X.transpose() * X + mylambda * identityMatrix).getI() * X.transpose() * y

iterate = 1
iteration = 0

while iterate > 0 and iteration < 10000:
iteration += 1
iterate = 0
maxerror = 0
for j in xrange(D):
aj = 0
cj = 0
for i in xrange(n):
aj += 2 * (X[i,j] ** 2)
cj += 2 * (X[i,j] * (y[i] - w.transpose()*X[i].transpose() + w[j]*X[i,j]))

if cj < -mylambda:
newweight[j,0] = (cj + mylambda)/ aj
elif cj > mylambda:
newweight[j,0] = (cj - mylambda)/ aj
else:
newweight[j,0] = 0

if abs(newweight[j,0] - w[j,0]) > epsilon:
iterate += 1
if abs(newweight[j,0] - w[j,0]) > maxerror:
maxerror = abs(newweight[j,0] - w[j,0])
w[j,0] = newweight[j,0]

N, D, k = 50, 75, 5
(y, X) = data(N, D, k, 1, 123)
X = numpy.matrix(X)
y = numpy.matrix(y).transpose()
foo(X, y, 1, N, D, 0.1)

最佳答案

您可以替换:

aj = 0
cj = 0
for i in xrange(n):
aj += 2 * (X[i,j] ** 2)
cj += 2 * (X[i,j] * (y[i] - w.transpose()*X[i].transpose() + w[j]*X[i,j]))

与:

aj = 2*np.sum(X[:,j].T*X[:,j])
cj = 2*np.sum(np.multiply(X[:, j].T, (y.T - w.T*X.T + w[j] * X[:, j].T)))

关于python - 需要帮助向量化一些 Python 代码,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/21209741/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com