gpt4 book ai didi

python - NCHW和NHWC网络格式转换Tensorflow模型

转载 作者:行者123 更新时间:2023-12-01 07:20:39 25 4
gpt4 key购买 nike

对于这个 Humanpose Tensorflow 网络,network_cmubase ,它仅接受 NHWC 输入格式。如果我以 NCHW 格式构建网络,则会出现错误

Depth of input (32) is not a multiple of input depth of filter (3) for 'conv1_1/Conv2D' (op: 'Conv2D') with input shapes: [1,3,24,32], [3,3,3,64]. 

我构建网络的代码是

import tensorflow as tf
import numpy as np
from network_cmu import CmuNetwork


def main():
#print(tensor_util.MakeNdarray(n.attr['value'].tensor))
placeholder_input = tf.placeholder(dtype=tf.float32, shape=(1, 3, 24, 32), name="image")
net = CmuNetwork({'image': placeholder_input}, trainable=False)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
#for n in tf.get_default_graph().as_graph_def().node:
# print(n.name)
save_path = saver.save(sess, "cmuThreeOutputs/model.ckpt")

if __name__ == '__main__':
main()

我应该更改什么才能拥有 NCHW 格式的网络?

最佳答案

您可以使用tf.transpose将轴从 NHWC 转移到 NCHW

input_ = tf.convert_to_tensor(np.random.rand(1, 3, 24, 32))
a1 = tf.transpose(input_, perm=[0, 2, 3, 1])
print(a1.shape) # 1, 24, 32, 3

您甚至可以使用 tf.reshape

a2 = tf.reshape(input_, (-1, input_.shape[2], input_.shape[3], input_.shape[1]))
print(a2.shape) # 1, 24, 32, 3

关于python - NCHW和NHWC网络格式转换Tensorflow模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57719499/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com