为了账号安全,请及时绑定邮箱和手机立即绑定

tensorflow找到到真实点的最小距离

tensorflow找到到真实点的最小距离

繁花不似锦 2022-05-24 16:17:15
我有一个 Bx3 张量,fooB= 批量大小的 3D 点。通过某种幻想,我得到了另一个张量,bar形状为 Bx6x3,其中每个 B 6x3 矩阵对应于foo. 该 6x3 矩阵由 6 个复值 3D 点组成。我想做的是,对于我的每个 B 点,从6 in 中找到最接近对应点 in的实值点,最终得到一个新的 Bx3 ,其中包含与 in点的最近点。barfoomin_barbarfoo在numpy中,我可以使用屏蔽数组来完成这一壮举:foo = np.array([    [1,2,3],    [4,5,6],    [7,8,9]])# here bar is only Bx2x3 for simplicity, but the solution generalizesbar = np.array([    [[2,3,4],[1+0.1j,2+0.1j,3+0.1j]],    [[6,5,4],[4,5,7]],    [[1j,1j,1j],[0,0,0]],])#mask complex elements of barbar_with_masked_imag = np.ma.array(bar)candidates = bar_with_masked_imag.imag == 0bar_with_masked_imag.mask = ~candidatesdists = np.sum(bar_with_masked_imag**2, axis=1)mindists = np.argmin(dists, axis=1)foo_indices = np.arange(foo.shape[0])min_bar = np.array(    bar_with_masked_imag[foo_indices,mindists,:],     dtype=float)print(min_bar)#[[2. 3. 4.]# [4. 5. 7.]# [0. 0. 0.]]但是,tensorflow 没有掩码数组等。我如何将其翻译成张量流?
查看完整描述

1 回答

?
幕布斯7119047

TA贡献1794条经验 获得超8个赞

这是一种方法:


import tensorflow as tf

import math


def solution_tf(foo, bar):

    foo = tf.convert_to_tensor(foo)

    bar = tf.convert_to_tensor(bar)

    # Get real and imaginary parts

    bar_r = tf.cast(tf.real(bar), foo.dtype)

    bar_i = tf.imag(bar)

    # Mask of all real-valued points

    m = tf.reduce_all(tf.equal(bar_i, 0), axis=-1)

    # Distance to every corresponding point

    d = tf.reduce_sum(tf.squared_difference(tf.expand_dims(foo, 1), bar_r), axis=-1)

    # Replace distances of complex points with infinity

    d2 = tf.where(m, d, tf.fill(tf.shape(d), tf.constant(math.inf, d.dtype)))

    # Find smallest distances

    idx = tf.argmin(d2, axis=1)

    # Get points with smallest distances

    b = tf.range(tf.shape(foo, out_type=idx.dtype)[0])

    return tf.gather_nd(bar_r, tf.stack([b, idx], axis=1))


# Test

with tf.Graph().as_default(), tf.Session() as sess:

    foo = tf.constant([

        [1,2,3],

        [4,5,6],

        [7,8,9]], dtype=tf.float32)

    bar = tf.constant([

        [[2,3,4],[1+0.1j,2+0.1j,3+0.1j]],

        [[6,5,4],[4,5,7]],

        [[1j,1j,1j],[0,0,0]]], dtype=tf.complex64)

    sol_tf = solution_tf(foo, bar)

    print(sess.run(sol_tf))

    # [[2. 3. 4.]

    #  [4. 5. 7.]

    #  [0. 0. 0.]]


查看完整回答
反对 回复 2022-05-24
  • 1 回答
  • 0 关注
  • 88 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信