为了账号安全,请及时绑定邮箱和手机立即绑定

取一个张量的元素,它也在另一个张量中

取一个张量的元素,它也在另一个张量中

手掌心 2021-07-06 12:08:56
我有两个张量,我必须迭代第一个以只取另一个张量内的元素。只有一个元素t2也在里面t1。这里有一个例子t1 = tf.where(values > 0) # I get some indices example [6, 0], [3, 0]t2 = tf.where(values2 > 0) # I get [4, 0], [3, 0]t3 = .... # [3, 0]我尝试使用运算符来评估和迭代它们,.eval()并检查它们是否t2正在t1使用 operator in,但不起作用。TensorFlow 有没有可以做到这一点的函数?编辑for index in xrange(max_indices):    indices = tf.where(tf.equal(values, (index + 1))).eval() # indices: [[1 0]\n [4 0]\n [9 0]]    cent_indices = tf.where(centers > 0).eval() # cent_indices: [[6 0]\n [9 0]]    indices_list.append(indices)    for cent in cent_indices:        if cent in indices:           centers_list.append(cent)           break第一次迭代cent具有值[6 0]但它进入if条件。回答for index in xrange(max_indices):    indices = tf.where(tf.equal(values, (index + 1))).eval()    cent_indices = tf.where(centers > 0).eval()    indices_list.append(indices)    for cent in cent_indices:        # batch_item is an iterator from an outer loop        if values[batch_item, cent[0]].eval() == (index + 1):           centers_list.append(tf.constant(cent))           break该解决方案与我的任务有关,但如果您正在寻找一维张量中的解决方案,我建议您查看 tf.sets.set_intersection
查看完整描述

1 回答

?
炎炎设计

TA贡献1808条经验 获得超4个赞

那是你想要的吗?我只使用了这两个测试用例。


x = tf.constant([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 1]])

y = tf.constant([[1, 2, 3, 4, 3, 6], [1, 2, 3, 4, 5, 1]])

# x = tf.constant([[1, 2], [4, 5], [7, 7]])

# y = tf.constant([[7, 7], [3, 5]])


def match(xiterations, yiterations, yvalues, xvalues ):

    for i in range(xiterations):

        for j in range(yiterations):

            if (np.array_equal(yvalues[j], xvalues[i])):

                print( yvalues[j])


with tf.Session() as sess:

    xindex = tf.where( x > 4 )

    yindex = tf.where( y > 4 )


    xvalues = xindex.eval()

    yvalues = yindex.eval()


    xiterations =  tf.shape(xvalues)[0].eval()

    yiterations =  tf.shape(yvalues)[0].eval()


    print(tf.shape(xvalues)[0].eval())

    print(tf.shape(yvalues)[0].eval())


    if tf.shape(xvalues)[0].eval() >= tf.shape(yvalues)[0].eval():

        match( xiterations, yiterations, yvalues, xvalues)

    else:

        match( yiterations, xiterations, xvalues, yvalues)


查看完整回答
反对 回复 2021-07-21
  • 1 回答
  • 0 关注
  • 225 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信