gpt4 book ai didi

tensorflow - 如何获得填充的 bool 掩码

转载 作者:行者123 更新时间:2023-12-02 04:25:43 24 4
gpt4 key购买 nike

使用带填充 'SAME'tf.extract_image_patches 将导致一些包含填充的补丁(这很好)。

有没有一种简单的方法来获得一个 TensorFlow bool 掩码来掩蔽所有包含填充的补丁?或者我需要重新实现填充过程吗?

最佳答案

我目前的解决方案是添加一个表示位标志的附加 channel 。提取图像 block 后,对于正在填充的 channel ,位标志为 0,对于非填充 channel ,位标志为 1

完整解决方案:

input_tensor = tf.random.normal([10, 28, 28, 1])
window_shape, strides, padding = (4, 4), (2, 2), 'SAME'

# ----------------------------

bits = tf.ones([tf.shape(input_tensor)[0], input_tensor.shape[1], input_tensor.shape[2], 1])
input_for_patching = tf.concat([input_tensor, bits], axis=-1)

patches = tf.extract_image_patches(input_for_patching, ksizes=(1, *window_shape, 1), strides=(1, *strides, 1), rates=(1, 1, 1, 1), padding=padding)

patches_shape = patches.shape

patches = tf.reshape(patches, [-1, *window_shape, input_tensor.shape[3] + 1])

padding_mask = tf.to_float(tf.reduce_all(tf.equal(patches[:, :, :, -1:], 1.0), [1, 2, 3]))

patches = tf.reshape(patches[:, :, :, :-1], [-1, patches_shape[1], patches_shape[2], window_shape[0] * window_shape[1] * input_tensor.shape[3]])

上面代码中的 padding_mask 是我需要的。

如果有人有更短、更优雅和/或更集成的版本,请随时分享。

关于tensorflow - 如何获得填充的 bool 掩码,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54618125/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com