2 回答
TA贡献1802条经验 获得超4个赞
用于torch.cdist
L2 范数 - 欧氏距离
res = torch.cdist(mat, mat2.permute(1,0), p=2)
在这里,我曾经将frompermute
的 dim 交换为mat2
7,20
20,7
TA贡献2065条经验 获得超13个赞
首先,PyTorch 中的矩阵乘法有一个内置运算符:@。因此,要将 mat 和 mat2 相乘,您只需执行以下操作:
mat @ mat2
(假设尺寸一致,应该可以工作)。
现在,要计算您似乎在第二个块中计算的平方差之和(SSD 或 L2 范数),您可以做一个简单的技巧。由于 L2 范数的平方||m_i - v||^2(其中m_i是矩阵的第 i 行M,v是向量)等于点积<m_i - v, m_i-v>- 根据您获得的点积的线性度:因此您可以通过以下方式<m_i,m_i> - 2<m_i,v> + <v,v>计算向量中每一行的 SSD:计算一次每行的 L2 范数平方、一次每行与向量之间的点积以及一次向量的 L2 范数。这可以在 中完成。然而,对于 2 个矩阵之间的 SSD,您仍然会得到MvO(n^2)O(n^3)。不过,可以通过向量化操作而不是使用循环来进行改进。这是 2 个矩阵的简单实现:
def mat_mat_l2_mult(mat,mat2):
rows_norm = (torch.norm(mat, dim=1, p=2, keepdim=True)**2).repeat(1,mat2.shape[1])
cols_norm = (torch.norm(mat2, dim=0, p=2, keepdim=True)**2).repeat(mat.shape[0], 1)
rows_cols_dot_product = mat @ mat2
ssd = rows_norm -2*rows_cols_dot_product + cols_norm
return ssd.sqrt()
mat = torch.randn([20, 7])
mat2 = torch.randn([7,20])
print(mat_mat_l2_mult(mat, mat2))
所得矩阵的每个单元格将具有中每行和每列之间i,j差异的 L2 范数。imatjmat2
添加回答
举报