为了账号安全,请及时绑定邮箱和手机立即绑定

numpy 2d:如何仅获取第二列中允许的值的第一列中最大元素的索引

numpy 2d:如何仅获取第二列中允许的值的第一列中最大元素的索引

慕雪6442864 2023-09-19 14:11:16
帮助找到解决问题的高性能方法:我在神经网络(answers_weight)之后得到了结果,答案类别(相同长度)以及当前请求允许的类别:answers_weight = np.asarray([0.9, 3.8, 3, 0.6, 0.7, 0.99]) # ~3kk itemsanswers_category = [1, 2, 1, 5, 3, 1] # same size as answers_weight: ~3kk itemscategories_allowed1 = [1, 5, 8]res = np.stack((answers_weight, answers_category), axis=1)我需要知道最大元素的索引(在answers_weight数组中),但跳过不允许的类别(2,3)。在final中,索引必须= 2(“3.0”,因为“3.8”必须被跳过,因为类别不允许)
查看完整描述

2 回答

?
天涯尽头无女友

TA贡献1831条经验 获得超9个赞

最简单的方法是使用 numpy 的 masked_arrays 根据 allowed_categories 来屏蔽权重,然后查找argmax:


np.ma.masked_where(~np.isin(answers_category,categories_allowed1),answers_weight).argmax()

#2

另一种使用掩码的方法(假设最大权重是唯一的):


mask = np.isin(answers_category, categories_allowed1)

np.argwhere(answers_weight==answers_weight[mask].max())[0,0]

#2


查看完整回答
反对 回复 2023-09-19
?
一只萌萌小番薯

TA贡献1795条经验 获得超7个赞

我也使用面膜解决了这个问题


inds = np.arange(res.shape[0])

# a mask is an array [False  True False False  True False]

mask = np.all(res[:,1][:,None] != categories_allowed1,axis=1)


allowed_inds = inds[mask]

# max_ind is not yet the real answer because the not allowed values are not taken into account

max_ind = np.argmax(res[:,0][mask])

real_ind = allowed_inds[max_ind]


查看完整回答
反对 回复 2023-09-19
  • 2 回答
  • 0 关注
  • 86 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信