Como obter o índice de um elemento máximo em uma matriz numpy ao longo de um eixo

118

Eu tenho uma matriz NumPy bidimensional. Eu sei como obter os valores máximos sobre os eixos:

>>> a = array([[1,2,3],[4,3,1]])
>>> amax(a,axis=0)
array([4, 3, 3])

Como posso obter os índices dos elementos máximos? Então, eu gostaria como saídaarray([1,1,0])

Peter Smit
fonte

Respostas:

141
>>> a.argmax(axis=0)

array([1, 1, 0])
eumiro
fonte
1
isso funciona bem para inteiros, mas o que posso fazer para valores flutuantes e os números entre 0 e 1
Priyom saha
100
>>> import numpy as np
>>> a = np.array([[1,2,3],[4,3,1]])
>>> i,j = np.unravel_index(a.argmax(), a.shape)
>>> a[i,j]
4
chama
fonte
11
Observe que essa resposta é enganosa. Ele calcula o índice do elemento máximo da matriz em todos os eixos, não ao longo de um determinado eixo, como o OP pergunta: está errado. Além disso, se houver mais de um máximo, ele recupera os índices apenas do primeiro máximo: isso deve ser destacado. Tente com a = np.array([[1,4,3],[4,3,1]])para ver se ele volta i,j==0,1, e negligencia a solução em i,j==1,0. Para os índices de todos os máximos, use i,j = where(a==a.max().
gg349
36

argmax()retornará apenas a primeira ocorrência de cada linha. http://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html

Se você precisar fazer isso para uma matriz em forma, funciona melhor do que unravel:

import numpy as np
a = np.array([[1,2,3], [4,3,1]])  # Can be of any shape
indices = np.where(a == a.max())

Você também pode alterar suas condições:

indices = np.where(a >= 1.5)

O texto acima fornece os resultados na forma que você solicitou. Como alternativa, você pode converter para uma lista de coordenadas x, y:

x_y_coords =  zip(indices[0], indices[1])
SevakPrime
fonte
2
Isso não funcionou para mim ... Você quer dizer indices = np.where(a==a.max())na linha 3?
atomh33ls
Você está certo, atomh33ls! Obrigado por perceber isso. Corrigi essa declaração para incluir o segundo sinal de igual para a condicional adequada.
SevakPrime
@SevakPrime, houve um segundo erro apontado por @ atomh33ls, em .max()vez de .argmax(). Edite a resposta
gg349
@ gg349, depende do que você deseja. argmax fornece-o ao longo de um eixo que parece ser a maneira que o OP deseja, tendo aprovado essa resposta por eumiro.
SevakPrime,
Vejo que a correção @ atomh33ls e proponho leva ao índice do (s) maior (is) elemento (s) da matriz, enquanto o OP estava perguntando sobre os maiores elementos ao longo de um determinado eixo. Observe, entretanto, que sua solução atual leva a x_y_coord = [(0, 2), (1, 1)]que NÃO corresponde à resposta @eumiro e está errada. Por exemplo, tente a = array([[7,8,9],[10,11,12]])ver se o seu código não tem nenhum acerto nesta entrada. Você também mencionou que funciona melhor que unravel, mas a solução postada por @blas responde ao problema do máximo absoluto, não apenas ao longo de um eixo.
gg349
3
v = alli.max()
index = alli.argmax()
x, y = index/8, index%8
ahmed
fonte