gpt4 book ai didi

Python 小批量字典学习

转载 作者:行者123 更新时间:2023-11-28 17:25:09 24 4
gpt4 key购买 nike

我想在 python 中使用字典学习实现错误跟踪,使用 sklearn 的 MiniBatchDictionaryLearning ,这样我就可以记录误差是如何随着迭代而减少的。我有两种方法可以做到这一点,但都没有奏效。设置:

  • 输入数据 X,numpy 数组形状 (n_samples, n_features) = (298143, 300)。这些是形状为 (10, 10) 的补丁,由形状为 (642, 480, 3) 的图像生成。
  • 字典学习参数:列数(或原子数)= 100,alpha = 2,变换算法= OMP,总数。迭代次数 = 500(先保持小,就像测试用例一样)
  • 计算错误:学习字典后,我根据学习到的字典再次对原图进行编码。由于编码和原始编码都是相同形状 (642, 480, 3) 的 numpy 数组,所以我现在只是按元素计算欧几里得距离:

    err = np.sqrt(np.sum(reconstruction - original)**2))

我用这些参数进行了测试,完全拟合能够产生非常好的重建,误差很小,这很好。现在介绍两种方法:

方法1:每100次迭代保存学习到的字典,并记录错误。对于 500 次迭代,这给了我们 5 次运行,每次运行 100 次迭代。每次运行后,我计算错误,然后使用当前学习的字典作为下一次运行的初始化。

# Fit an initial dictionary, V, as a first run
dico = MiniBatchDictionaryLearning(n_components = 100,
alpha = 2,
n_iter = 100,
transform_algorithm='omp')
dl = dico.fit(patches)
V = dl.components_

# Now do another 4 runs.
# Note the warm restart parameter, dict_init = V.
for i in range(n_runs):
print("Run %s..." % i, end = "")
dico = MiniBatchDictionaryLearning(n_components = 100,
alpha = 2,
n_iter = n_iterations,
transform_algorithm='omp',
dict_init = V)
dl = dico.fit(patches)
V = dl.components_

img_r = reconstruct_image(dico, V, patches)
err = np.sqrt(np.sum((img - img_r)**2))
print("Err = %s" % err)

问题:错误没有减少,而且相当高。字典也没学好。

方法 2:将输入数据 X 分成 500 个批处理,使用 partial_fit() 方法进行部分拟合。

batch_size = 500
n_batches = X.shape[0] // batch_size
print(n_batches) # 596

for iternum in range(n_batches):
batch = patches[iternum*batch_size : (iternum+1)*batch_size]
V = dico.partial_fit(batch)

问题:这似乎需要大约 5000 倍的时间。

我想知道是否有办法检索拟合过程中的错误?

最佳答案

每次调用 fit 都会重新初始化模型并忘记之前对 fit 的任何调用:这是 scikit-learn 中所有估算器的预期行为。

我认为在循环中使用 partial_fit 是正确的解决方案,但你应该在小批量上调用它(就像在 fit 方法中所做的那样,默认的 batch_size 值仅为 3)然后只计算例如,每 100 次或 1000 次调用 partial_fit 的成本:

batch_size = 3
n_epochs = 20
n_batches = X.shape[0] // batch_size
print(n_batches) # 596


n_updates = 0
for epoch in range(n_epochs):
for i in range(n_batches):
batch = patches[i * batch_size:(i + 1) * batch_size]
dico.partial_fit(batch)
n_updates += 1
if n_updates % 100 == 0:
img_r = reconstruct_image(dico, dico.components_, patches)
err = np.sqrt(np.sum((img - img_r)**2))
print("[epoch #%02d] Err = %s" % (epoch, err))

关于Python 小批量字典学习,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39629931/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com