1 回答
TA贡献1794条经验 获得超8个赞
这是一种方法:
import tensorflow as tf
import math
def solution_tf(foo, bar):
foo = tf.convert_to_tensor(foo)
bar = tf.convert_to_tensor(bar)
# Get real and imaginary parts
bar_r = tf.cast(tf.real(bar), foo.dtype)
bar_i = tf.imag(bar)
# Mask of all real-valued points
m = tf.reduce_all(tf.equal(bar_i, 0), axis=-1)
# Distance to every corresponding point
d = tf.reduce_sum(tf.squared_difference(tf.expand_dims(foo, 1), bar_r), axis=-1)
# Replace distances of complex points with infinity
d2 = tf.where(m, d, tf.fill(tf.shape(d), tf.constant(math.inf, d.dtype)))
# Find smallest distances
idx = tf.argmin(d2, axis=1)
# Get points with smallest distances
b = tf.range(tf.shape(foo, out_type=idx.dtype)[0])
return tf.gather_nd(bar_r, tf.stack([b, idx], axis=1))
# Test
with tf.Graph().as_default(), tf.Session() as sess:
foo = tf.constant([
[1,2,3],
[4,5,6],
[7,8,9]], dtype=tf.float32)
bar = tf.constant([
[[2,3,4],[1+0.1j,2+0.1j,3+0.1j]],
[[6,5,4],[4,5,7]],
[[1j,1j,1j],[0,0,0]]], dtype=tf.complex64)
sol_tf = solution_tf(foo, bar)
print(sess.run(sol_tf))
# [[2. 3. 4.]
# [4. 5. 7.]
# [0. 0. 0.]]
添加回答
举报