gpt4 book ai didi

neural-network - 使用 Theano 进行平均池化

转载 作者:行者123 更新时间:2023-12-02 03:34:19 25 4
gpt4 key购买 nike

我正在尝试使用 Theano 为神经网络实现另一个池化函数,期望已经存在的 maxpool,例如平均池。

使用this source ,其中已经实现了平均池化,我的代码如下所示:

随机初始化只是为了测试:

invals = numpy.random.RandomState(1).rand(3,2,5,5) 

Theano 标量和函数的定义:

pdim = T.scalar('pool dim', dtype='float32')
pool_inp = T.tensor4('pool input', dtype='float32')
pool_sum = TSN.images2neibs(pool_inp, (pdim, pdim))
pool_out = pool_sum.mean(axis=-1)
pool_fun = theano.function([pool_inp, pdim], pool_out, name = 'pool_fun', allow_input_downcast=True)

TSN 是 theano.sandbox.neighbours

以及函数的调用:

pool_dim = 2
temp = pool_fun(invals, pool_dim)
temp.shape = (invals.shape[0], invals.shape[1], invals.shape[2]/pool_dim,
invals.shape[3]/pool_dim)
print ('invals[1,0,:,:]=\n', invals[1,0,:,:])
print ('output[1,0,:,:]=\n',temp[1,0,:,:])

我收到一个错误:

TypeError: neib_shape[0]=2, neib_step[0]=2 and ten4.shape[2]=5 not consistent
Apply node that caused the error: Images2Neibs{valid}(pool input, MakeVector.0, MakeVector.0)
Inputs shapes: [(3, 2, 5, 5), (2,), (2,)]
Inputs strides: [(200, 100, 20, 4), (4,), (4,)]
Inputs types: [TensorType(float32, 4D), TensorType(float32, vector), TensorType(float32, vector)]
Use the Theano flag 'exception_verbosity=high' for a debugprint of this apply node.

我不太明白这个错误。很高兴有任何建议如何纠正这个错误或其他池化技术的例子,在 Theano 中编程。

谢谢!

编辑:忽略边框,效果很好

pool_sum = TSN.images2neibs(pool_inp, (pdim, pdim), mode='ignore_borders')

invals[1,0,:,:]=
[[ 0.01936696 0.67883553 0.21162812 0.26554666 0.49157316]
[ 0.05336255 0.57411761 0.14672857 0.58930554 0.69975836]
[ 0.10233443 0.41405599 0.69440016 0.41417927 0.04995346]
[ 0.53589641 0.66379465 0.51488911 0.94459476 0.58655504]
[ 0.90340192 0.1374747 0.13927635 0.80739129 0.39767684]]
output[1,0,:,:]=
[[ 0.33142066 0.30330223]
[ 0.42902038 0.64201581]]

最佳答案

invals 在最后两个维度中的形状为 (5, 5),但是您希望合并 (2, 2) 子集.这仅在您忽略边框(即 invals 的最后一列和最后一行)时有效。

关于neural-network - 使用 Theano 进行平均池化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/24654389/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com