gpt4 book ai didi

python - 当我在 tensorflow 中实现 unpool 时,tf.scatter_add 出现了一个奇怪的错误

转载 作者:太空宇宙 更新时间:2023-11-04 04:27:46 25 4
gpt4 key购买 nike

我正在尝试使用 tf.scatter_add 在 tensorflow 中实现 unpool,但我遇到了一个奇怪的错误,这是我的代码:

import tensorflow as tf
import numpy as np
import random

tf.reset_default_graph()

mat = list(range(64))
random.shuffle(mat)
mat = np.array(mat)
mat = np.reshape(mat, [1,8,8,1])
M = tf.constant(mat, dtype=tf.float32)
pool1, argmax1 = tf.nn.max_pool_with_argmax(M, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
pool2, argmax2 = tf.nn.max_pool_with_argmax(pool1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
pool3, argmax3 = tf.nn.max_pool_with_argmax(pool2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')


def unpool(x, argmax, strides, unpool_shape=None, batch_size=None, name='unpool'):
x_shape = x.get_shape().as_list()
argmax_shape = argmax.get_shape().as_list()
assert not(x_shape[0] is None and batch_size is None), "must input batch_size if number of batch is alterable"
if x_shape[0] is None:
x_shape[0] = batch_size
if argmax_shape[0] is None:
argmax_shape[0] = x_shape[0]
if unpool_shape is None:
unpool_shape = [x_shape[i] * strides[i] for i in range(4)]
x_unpool = tf.get_variable(name=name, shape=[np.prod(unpool_shape)], initializer=tf.zeros_initializer(), trainable=False)
argmax = tf.cast(argmax, tf.int32)
argmax = tf.reshape(argmax, [np.prod(argmax_shape)])
x = tf.reshape(x, [np.prod(argmax.get_shape().as_list())])
x_unpool = tf.scatter_add(x_unpool , argmax, x)
x_unpool = tf.reshape(x_unpool , unpool_shape)
return x_unpool


unpool2 = unpool(pool3, argmax3, strides=[1,2,2,1], name='unpool3')
unpool1 = unpool(unpool2, argmax2, strides=[1,2,2,1], name='unpool2')
unpool0 = unpool(unpool1, argmax1, strides=[1,2,2,1], name='unpool1')


with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
mat_out = mat[:,:,:,0]
pool1_out = sess.run(pool1)[0,:,:,0]
pool2_out = sess.run(pool2)[0,:,:,0]
pool3_out = sess.run(pool3)[0,:,:,0]
argmax1_out = sess.run(argmax1)[0,:,:,0]
argmax2_out = sess.run(argmax2)[0,:,:,0]
argmax3_out = sess.run(argmax3)[0,:,:,0]
unpool2_out = sess.run(unpool2)[0,:,:,0]
unpool1_out = sess.run(unpool1)[0,:,:,0]
unpool0_out = sess.run(unpool0)[0,:,:,0]
print(unpool2_out)
print(unpool1_out)
print(unpool0_out)

输出:

[[ 0.  0.]
[ 0. 63.]]
[[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 126. 0.]
[ 0. 0. 0. 0.]]
[[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 315. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]]

位置是对的,但是值是错的。 unpool2 是对的,unpool1 是期望值的两倍,unpool2 是期望值的五倍。我不知道出了什么问题,谁能告诉我如何修复这个错误?

提前致谢。

最佳答案

其实答案很简单。为了方便,我重命名了一些变量,看这段代码:

def unpool(x, argmax, strides, unpool_shape=None, batch_size=None, name='unpool'):
x_shape = x.get_shape().as_list()
argmax_shape = argmax.get_shape().as_list()
assert not(x_shape[0] is None and batch_size is None), "must input batch_size if number of batch is alterable"
if x_shape[0] is None:
x_shape[0] = batch_size
if argmax_shape[0] is None:
argmax_shape[0] = x_shape[0]
if unpool_shape is None:
unpool_shape = [x_shape[i] * strides[i] for i in range(4)]
x_unpool = tf.get_variable(name=name, shape=[np.prod(unpool_shape)], initializer=tf.zeros_initializer(), trainable=False)
argmax = tf.cast(argmax, tf.int32)
argmax = tf.reshape(argmax, [np.prod(argmax_shape)])
x = tf.reshape(x, [np.prod(argmax.get_shape().as_list())])
x_unpool_add = tf.scatter_add(x_unpool , argmax, x)
x_unpool_reshape = tf.reshape(x_unpool_add , unpool_shape)
return x_unpool_reshape

x_unpool_add 是 tf.scatter_add 的一个操作,每次我们计算 x_unpool_reshape 时,x_unpool_add 都会被调用。所以如果我们计算 unpool2 两次,x_unpool 将添加 x 两次。在我的原始代码中,我依次计算unpool0、unpool1、unpool2,先调用unpool1的x_unpool_add,然后在计算unpool2时,由于需要计算unpool1,x_unpool_add会再次调用,所以等于调用了两次x_unpool_add,值是错误的。如果我们直接计算 unpool2,我们会得到正确的结果。所以用 tf.scatter_update 替换 tf.scatter_add 可以避免这个错误。

这段代码可以直观地重现:

import tensorflow as tf

t1 = tf.get_variable(name='t1', shape=[1], dtype=tf.float32, initializer=tf.zeros_initializer())
t2 = tf.get_variable(name='t2', shape=[1], dtype=tf.float32, initializer=tf.zeros_initializer())
d = tf.scatter_add(t1, [0], [1])
e = tf.scatter_add(t2, [0], d)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
d_out1 = sess.run(d)
d_out2 = sess.run(d)
e_out = sess.run(e)
print(d_out1)
print(d_out2)
print(e_out)

输出:

[1.]
[2.]
[3.]

关于python - 当我在 tensorflow 中实现 unpool 时,tf.scatter_add 出现了一个奇怪的错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53204794/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com