为了账号安全,请及时绑定邮箱和手机立即绑定

在 Keras 中,如何使用 dot() 计算张量与常量矩阵的每一行之间的余弦接近度?

在 Keras 中,如何使用 dot() 计算张量与常量矩阵的每一行之间的余弦接近度?

炎炎设计 2021-10-12 10:53:26
我有一个张量jdes,其是(?, 100)和常数的矩阵jt_six,其具有的形状(6,100)。我试图得到jdes和 的每一行的余弦接近度jt_six的结果,结果应该有 shape (?, 6)。我看到该dot()层能够计算余弦接近度设置,normalize=True但是使用我拥有的代码,我得到的结果形状(6,1)中没有批量大小。任何人都可以帮助我吗?def dot_similarity(jdes):    jdes = K.l2_normalize(jdes, axis=-1) # (?, 100)    jt_six = K.l2_normalize(K.variable(jt_six), axis=-1) # (6, 100)    return dot([jt_six, jdes], axes=-1, normalize=True) # (6, 1), need (?, 6)result = Lambda(dot_similarity)(jdes)
查看完整描述

1 回答

?
POPMUISE

TA贡献1765条经验 获得超5个赞

可以K.dot()直接使用。因为您已经使用了K.l2_normalize,所以矩阵乘法的结果是余弦接近度。


from keras.models import Model

import keras.backend as K

from keras.layers import Lambda,Input

import numpy as np


N = 100

def dot_similarity(jdes):

    jdes = K.l2_normalize(jdes, axis=-1) # (?, 100)

    # define it myself

    jt_six = K.constant(np.random.uniform(0, 1, size=(6, N)))

    jt_six = K.l2_normalize(K.variable(jt_six), axis=-1) # (6, 100)

    return K.dot(jdes,K.transpose(jt_six))


jdes = Input(shape=(N,))

result = Lambda(dot_similarity)(jdes)

model = Model(jdes,result)

print(model.summary())


_________________________________________________________________

Layer (type)                 Output Shape              Param #   

=================================================================

input_1 (InputLayer)         (None, 100)               0         

_________________________________________________________________

lambda_1 (Lambda)            (None, 6)                 0         

=================================================================

Total params: 0

Trainable params: 0

Non-trainable params: 0

_________________________________________________________________


查看完整回答
反对 回复 2021-10-12
  • 1 回答
  • 0 关注
  • 337 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信