为了账号安全,请及时绑定邮箱和手机立即绑定

tensorflow in_top_k 的输入应该是 1 级还是 2 级?

tensorflow in_top_k 的输入应该是 1 级还是 2 级?

叮当猫咪 2021-06-29 13:57:27
我尝试尝试使用 in_top_k 函数来查看该函数到底在做什么。但我发现了一些非常令人困惑的行为。首先我编码如下import numpy as npimport tensorflow as tftarget = tf.constant(np.random.randint(2, size=30).reshape(30,-1), dtype=tf.int32, name="target")pred = tf.constant(np.random.rand(30,1), dtype=tf.float32, name="pred")result = tf.nn.in_top_k(pred, target, 1)init = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init)    targetVal = target.eval()    predVal = pred.eval()    resultVal = result.eval()然后它生成以下错误:ValueError: Shape must be rank 1 but is rank 2 for 'in_top_k/InTopKV2' (op: 'InTopKV2') with input shapes: [30,1], [30,1], [].然后我将代码更改为import numpy as npimport tensorflow as tftarget = tf.constant(np.random.randint(2, size=30), dtype=tf.int32, name="target")pred = tf.constant(np.random.rand(30,1).reshape(-1), dtype=tf.float32, name="pred")result = tf.nn.in_top_k(pred, target, 1)init = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init)    targetVal = target.eval()    predVal = pred.eval()    resultVal = result.eval()但现在错误变成了ValueError: Shape must be rank 2 but is rank 1 for 'in_top_k/InTopKV2' (op: 'InTopKV2') with input shapes: [30], [30], [].那么输入应该是 1 级还是 2 级?
查看完整描述

1 回答

  • 1 回答
  • 0 关注
  • 135 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信