作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我需要创建一个与自动梯度兼容的剪切矩阵,适用于 B、C、H、W 张量,并为剪切值获取输入值(可能随机生成)。如何为此生成剪切矩阵?
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
# Load image
def preprocess_simple(image_name, image_size):
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
image = Image.open(image_name).convert('RGB')
return Loader(image).unsqueeze(0)
# Save image
def deprocess_simple(output_tensor, output_name):
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.squeeze(0))
image.save(output_name)
def get_shear_mat(theta):
...
return shear_mat
def shear_img(x, theta, dtype):
shear_mat = get_shear_mat(theta)
grid = F.affine_grid(shear_mat , x.size()).type(dtype)
x = F.grid_sample(x, grid)
return x
# Shear tensor
test_input = # Test image
shear_values = (3,4) # Example values
sheared_tensor = shear_img(test_input, shear_values)
最佳答案
假设m
是剪切因子,那么theta = atan(1/m)
是剪切角。您现在可以选择水平 剪切或垂直 剪切。以下是如何实现 get_shear_mat
,这样您就可以通过设置 ax=0
来选取水平剪切,通过设置 ax=1
来选取垂直剪切:
def get_shear_mat(theta, ax=0):
assert ax in [0, 1]
m = 1 / torch.tan(torch.tensor(theta))
if ax == 0: # Horizontal shear
shear_mat = torch.tensor([[1, m, 0],
[0, 1, 0]])
else: # Vertical shear
shear_mat = torch.tensor([[1, 0, 0],
[m, 1, 0]])
return shear_mat
请注意,剪切映射只是原始图像中的点 (x,y)
到水平剪切点 (x+my,y)
的映射, 和 (x,y+mx)
用于垂直剪切。这正是我们通过如上所述定义 shear_mat
所做的。
对 shear_img
的可选修改以支持第一行中的批输入操作。还向 shear_img
添加一个参数 - ax
来定义我们想要水平 (ax=0
) 还是垂直 (ax=1
) 剪切:
def shear_img(x, ax, theta, dtype):
shear_mat = get_shear_mat(theta, ax)[None, ...].type(dtype).repeat(x.shape[0], 1, 1)
grid = F.affine_grid(shear_mat , x.size()).type(dtype)
x = F.grid_sample(x.type(dtype), grid)
return x
让我们在图像上测试这个实现:
# Let im be a 4D tensor of shape BxCxHxW (an image or a batch of images):
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor # Set type of data
sheared_im = shear_img(im, 0, np.pi/4, dtype) #Horizontal shear by shear angle of pi/4
plt.imshow(sheared_im.squeeze(0).permute(1,2,0)/255)
plt.show()
如果 im
是我们穿裙子跳舞的猫:
那么我们的剧情就是:
如果我们想要垂直剪切:
sheared_im = shear_img(im, 1, np.pi/4, dtype) # Vertical shear by shear angle of pi/4
plt.imshow(sheared_im.squeeze(0).permute(1, 2, 0)/255)
plt.show()
我们得到:
万岁!
关于python - 如何为 PyTorch 的 F.affine_grid 和 F.grid_sample 创建剪切矩阵?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64394325/
我需要创建一个与自动梯度兼容的剪切矩阵,适用于 B、C、H、W 张量,并为剪切值获取输入值(可能随机生成)。如何为此生成剪切矩阵? import torch import torch.nn.funct
我是一名优秀的程序员,十分优秀!