1 回答
TA贡献1765条经验 获得超5个赞
可以K.dot()直接使用。因为您已经使用了K.l2_normalize,所以矩阵乘法的结果是余弦接近度。
from keras.models import Model
import keras.backend as K
from keras.layers import Lambda,Input
import numpy as np
N = 100
def dot_similarity(jdes):
jdes = K.l2_normalize(jdes, axis=-1) # (?, 100)
# define it myself
jt_six = K.constant(np.random.uniform(0, 1, size=(6, N)))
jt_six = K.l2_normalize(K.variable(jt_six), axis=-1) # (6, 100)
return K.dot(jdes,K.transpose(jt_six))
jdes = Input(shape=(N,))
result = Lambda(dot_similarity)(jdes)
model = Model(jdes,result)
print(model.summary())
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 100) 0
_________________________________________________________________
lambda_1 (Lambda) (None, 6) 0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
添加回答
举报