为了账号安全,请及时绑定邮箱和手机立即绑定

Numpy 数组到张量

Numpy 数组到张量

婷婷同学_ 2021-06-18 18:44:35
我正在将 numpy 创建的代码更改为 tensorflow 代码。但是,tensorflow 不支持指定每个元素,(例如 x [i] = 7),布尔值(例如.var [x <0.25] = -1)与可能的 numpy 数组很难。如何将以下代码更改为张量?x=np.random.rand((500*300))var=np.zeros((500*300), dtype=np.uint16)var[x<.25] = -1var[x>.75] = 1S=var.reshape((500, 300))请帮我。注意:我尝试了这一步。x=tf.random_uniform((500*300), minval=0, maxval=1, dtype=tf.float32)var=tf.zeros((500*300), int16)var[x<.25] = -1  # How is the change???????var[x>.75] = 1   # How is the change???????S=var.reshape((500, 300))
查看完整描述

2 回答

?
慕雪6442864

TA贡献1812条经验 获得超5个赞

按照评论中的建议使用tf.where。我在下面提供了一个示例代码,并在必要时进行了评论。


x = tf.random_uniform(shape=[5, 3], minval=0, maxval=1, dtype=tf.float32)


#same shape as x and only contains -1

c1 = tf.multiply(tf.ones(x.shape, tf.int32), -1)

#same shape as x and only contains 1

c2 = tf.multiply(tf.ones(x.shape, tf.int32), 1)


var = tf.zeros([5, 3], tf.int32)


#assign 1 element wise if x< 0.25 else 0

r1 = tf.where(tf.less(x, 0.25), c1, var)

#assign -1 element wise if x> 0.75 else 0

r2 = tf.where(tf.greater(x, 0.75), c2, var)


r = tf.add(r1, r2)


with tf.Session() as sess:

    _x, _r = sess.run([x, r])


    print(_x)

    print(_r)

示例结果


[[0.6438687  0.79183984 0.40236235]

 [0.7848805  0.0117377  0.6858672 ]

 [0.6067281  0.5176437  0.9839716 ]

 [0.15617108 0.28574145 0.31405795]

 [0.28515983 0.6034068  0.9314337 ]]


[[ 0  1  0]

 [ 1 -1  0]

 [ 0  0  1]

 [-1  0  0]

 [ 0  0  1]]

希望这可以帮助。


查看完整回答
反对 回复 2021-06-22
  • 2 回答
  • 0 关注
  • 106 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信