gpt4 book ai didi

使用索引在 max_pool_with_argmax 之后进行 Tensorflow 反池化

转载 作者:行者123 更新时间:2023-11-30 09:24:37 25 4
gpt4 key购买 nike

在尝试实现 Google 论文中的 U-SegNet 时,我在使用 argmax 索引实现非池化操作时遇到了问题。

完整代码:

import tensorflow as tf


def unpool(pool, ind, ksize=[1, 2, 2, 1], name=None):
with tf.variable_scope('name') as scope:
input_shape = tf.shape(pool)
output_shape = [input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]]

flat_input_size = tf.cumprod(input_shape)[-1]
flat_output_shape = tf.stack([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]])

pool_ = tf.reshape(pool, tf.stack([flat_input_size]))
batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype),
shape=tf.stack([input_shape[0], 1, 1, 1]))
b = tf.ones_like(ind) * batch_range
b = tf.reshape(b, tf.stack([flat_input_size, 1]))
ind_ = tf.reshape(ind, tf.stack([flat_input_size, 1]))
ind_ = tf.concat([b, ind_], 1)

ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
ret = tf.reshape(ret, tf.stack(output_shape))

set_input_shape = pool.get_shape()
set_output_shape = [set_input_shape[0], set_input_shape[1] * ksize[1], set_input_shape[2] * ksize[2], set_input_shape[3]]
ret.set_shape(set_output_shape)
return ret

with tf.Session() as sess:
x = tf.random_normal([1, 4, 4, 1])
y, ind = tf.nn.max_pool_with_argmax(
x,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME'
)

z = unpool(y, ind)

x_, y_, z_ = sess.run([x, y, z])

对于批量大小 1,它工作正常,但对于批量大小 > 1,它会因下一个问题而崩溃:

2018-09-22 16:33:57.010504: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2018-09-22 16:33:57.082638: W tensorflow/core/framework/op_kernel.cc:1275] OP_REQUIRES failed at scatter_nd_op.cc:119 : Invalid argument: Invalid indices: [2,0] = [1, 21] does not index into [4,16]
Traceback (most recent call last):
File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1278, in _do_call
return fn(*args)
File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1263, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1350, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Invalid indices: [2,0] = [1, 21] does not index into [4,16]
[[Node: name/ScatterNd = ScatterNd[T=DT_FLOAT, Tindices=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](name/concat, name/Reshape, name/Cast_1)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "tst.py", line 39, in <module>
x_, y_, z_ = sess.run([x, y, z])
File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 877, in run
run_metadata_ptr)
File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1100, in _run
feed_dict_tensor, options, run_metadata)
File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1272, in _do_run
run_metadata)
File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1291, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Invalid indices: [2,0] = [1, 21] does not index into [4,16]
[[Node: name/ScatterNd = ScatterNd[T=DT_FLOAT, Tindices=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](name/concat, name/Reshape, name/Cast_1)]]

Caused by op 'name/ScatterNd', defined at:
File "tst.py", line 37, in <module>
z = unpool(y, ind)
File "tst.py", line 20, in unpool
ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 6788, in scatter_nd
"ScatterNd", indices=indices, updates=updates, shape=shape, name=name)
File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 454, in new_func
return func(*args, **kwargs)
File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3155, in create_op
op_def=op_def)
File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1717, in __init__
self._traceback = tf_stack.extract_stack()

InvalidArgumentError (see above for traceback): Invalid indices: [2,0] = [1, 21] does not index into [4,16]
[[Node: name/ScatterNd = ScatterNd[T=DT_FLOAT, Tindices=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](name/concat, name/Reshape, name/Cast_1)]]

哪里可能有问题以及如何解决它?

Unpooling 函数取自 this issue on github ,但没有提及有关批量取消池化的信息。

我的tf.__version__是1.10。

最佳答案

@Tofik.AI女巫Tensorflow版本你用吗?根据最新的文档,这是不正确的。我的实现:

def unpool(pool, ind, ksize=[1, 2, 2, 1], name=None):
with tf.variable_scope('name') as scope:
input_shape = tf.shape(pool)
output_shape = [input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]]

flat_input_size = tf.cumprod(input_shape)[-1]
flat_output_shape = tf.stack([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]])

pool_ = tf.reshape(pool, tf.stack([flat_input_size]))
batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype),
shape=tf.stack([input_shape[0], 1, 1, 1]))
b = tf.ones_like(ind) * batch_range
b = tf.reshape(b, tf.stack([flat_input_size, 1]))
ind_ = tf.reshape(ind, tf.stack([flat_input_size, 1]))
ind_ = ind_ - b * tf.cast(flat_output_shape[1], tf.int64)
ind_ = tf.concat([b, ind_], 1)

ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
ret = tf.reshape(ret, tf.stack(output_shape))

set_input_shape = pool.get_shape()
set_output_shape = [set_input_shape[0], set_input_shape[1] * ksize[1], set_input_shape[2] * ksize[2], set_input_shape[3]]
ret.set_shape(set_output_shape)
return ret

关于使用索引在 max_pool_with_argmax 之后进行 Tensorflow 反池化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52457221/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com