gpt4 book ai didi

python - 使用 TensorFlow 对图像进行自定义上采样

转载 作者:太空宇宙 更新时间:2023-11-03 21:08:39 26 4
gpt4 key购买 nike

我在 TensorFlow 中实现层函数时遇到问题。也许有更多经验的人知道如何解决这个问题。该函数的用法应如下所示:

在:一个名为 A[B x W x H x 2] 张量

输出:一个名为 B 的新张量,其大小为 [B x p*W x q*W],填充如下:

for b from 0 to B: #loop over batches
for w from 0 to W: # loop over width
for h from 0 to H: # loop over height
B[b,w*p:w*p+p,h*q:h*q+q] = tf.random.normal(shape=[p,q],
mean=A[b,w,h,0],
stddev=A[b,w,h,1])

我基本上想做的是使用“随机(高斯)插值”对图像进行上采样。

我无法创建一个空张量并填充它,就像我通常根据伪代码所做的那样。我尝试使用 TensorFlow 的 tf.map_fn() 函数,不幸的是它不起作用。

这个想法是稍后使用该层作为均值池或最大池的替代方案。

也许有更简单的方法来做到这一点?

任何帮助表示赞赏。谢谢。

最佳答案

您可以通过矢量化方式来完成此操作(这应该比循环或映射快得多),如下所示:

import tensorflow as tf
import numpy as np

def gaussian_upsampling(A, p, q):
s = tf.shape(A)
B, W, H, C = s[0], s[1], s[2], s[3]
# Add two dimensions to A for tiling
A_exp = tf.expand_dims(tf.expand_dims(A, 2), 4)
# Tile A along new dimensions
A_tiled = tf.tile(A_exp, [1, 1, p, 1, q, 1])
# Reshape
A_tiled = tf.reshape(A_tiled, [B, W * p, H * q, C])
# Extract mean and std
mean_tiled = A_tiled[:, :, :, 0]
std_tiled = A_tiled[:, :, :, 1]
# Make base random value
rnd = tf.random.normal(shape=[B, W * p, H * q], mean=0, stddev=1, dtype=A.dtype)
# Scale and shift random value
return rnd * std_tiled + mean_tiled

# Test
with tf.Graph().as_default(), tf.Session() as sess:
tf.random.set_random_seed(100)
mean = tf.constant([[[ 1.0, 2.0, 3.0],
[ 4.0, 5.0, 6.0]],
[[ 7.0, 8.0, 9.0],
[10.0, 11.0, 12.0]]])
std = tf.constant([[[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6]],
[[0.7, 0.8, 0.9],
[1.0, 1.1, 1.2]]])
A = tf.stack([mean, std], axis=-1)
with np.printoptions(precision=2, suppress=True):
print(sess.run(gaussian_upsampling(A, 3, 2)))

输出:

[[[ 0.94  0.97  1.82  1.67  2.89  2.96]
[ 1.04 0.78 2.23 2.02 2.95 3.04]
[ 0.9 0.96 1.84 1.98 2.74 3.06]
[ 3.89 4.12 5.72 4.32 6.02 5.7 ]
[ 3.47 4.27 4.39 4.85 6.38 5.32]
[ 3.21 3.98 4.64 4.31 5.72 5.96]]

[[ 8.15 7.08 7.33 7.78 8.75 9.95]
[ 7.37 7.29 8.27 8.26 8.56 8.17]
[ 5.91 7.95 7.9 7.81 8.43 8.64]
[11.12 11.49 11.95 11.74 11.43 12.3 ]
[ 9.98 9.66 9.21 10.2 12.78 12.13]
[ 8.33 10.37 11.88 11.44 12.96 11.73]]]

关于python - 使用 TensorFlow 对图像进行自定义上采样,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55224109/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com