3 回答
TA贡献1963条经验 获得超6个赞
因此,网格生成器和采样器是 Spatial Transformer 的子模块(JADERBERG、Max 等人)。这些子模块不可训练,它们可让您应用可学习的以及不可学习的空间变换。theta在这里,我使用这两个子模块,并使用 PyTorch 的函数torch.nn.functional.affine_grid和(这些函数分别是生成器和采样器的实现)来旋转图像torch.nn.functional.affine_sample:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
def get_rot_mat(theta):
theta = torch.tensor(theta)
return torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],
[torch.sin(theta), torch.cos(theta), 0]])
def rot_img(x, theta, dtype):
rot_mat = get_rot_mat(theta)[None, ...].type(dtype).repeat(x.shape[0],1,1)
grid = F.affine_grid(rot_mat, x.size()).type(dtype)
x = F.grid_sample(x, grid)
return x
#Test:
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
#im should be a 4D tensor of shape B x C x H x W with type dtype, range [0,255]:
plt.imshow(im.squeeze(0).permute(1,2,0)/255) #To plot it im should be 1 x C x H x W
plt.figure()
#Rotation by np.pi/2 with autograd support:
rotated_im = rot_img(im, np.pi/2, dtype) # Rotate image by 90 degrees.
plt.imshow(rotated_im.squeeze(0).permute(1,2,0)/255)
在上面的示例中,假设我们将图像im视为一只穿着裙子跳舞的猫:
rotated_im
将是一只穿着裙子逆时针旋转 90 度的跳舞猫:
如果我们用rot_img
等号theta
调用,就会得到以下结果np.pi/4
:
最好的部分是它可以区分输入并具有 autograd 支持!万岁!
TA贡献2011条经验 获得超2个赞
使用 torchvision 应该很简单:
import torchvision.transforms.functional as TF
angle = 30
x = torch.randn(1,3,512,512)
out = TF.rotate(x, angle)
例如如果x是:
out
旋转 30 度为(注:逆时针):
TA贡献1813条经验 获得超2个赞
pytorch 有一个函数:
x = torch.tensor([[0, 1], [2, 3]]) x = torch.rot90(x, 1, [0, 1])
>> tensor([[1, 3], [0, 2]])
以下是文档:https://pytorch.org/docs/stable/ generated/torch.rot90.html
添加回答
举报