1 回答
TA贡献1775条经验 获得超11个赞
您需要修改搜索功能。
具体来说,看这一行:
best_states.append(imported_images[i])
如果要在返回的图像和文件名之间进行映射,则需要记录并返回该索引,i. 考虑添加一个best_states_index变量并返回两者,或者简单地替换imported_images[i]并i使用它来访问文件名和图像数据。
更明确地说:
def search(image):
hidden_states = [sess.run(hidden_state(X, mask, W, b),
feed_dict={X: im.reshape(1, pixels), mask:
np.random.binomial(1, 1-corruption_level, (1, pixels))})
for im in image_set]
query = sess.run(hidden_state(X, mask, W, b),
feed_dict={X: image.reshape(1,pixels), mask: np.random.binomial(1, 1-corruption_level, (1, pixels))})
starting_state = int(np.random.random()*len(hidden_states)) #choose random starting state
best_states = [imported_images[starting_state]]
best_states_index = [starting_state]
distance = euclidean_distance(query[0], hidden_states[starting_state][0]) #Calculate similarity between hidden states
for i in range(len(hidden_states)):
dist = euclidean_distance(query[0], hidden_states[i][0])
if dist <= distance:
distance = dist #as the method progresses, it gets better at identifying similiar images
best_states.append(imported_images[i])
best_states_index.append(i)
if len(best_states)>0:
return best_states, best_states_index
else:
return best_states[len(best_states)-101:], best_states_index[len(best_states)-101:]
添加回答
举报