为了账号安全,请及时绑定邮箱和手机立即绑定

如何在2D张量中获取每行的top-k元素?

如何在2D张量中获取每行的top-k元素?

拉莫斯之舞 2022-08-25 15:25:26
如何以优雅的方式在2D张量中获取每行的前k个元素,而不是像下面那样使用for循环?import torchelements = torch.rand(5,10)topk_list = [2,3,1,2,0] # means top2 for 1st row, top3 for 2nd row, top1 for 3rd row,....index_list = [] # record the topk index in elementsfor i in range(5):    index_list.append(elements[i].topk(topk_list[i]))
查看完整描述

2 回答

?
HUX布斯

TA贡献1876条经验 获得超6个赞

如果你的 's 变化不大,并且你想对代码进行矢量化,你可以先取每行的最大顶部,然后收集所需的结果。kk


# Code from OP

import torch


elements = torch.rand(5,10)

topk_list = [2,3,1,2,0] # means top2 for 1st row, top3 for 2nd row, top1 for 3rd row,....

index_list = [] # record the topk index in elements


for i in range(5):

    index_list.append(elements[i].topk(topk_list[i]))


# Print the result

print(index_list)


# Get topk for max_k

max_k = max(topk_list)

topk_vals, topk_inds = elements.topk(max_k, dim=-1)


# Select desired topk using mask

mask = torch.arange(max_k)[None, :] < torch.tensor(topk_list)[:, None]

vals, inds = topk_vals[mask], topk_inds[mask]

rows, _ = mask.nonzero().T

print("-" * 10)

print("rows", rows)

print("inds", inds)

print("vals", vals)


# Or split

vals_per_row = vals.split(topk_list)

inds_per_row = inds.split(topk_list)

print("-" * 10)

print("vals_per_row", vals_per_row)

print("inds_per_row", inds_per_row)


# Or zip (for loop but should be cheap)

index_list = zip(vals_per_row, inds_per_row)

print("-" * 10)

print("zipped results", list(index_list))

这将给出以下输出:


[torch.return_types.topk(

values=tensor([0.8148, 0.7443]),

indices=tensor([8, 4])), torch.return_types.topk(

values=tensor([0.7529, 0.7352, 0.6354]),

indices=tensor([8, 1, 9])), torch.return_types.topk(

values=tensor([0.8792]),

indices=tensor([7])), torch.return_types.topk(

values=tensor([0.9626, 0.8728]),

indices=tensor([6, 2])), torch.return_types.topk(

values=tensor([]),

indices=tensor([], dtype=torch.int64))]

----------

rows tensor([0, 0, 1, 1, 1, 2, 3, 3])

inds tensor([8, 4, 8, 1, 9, 7, 6, 2])

vals tensor([0.8148, 0.7443, 0.7529, 0.7352, 0.6354, 0.8792, 0.9626, 0.8728])

----------

vals_per_row (tensor([0.8148, 0.7443]), tensor([0.7529, 0.7352, 0.6354]), tensor([0.8792]), tensor([0.9626, 0.8728]), tensor([]))

inds_per_row (tensor([8, 4]), tensor([8, 1, 9]), tensor([7]), tensor([6, 2]), tensor([], dtype=torch.int64))

----------

zipped results [(tensor([0.8148, 0.7443]), tensor([8, 4])), (tensor([0.7529, 0.7352, 0.6354]), tensor([8, 1, 9])), (tensor([0.8792]), tensor([7])), (tensor([0.9626, 0.8728]), tensor([6, 2])), (tensor([]), tensor([], dtype=torch.int64))]


查看完整回答
反对 回复 2022-08-25
?
跃然一笑

TA贡献1826条经验 获得超6个赞

某件事是否优雅总是有待商榷。在 for 循环中使用固定范围绝对可以改进,您至少可以使用,以便代码可以重用于不同的 topk 列表。range(len(topk_list))

您可以通过使用以下命令进一步改进:

for i, n in enumerate(topk_list): 
    index_list.append(elements[i].topk(n))

甚至:

index_list = [ elements[i].topk(n) for i, n in enumerate(topk_list) ]

但这只是语法糖。


查看完整回答
反对 回复 2022-08-25
  • 2 回答
  • 0 关注
  • 87 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号