为了账号安全,请及时绑定邮箱和手机立即绑定

Numpy where() 使用随数组中项目位置而变化的条件

Numpy where() 使用随数组中项目位置而变化的条件

海绵宝宝撒 2023-06-27 14:20:54
我正在尝试使用 numpy 构建一个网格世界。网格为4*4,排列成正方形。第一个和最后一个方格(即 1 和 16)是终端方格。在每个时间步,您可以向任意方向移动一步:上、下、左或右。一旦您进入终端方块之一,就不可能再进行进一步的移动,游戏就会终止。第一列和最后一列是正方形的左边缘和右边缘,而第一行和最后一行代表顶部边缘和底部边缘。如果您位于边缘(例如左边的边缘)并尝试向左移动,您不会向左移动,而是停留在开始所在的方格中。类似地,如果您尝试穿过任何其他边,您也会停留在同一个方格中。虽然网格是一个正方形,但我将它实现为一个数组。states_r 计算右移后状态的位置。1 和 16 保持原样,因为它们是终止状态(请注意,代码使用基于零的计数,因此 1 和 16 在代码中分别是 0 和 15)。其余的方格都加一。states_r 的代码可以工作,但是右边缘上的那些方块,即 (4, 8, 12) 也应该保持在原来的位置,但 states_r 代码不会这样做。State_l 是我尝试包含正方形左边缘的边缘条件。逻辑是相同的,终端状态 (1, 16) 不应移动,左侧边缘的那些方块 (5, 9, 13) 也不应移动。我认为一般逻辑是正确的,但它产生了错误。states = np.arange(16) states_r = states[np.where((states + 1 <= 15) & (states != 0), states + 1, states)] states_l = states[np.where((max(1, (states // 4) * 4) <= states - 1) & (states != 15), states - 1, states)]第一个示例 states_r 有效,它处理终端状态,但不处理边缘条件。第二个例子是我尝试包含边缘条件,但是它给了我以下错误:“具有多个元素的数组的真值是不明确的。”有人可以解释一下如何修复我的代码吗?或者建议另一种解决方案,理想情况下我希望代码速度快(这样我就可以扩展它),所以我想尽可能避免 for 循环?
查看完整描述

2 回答

?
烙印99

TA贡献1829条经验 获得超13个赞

如果我理解正确的话,你需要数组来指示每个状态的下一个状态,具体取决于移动(右、左、上、下)。如果是这样,我猜你的退出执行state_r不正确。我建议切换到网格的 2D 表示,因为如果直接有 x 和 y (至少对我来说),您描述的很多事情会更容易、更直观地处理。


import numpy as np


n = 4

states = np.arange(n*n).reshape(n, n)

states_r, states_l, states_u, states_d = (states.copy(), states.copy(), 

                                          states.copy(), states.copy())

states_r[:, :n-1] = states[:, 1:]

states_l[:, 1:] = states[:, :n-1]

states_u[1:, :] = states[:n-1, :]

states_d[:n-1, :] = states[1:, :]


#        up             [[ 0,  1,  2,  3],

#  left state right      [ 0,  1,  2,  3],

#       down             [ 4,  5,  6,  7],

#                        [ 8,  9, 10, 11]]

#

#  [[ 0,  0,  1,  2],   [[ 0,  1,  2,  3],   [[ 1,  2,  3,  3],

#   [ 4,  4,  5,  6],    [ 4,  5,  6,  7],    [ 5,  6,  7,  7],

#   [ 8,  8,  9, 10],    [ 8,  9, 10, 11],    [ 9, 10, 11, 11],

#   [12, 12, 13, 14]]    [12, 13, 14, 15]]    [13, 14, 15, 15]]

#

#                       [[ 4,  5,  6,  7],

#                        [ 8,  9, 10, 11],

#                        [12, 13, 14, 15],

#                        [12, 13, 14, 15]]


如果你想排除终端状态,你可以这样做:


terminal_states = np.zeros((n, n), dtype=bool)

terminal_states[0, 0] = True

terminal_states[-1, -1] = True

states_r[terminal_states] = states[terminal_states]

states_l[terminal_states] = states[terminal_states]

states_u[terminal_states] = states[terminal_states]

states_d[terminal_states] = states[terminal_states]

如果您更喜欢一维方法:


import numpy as np


n = 4

states = np.arange(n*n)

valid_s = np.ones(n*n, dtype=bool)

valid_s[0] = False

valid_s[-1] = False


states_r = np.where(np.logical_and(valid_s, states % n < n-1), states+1, states)

states_l = np.where(np.logical_and(valid_s, states % n > 0),   states-1, states)

states_u = np.where(np.logical_and(valid_s, states > n-1),     states-n, states)

states_d = np.where(np.logical_and(valid_s, states < n**2-n),  states+n, states)


查看完整回答
反对 回复 2023-06-27
?
慕桂英546537

TA贡献1848条经验 获得超10个赞

另一种无需预分配数组的方法:


states = np.arange(16).reshape(4,4)


states_l = np.hstack((states[:,0][:,None],states[:,:-1],))

states_r = np.hstack((states[:,1:],states[:,-1][:,None]))

states_d = np.vstack((states[1:,:],states[-1,:]))

states_u = np.vstack((states[0,:],states[:-1,:]))

为了将它们全部变为一维,您始终可以使用flatten()/ravel()/reshape(-1)二维数组。


                  [[ 0  1  2  3]

                   [ 0  1  2  3]

                   [ 4  5  6  7]

                   [ 8  9 10 11]]


[[ 0  0  1  2]    [[ 0  1  2  3]    [[ 1  2  3  3]

 [ 4  4  5  6]     [ 4  5  6  7]     [ 5  6  7  7]

 [ 8  8  9 10]     [ 8  9 10 11]     [ 9 10 11 11]

 [12 12 13 14]]    [12 13 14 15]]    [13 14 15 15]]

                  

                  [[ 4  5  6  7]

                   [ 8  9 10 11]

                   [12 13 14 15]

                   [12 13 14 15]]

对于角落,你可以这样做:


states_u[-1,-1] = 15

states_l[-1,-1] = 15


查看完整回答
反对 回复 2023-06-27
  • 2 回答
  • 0 关注
  • 124 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信