gpt4 book ai didi

python - 具有 NCHW 格式的 tensorflow.nn.conv2d 中的过滤器形状

转载 作者:太空狗 更新时间:2023-10-30 01:36:58 34 4
gpt4 key购买 nike

正在关注 Tensorflow's best practices for performance ,我正在使用 NCHW 数据格式,但我不确定要在 tensorflow.nn.conv2d 中使用的过滤器形状.

文档说对 NHWC 格式使用 [filter_height, filter_width, in_channels, out_channels],但不清楚如何处理 NCHW。

是否应该使用相同的形状?

最佳答案

使用相同的过滤器形状应该可以。函数参数的唯一变化是步幅。例如,假设您希望您的体系结构适用于两种格式,这也是推荐的:

# input -> Tensor in NCHW format
if use_nchw:
result = tf.nn.conv2d(
input=input,
filter=filter,
strides=[1, 1, stride, stride],
data_format='NCHW')
else:
input_t = tf.transpose(input, [0, 2, 3, 1]) # NCHW to NHWC

result = tf.nn.conv2d(
input=input_t,
filter=filter,
strides=[1, stride, stride, 1])

result = tf.transpose(result, [0, 3, 1, 2]) # NHWC to NCHW

关于python - 具有 NCHW 格式的 tensorflow.nn.conv2d 中的过滤器形状,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42783286/

34 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com