为了账号安全,请及时绑定邮箱和手机立即绑定

在不使用 get_shape、size 和 shape 函数的情况下提取张量的第一维?

在不使用 get_shape、size 和 shape 函数的情况下提取张量的第一维?

收到一只叮咚 2021-06-04 13:50:08
我用 Keras 写了一个损失函数。它有两个参数,y_true和y_pred。我的第一行代码是:batch = y_pred.get_shape()[0]. 然后在我的batch变量中我有第一个维度y_pred,所以我循环range(batch)并写下我写的内容。那没关系。问题是当我编译所有内容时,我收到一条错误消息,告诉我批处理不是整数,而是张量。然后,作为 Tensorflow 的初学者,我开始思考如何从 中获取一个整数batch,它应该是一个整数,但是一个张量。我试图这样做,sess.run(batch)但这根本没有帮助。所以,我的问题是如何从表示整数变量的张量中获取整数。我想使用一些真正给我一个整数而不是张量的函数。请帮忙。这是我的代码:def custom_loss(y_true, y_pred):    batch = y_pred.get_shape()[0]    list_ones = returnListOnes(batch)    tensor_ones = tf.convert_to_tensor(list_ones)    loss = 0    for i in range(batch):      for j in range(S):        for k in range(S):            lista = returnListOnesIndex(batch, [j,k,0])            lista_bx = returnListOnesIndex(batch, [j,k,1])            lista_by = returnListOnesIndex(batch, [j,k,2])            lista_bw = returnListOnesIndex(batch, [j,k,3])            lista_bh = returnListOnesIndex(batch, [j,k,4])            lista_to_tensor = tf.convert_to_tensor(lista)            lista_bx_to_tensor = tf.convert_to_tensor(lista_bx)            lista_by_to_tensor = tf.convert_to_tensor(lista_by)            lista_bw_to_tensor = tf.convert_to_tensor(lista_bw)            lista_bh_to_tensor = tf.convert_to_tensor(lista_bh)            element = tf.reduce_sum(tf.multiply(lista_to_tensor,y_pred))            element_true = tf.reduce_sum(tf.multiply(lista_to_tensor, y_true))            element_bx = tf.reduce_sum(tf.multiply(lista_bx_to_tensor, y_pred))            element_bx_true = tf.reduce_sum(tf.multiply(lista_bx_to_tensor, y_true))            element_by = tf.reduce_sum(tf.multiply(lista_by_to_tensor, y_pred))            element_by_true = tf.reduce_sum(tf.multiply(lista_by_to_tensor, y_true))            element_bw = tf.reduce_sum(tf.multiply(lista_bw_to_tensor, y_pred))            element_bw_true = tf.reduce_sum(tf.multiply(lista_bw_to_tensor, y_true))正如你所看到的,我想要batch变量是int这样我可以循环并做一些事情。我也用过size,shape它也行不通。
查看完整描述

1 回答

?
米琪卡哇伊

TA贡献1998条经验 获得超6个赞

矢量化代码肯定会更高效,我强烈建议您尝试以不需要循环的方式编写代码。

但是,如果您无法这样做,则可以求助于tf.map_fn.

从您的代码中,我看不出在i您的循环中使用了什么地方。我猜这是一个错误(可能batch应该i在循环内)或我自己的失明 - 否则你可以将结果乘以批量大小......


查看完整回答
反对 回复 2021-06-09
  • 1 回答
  • 0 关注
  • 176 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信