为了账号安全,请及时绑定邮箱和手机立即绑定

使用方括号对 Pytorch 张量进行子集化

使用方括号对 Pytorch 张量进行子集化

呼啦一阵风 2022-07-26 16:00:51
我遇到了一行代码,用于在 PyTorch 中将 3D 张量简化为 2D 张量。3D 张量x的大小torch.Size([500, 50, 1])和这行代码:x = x[lengths - 1, range(len(lengths))]用于减少x到大小为 的 2D 张量torch.Size([50, 1])。lengths也是一个torch.Size([50])包含值的形状张量。请任何人解释这是如何工作的?谢谢你。
查看完整描述

2 回答

?
一只萌萌小番薯

TA贡献1795条经验 获得超7个赞

这里的关键特性是将张量的值lengths作为 的索引传递x。这里简化的例子,我交换了容器的尺寸,所以 index dimenson 首先:


container = torch.arange(0, 50 )

container = f.reshape((5, 10))

>>>tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],

        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],

        [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],

        [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],

        [40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])


indices = torch.arange( 2, 7, dtype=torch.long )

>>>tensor([2, 3, 4, 5, 6])


print( container[ range( len(indices) ), indices] )

>>>tensor([ 2, 13, 24, 35, 46])    

注意:我们从一行中得到一件事(range( len(indices) )产生连续的行号),列号由索引[ row_number ]


查看完整回答
反对 回复 2022-07-26
?
素胚勾勒不出你

TA贡献1827条经验 获得超9个赞

在被这种行为难住之后,我对此进行了更多挖掘,发现它与多维 NumPy 数组的索引行为一致。使这种违反直觉的原因是两个数组必须具有相同的长度这一不太明显的事实,即在这种情况下len(lengths)。


事实上,它的工作原理如下: *lengths确定您访问第一个维度的顺序。即,如果您有一个一维数组a = [0, 1, 2, ...., 500],并使用 list 访问它b = [300, 200, 100],那么结果a[b] = [301, 201, 101](这也解释了lengths - 1运算符,它只会导致访问的值与分别在b、 或lengths中使用的索引相同)。*range(len(lengths))然后 * 只需选择第 - 行i中的第 - 个元素i。如果您有一个方阵,您可以将其解释为矩阵的对角线。由于您只能访问前两个维度上每个位置的单个元素,因此可以将其存储在一个维度中(从而将您的 3D 张量减少到 2D)。后一个维度简单地保持“原样”。


如果你想玩这个,我强烈建议将range()值更改为更长/更短的值,这将导致以下错误:


IndexError:形状不匹配:索引数组无法与形状(x,)(y,)一起广播


其中x和y是您的特定长度值。


要以长形式编写此访问方法以了解“幕后”发生的情况,还请考虑以下示例:


import torch

x = torch.randint(500, 50, 1)

lengths = torch.tensor([2, 30, 1, 4])  # random examples to explore

diag = list(range(len(lengths)))  # [0, 1, 2, 3]

result = []

for i, row in enumerate(lengths):

    temp_tensor = x[row, :, :]  # temp_tensor.shape = [1, 50, 1]

    temp_tensor = temp_tensor.squeeze(0)[diag[i]]  # temp_tensor.shape = [1, 1]

    result.append(temp.tensor)


# back to pytorch

result = torch.tensor(result)

result.shape  # [4, 1]


查看完整回答
反对 回复 2022-07-26
  • 2 回答
  • 0 关注
  • 66 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信