gpt4 book ai didi

Python:im2col 的实现利用了 6 维数组的优势?

转载 作者:太空宇宙 更新时间:2023-11-03 12:54:02 25 4
gpt4 key购买 nike

我正在阅读 implementation of im2col来自一本深度学习书(第 7 章,CNN),其目的是将 4 维数组转换为 2 维数组。我不知道为什么在实现中有一个6维数组。我对作者使用的算法背后的想法很感兴趣。

我尝试搜索了很多关于 im2col 实现的论文,但没有一篇像这样使用高维数组。目前我发现对 im2col 过程可视化有用的 Material 是 this paper - HAL Id: inria-00112631 的图片。


def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
"""
Parameters
----------
input_data : (batch size, channel, height, width), or (N,C,H,W) at below
filter_h : kernel height
filter_w : kernel width
stride : size of stride
pad : size of padding
Returns
-------
col : two dimensional array
"""
N, C, H, W = input_data.shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1

img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col

最佳答案

让我们尝试想象一下 im2col 做了什么。它以一堆彩色图像作为输入,该堆栈具有图像 ID、颜色 channel 、垂直位置、水平位置等维度。为简单起见,假设我们只有一张图片:

enter image description here

它做的第一件事是填充:

enter image description here

接下来,它将它切割成窗口。窗口的大小由 filter_h/w 控制,重叠由 strides 控制。

enter image description here

这是六个维度的来源:图像 ID(示例中缺少,因为我们只有一张图像)、网格高度/宽度、颜色 channel 。窗口高度/宽度。

enter image description here

目前的算法有点笨拙,它以错误的维度顺序组装输出,然后必须使用 transpose 进行纠正。

最好一开始就把它做好:

def im2col_better(input_data, filter_h, filter_w, stride=1, pad=0):
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
N, C, H, W = img.shape
out_h = (H - filter_h)//stride + 1
out_w = (W - filter_w)//stride + 1
col = np.zeros((N, out_h, out_w, C, filter_h, filter_w))
for y in range(out_h):
for x in range(out_w):
col[:, y, x] = img[
..., y*stride:y*stride+filter_h, x*stride:x*stride+filter_w]
return col.reshape(np.multiply.reduceat(col.shape, (0, 3)))

作为旁注:我们可以使用 stride_tricks 并避免嵌套的 for 循环做得更好:

def im2col_best(input_data, filter_h, filter_w, stride=1, pad=0):
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
N, C, H, W = img.shape
NN, CC, HH, WW = img.strides
out_h = (H - filter_h)//stride + 1
out_w = (W - filter_w)//stride + 1
col = np.lib.stride_tricks.as_strided(img, (N, out_h, out_w, C, filter_h, filter_w), (NN, stride * HH, stride * WW, CC, HH, WW)).astype(float)
return col.reshape(np.multiply.reduceat(col.shape, (0, 3)))

算法做的最后一件事是 reshape ,合并前三个维度(在我们的示例中同样只有两个,因为只有一个图像)。红色箭头显示各个窗口如何排列到第一个新维度:

enter image description here

将最后三个维度的颜色 channel 、窗口中的y坐标、窗口中的x坐标合并到第二个输出维度中。各个像素按黄色箭头所示排列:

enter image description here

关于Python:im2col 的实现利用了 6 维数组的优势?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50292750/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com