2 回答
TA贡献1831条经验 获得超9个赞
最简单的方法是使用 numpy 的 masked_arrays 根据 allowed_categories 来屏蔽权重,然后查找argmax:
np.ma.masked_where(~np.isin(answers_category,categories_allowed1),answers_weight).argmax()
#2
另一种使用掩码的方法(假设最大权重是唯一的):
mask = np.isin(answers_category, categories_allowed1)
np.argwhere(answers_weight==answers_weight[mask].max())[0,0]
#2
TA贡献1795条经验 获得超7个赞
我也使用面膜解决了这个问题
inds = np.arange(res.shape[0])
# a mask is an array [False True False False True False]
mask = np.all(res[:,1][:,None] != categories_allowed1,axis=1)
allowed_inds = inds[mask]
# max_ind is not yet the real answer because the not allowed values are not taken into account
max_ind = np.argmax(res[:,0][mask])
real_ind = allowed_inds[max_ind]
添加回答
举报