2016-11-01 17 views
0

У меня есть ndarray, и я хочу установить все не максимальные элементы в последнем измерении равными нулю.Numpy: заполнение не максимальных элементов ndarray нулями

a = np.array([[[1,8,3,4],[6,7,10,6],[11,12,15,4]], 
       [[4,2,3,4],[4,7,9,8],[41,14,15,3]], 
       [[4,22,3,4],[16,7,9,8],[41,12,15,43]] 
      ]) 
print(a.shape) 
(3,3,4) 

можно получить индексы максимальных элементов от np.argmax():

b = np.argmax(a, axis=2) 
b 
array([[1, 2, 2], 
     [0, 2, 0], 
     [1, 0, 3]]) 

Очевидно, что Ь имеет размерность 1 меньше, чем. Теперь я хочу получить новый 3-мерный массив, который имеет все нули, кроме тех, где указаны максимальные значения.

Я хочу, чтобы получить этот массив:

np.array([[[0,1,0,0],[0,0,1,0],[0,0,1,0]], 
      [[1,0,0,1],[0,0,1,0],[1,0,0,0]], 
      [[0,1,0,0],[1,0,0,0],[0,0,0,1]] 
     ]) 

Один из способов добиться этого, я попытался создать эти временные массивы

b = np.repeat(b[:,:,np.newaxis], 4, axis=2) 
t = np.repeat(np.arange(4).reshape(4,1), 9, axis=1).T.reshape(b.shape) 

z = np.zeros(shape=a.shape, dtype=int) 
z[t == b] = 1 
z 
array([[[0, 1, 0, 0], 
    [0, 0, 1, 0], 
    [0, 0, 1, 0]], 

    [[1, 0, 0, 0], 
    [0, 0, 1, 0], 
    [1, 0, 0, 0]], 

    [[0, 1, 0, 0], 
    [1, 0, 0, 0], 
    [0, 0, 0, 1]]]) 

Любая идея, как получить это более эффективный способ?

ответ

1

Вот один способ, который использует вещание:

In [108]: (a == a.max(axis=2, keepdims=True)).astype(int) 
Out[108]: 
array([[[0, 1, 0, 0], 
     [0, 0, 1, 0], 
     [0, 0, 1, 0]], 

     [[1, 0, 0, 1], 
     [0, 0, 1, 0], 
     [1, 0, 0, 0]], 

     [[0, 1, 0, 0], 
     [1, 0, 0, 0], 
     [0, 0, 0, 1]]]) 
+0

OMG, это здорово! Большое спасибо @Warren :) –