2 回答
TA贡献1812条经验 获得超5个赞
按照评论中的建议使用tf.where。我在下面提供了一个示例代码,并在必要时进行了评论。
x = tf.random_uniform(shape=[5, 3], minval=0, maxval=1, dtype=tf.float32)
#same shape as x and only contains -1
c1 = tf.multiply(tf.ones(x.shape, tf.int32), -1)
#same shape as x and only contains 1
c2 = tf.multiply(tf.ones(x.shape, tf.int32), 1)
var = tf.zeros([5, 3], tf.int32)
#assign 1 element wise if x< 0.25 else 0
r1 = tf.where(tf.less(x, 0.25), c1, var)
#assign -1 element wise if x> 0.75 else 0
r2 = tf.where(tf.greater(x, 0.75), c2, var)
r = tf.add(r1, r2)
with tf.Session() as sess:
_x, _r = sess.run([x, r])
print(_x)
print(_r)
示例结果
[[0.6438687 0.79183984 0.40236235]
[0.7848805 0.0117377 0.6858672 ]
[0.6067281 0.5176437 0.9839716 ]
[0.15617108 0.28574145 0.31405795]
[0.28515983 0.6034068 0.9314337 ]]
[[ 0 1 0]
[ 1 -1 0]
[ 0 0 1]
[-1 0 0]
[ 0 0 1]]
希望这可以帮助。
添加回答
举报