Tępe pierwsze wystąpienie wartości większej niż istniejąca wartość


144

Mam tablicę 1D w numpy i chcę znaleźć pozycję indeksu, w którym wartość przekracza wartość w tablicy numpy.

Na przykład

aa = range(-10,10)

Znajdź pozycję, w aaktórej wartość 5zostanie przekroczona.


2
Powinno być jasne, czy nie może być rozwiązania (ponieważ np. Odpowiedź argmax nie zadziała w tym przypadku (max (0,0,0,0) = 0), jak skomentował
ambrus

Odpowiedzi:


199

To jest trochę szybsze (i wygląda ładniej)

np.argmax(aa>5)

Ponieważ argmaxzatrzyma się na pierwszym True(„W przypadku wielokrotnego wystąpienia wartości maksymalnych zwracane są indeksy odpowiadające pierwszemu wystąpieniu.”) I nie zapisuje kolejnej listy.

In [2]: N = 10000

In [3]: aa = np.arange(-N,N)

In [4]: timeit np.argmax(aa>N/2)
100000 loops, best of 3: 52.3 us per loop

In [5]: timeit np.where(aa>N/2)[0][0]
10000 loops, best of 3: 141 us per loop

In [6]: timeit np.nonzero(aa>N/2)[0][0]
10000 loops, best of 3: 142 us per loop

103
Tylko słowo ostrzeżenia: jeśli w tablicy wejściowej nie ma wartości True, np.argmax z radością zwróci 0 (co nie jest tym, czego chcesz w tym przypadku).
ambrus,

8
Wyniki są prawidłowe, ale wyjaśnienie wydaje mi się nieco podejrzane. argmaxnie wydaje się zatrzymywać na pierwszym True. (Można to sprawdzić, tworząc tablice boolowskie z pojedynczym Truew różnych pozycjach). Szybkość jest prawdopodobnie wyjaśniona przez fakt, że argmaxnie ma potrzeby tworzenia listy wyników.
DrV

1
Myślę, że masz rację, @DrV. Moje wyjaśnienie miało dotyczyć tego, dlaczego daje prawidłowy wynik pomimo pierwotnego zamiaru, a nie szuka maksimum, a nie dlaczego jest szybszy, ponieważ nie mogę twierdzić, że rozumiem wewnętrzne szczegóły argmax.
askewchan

1
@George, obawiam się, że nie wiem dokładnie dlaczego. Mogę tylko powiedzieć, że jest szybszy w konkretnym przykładzie, który pokazałem, więc nie uważałbym tego ogólnie za szybszy bez (i) wiedzy, dlaczego tak jest (patrz komentarz @ DrV) lub (ii) testowania większej liczby przypadków (np. Czy aajest posortowany, jak w odpowiedzi @ Michael).
askewchan

3
@DrV, właśnie uruchomiłem argmaxtablice Boolean z 10 milionami elementów z pojedynczym Truew różnych pozycjach przy użyciu NumPy 1.11.2 i pozycji spraw True. Więc 1.11.2 argmaxwydaje się „zwierać” na tablicach boolowskich.
Ulrich Stern

96

biorąc pod uwagę posortowaną zawartość tablicy, istnieje jeszcze szybsza metoda: wyszukiwanie posortowane .

import time
N = 10000
aa = np.arange(-N,N)
%timeit np.searchsorted(aa, N/2)+1
%timeit np.argmax(aa>N/2)
%timeit np.where(aa>N/2)[0][0]
%timeit np.nonzero(aa>N/2)[0][0]

# Output
100000 loops, best of 3: 5.97 µs per loop
10000 loops, best of 3: 46.3 µs per loop
10000 loops, best of 3: 154 µs per loop
10000 loops, best of 3: 154 µs per loop

19
To naprawdę najlepsza odpowiedź, zakładając, że tablica jest posortowana (co w rzeczywistości nie jest określone w pytaniu). Możesz uniknąć niezręczności +1dziękinp.searchsorted(..., side='right')
askewchan

3
Myślę, że sideargument ma znaczenie tylko wtedy, gdy w posortowanej tablicy występują powtarzające się wartości. Nie zmienia znaczenia zwracanego indeksu, który jest zawsze indeksem, do którego można wstawić wartość zapytania, przesuwając wszystkie kolejne wpisy w prawo i zachowując posortowaną tablicę.
Gus

@Gus, sidedziała, gdy ta sama wartość znajduje się zarówno w posortowanej, jak i wstawionej tablicy, niezależnie od powtarzających się wartości w obu. Powtarzające się wartości w posortowanej tablicy tylko wyolbrzymiają efekt (różnica między stronami to liczba razy, gdy wstawiana wartość pojawia się w posortowanej tablicy). side nie zmienia znaczenia zwracanego indeksu, chociaż nie zmienia wynikowej tablicy po wstawieniu wartości do posortowanej tablicy w tych indeksach. Subtelne, ale ważne rozróżnienie; w rzeczywistości ta odpowiedź daje zły indeks, jeśli go N/2nie ma aa.
askewchan

Jak wskazano w powyższym komentarzu, ta odpowiedź jest oddzielona o jeden, jeśli N/2nie ma aa. Prawidłowa forma to np.searchsorted(aa, N/2, side='right')(bez +1). W przeciwnym razie obie formularze mają ten sam indeks. Rozważ przypadek testowy Nbycia nieparzystym (i N/2.0wymuszenia float, jeśli używasz Pythona 2).
askewchan

21

To też mnie zainteresowało i porównałem wszystkie sugerowane odpowiedzi z perfplotem . (Zastrzeżenie: jestem autorem perfplot.)

Jeśli wiesz, że przeglądana tablica jest już posortowana , to

numpy.searchsorted(a, alpha)

jest dla Ciebie. Jest to operacja działająca w czasie stałym, tj. Prędkość nie zależy od rozmiaru tablicy. Nie możesz być szybszy niż to.

Jeśli nie wiesz nic o swojej tablicy, nie pomylisz się

numpy.argmax(a > alpha)

Już posortowane:

wprowadź opis obrazu tutaj

Nieposortowany:

wprowadź opis obrazu tutaj

Kod do odtworzenia fabuły:

import numpy
import perfplot


alpha = 0.5

def argmax(data):
    return numpy.argmax(data > alpha)

def where(data):
    return numpy.where(data > alpha)[0][0]

def nonzero(data):
    return numpy.nonzero(data > alpha)[0][0]

def searchsorted(data):
    return numpy.searchsorted(data, alpha)

out = perfplot.show(
    # setup=numpy.random.rand,
    setup=lambda n: numpy.sort(numpy.random.rand(n)),
    kernels=[
        argmax, where,
        nonzero,
        searchsorted
        ],
    n_range=[2**k for k in range(2, 20)],
    logx=True,
    logy=True,
    xlabel='len(array)'
    )

4
np.searchsortednie jest stały. Właściwie to O(log(n)). Ale twój przypadek testowy faktycznie porównuje najlepszy przypadek searchsorted(którym jest O(1)).
MSeifert

@MSeifert Jakiego rodzaju tablicy wejściowej / alfa potrzebujesz, aby zobaczyć O (log (n))?
Nico Schlömer

1
Uzyskanie pozycji w indeksie sqrt (length) doprowadziło do bardzo złych wyników. Napisałem również tutaj odpowiedź , w tym ten test.
MSeifert

Wątpię searchsorted(lub jakikolwiek algorytm) może pokonać O(log(n))binarne wyszukiwanie posortowanych, równomiernie rozłożonych danych. EDYCJA: searchsorted to wyszukiwanie binarne.
Mateen Ulhaq

16
In [34]: a=np.arange(-10,10)

In [35]: a
Out[35]:
array([-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,
         3,   4,   5,   6,   7,   8,   9])

In [36]: np.where(a>5)
Out[36]: (array([16, 17, 18, 19]),)

In [37]: np.where(a>5)[0][0]
Out[37]: 16

8

Tablice, które mają stały krok między elementami

W przypadku rangetablicy lub innej liniowo rosnącej tablicy możesz po prostu obliczyć indeks programowo, bez potrzeby iteracji po tablicy:

def first_index_calculate_range_like(val, arr):
    if len(arr) == 0:
        raise ValueError('no value greater than {}'.format(val))
    elif len(arr) == 1:
        if arr[0] > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    first_value = arr[0]
    step = arr[1] - first_value
    # For linearly decreasing arrays or constant arrays we only need to check
    # the first element, because if that does not satisfy the condition
    # no other element will.
    if step <= 0:
        if first_value > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    calculated_position = (val - first_value) / step

    if calculated_position < 0:
        return 0
    elif calculated_position > len(arr) - 1:
        raise ValueError('no value greater than {}'.format(val))

    return int(calculated_position) + 1

Prawdopodobnie można by to trochę poprawić. Upewniłem się, że działa poprawnie dla kilku przykładowych tablic i wartości, ale to nie znaczy, że nie może tam być błędów, zwłaszcza biorąc pod uwagę, że używa pływaków ...

>>> import numpy as np
>>> first_index_calculate_range_like(5, np.arange(-10, 10))
16
>>> np.arange(-10, 10)[16]  # double check
6

>>> first_index_calculate_range_like(4.8, np.arange(-10, 10))
15

Biorąc pod uwagę, że może obliczyć pozycję bez żadnej iteracji, będzie to stały czas ( O(1)) i prawdopodobnie może pokonać wszystkie inne wymienione podejścia. Jednak wymaga stałego kroku w tablicy, w przeciwnym razie da błędne wyniki.

Ogólne rozwiązanie przy użyciu numba

Bardziej ogólnym podejściem byłoby użycie funkcji numba:

@nb.njit
def first_index_numba(val, arr):
    for idx in range(len(arr)):
        if arr[idx] > val:
            return idx
    return -1

To zadziała dla dowolnej tablicy, ale musi iterować po tablicy, więc w przeciętnym przypadku będzie to O(n):

>>> first_index_numba(4.8, np.arange(-10, 10))
15
>>> first_index_numba(5, np.arange(-10, 10))
16

Reper

Mimo że Nico Schlömer przedstawił już pewne wzorce, pomyślałem, że przydatne może być uwzględnienie moich nowych rozwiązań i przetestowanie pod kątem różnych „wartości”.

Konfiguracja testu:

import numpy as np
import math
import numba as nb

def first_index_using_argmax(val, arr):
    return np.argmax(arr > val)

def first_index_using_where(val, arr):
    return np.where(arr > val)[0][0]

def first_index_using_nonzero(val, arr):
    return np.nonzero(arr > val)[0][0]

def first_index_using_searchsorted(val, arr):
    return np.searchsorted(arr, val) + 1

def first_index_using_min(val, arr):
    return np.min(np.where(arr > val))

def first_index_calculate_range_like(val, arr):
    if len(arr) == 0:
        raise ValueError('empty array')
    elif len(arr) == 1:
        if arr[0] > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    first_value = arr[0]
    step = arr[1] - first_value
    if step <= 0:
        if first_value > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    calculated_position = (val - first_value) / step

    if calculated_position < 0:
        return 0
    elif calculated_position > len(arr) - 1:
        raise ValueError('no value greater than {}'.format(val))

    return int(calculated_position) + 1

@nb.njit
def first_index_numba(val, arr):
    for idx in range(len(arr)):
        if arr[idx] > val:
            return idx
    return -1

funcs = [
    first_index_using_argmax, 
    first_index_using_min, 
    first_index_using_nonzero,
    first_index_calculate_range_like, 
    first_index_numba, 
    first_index_using_searchsorted, 
    first_index_using_where
]

from simple_benchmark import benchmark, MultiArgument

a wykresy zostały wygenerowane przy użyciu:

%matplotlib notebook
b.plot()

pozycja jest na początku

b = benchmark(
    funcs,
    {2**i: MultiArgument([0, np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

wprowadź opis obrazu tutaj

Najlepiej działa funkcja numba, po której następuje funkcja obliczeniowa i funkcja posortowana z wyszukiwaniem. Inne rozwiązania działają znacznie gorzej.

pozycja jest na końcu

b = benchmark(
    funcs,
    {2**i: MultiArgument([2**i-2, np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

wprowadź opis obrazu tutaj

W przypadku małych tablic funkcja numba działa zadziwiająco szybko, jednak w przypadku większych tablic jest lepsza od funkcji obliczającej i funkcji sortowania.

pozycja jest na sqrt (len)

b = benchmark(
    funcs,
    {2**i: MultiArgument([np.sqrt(2**i), np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

wprowadź opis obrazu tutaj

To jest bardziej interesujące. Ponownie numba i funkcja obliczająca działają świetnie, jednak w rzeczywistości powoduje to najgorszy przypadek sortowania wyszukiwania, który naprawdę nie działa dobrze w tym przypadku.

Porównanie funkcji, gdy żadna wartość nie spełnia warunku

Innym interesującym punktem jest zachowanie tych funkcji, jeśli nie ma wartości, której indeks powinien zostać zwrócony:

arr = np.ones(100)
value = 2

for func in funcs:
    print(func.__name__)
    try:
        print('-->', func(value, arr))
    except Exception as e:
        print('-->', e)

Z tym wynikiem:

first_index_using_argmax
--> 0
first_index_using_min
--> zero-size array to reduction operation minimum which has no identity
first_index_using_nonzero
--> index 0 is out of bounds for axis 0 with size 0
first_index_calculate_range_like
--> no value greater than 2
first_index_numba
--> -1
first_index_using_searchsorted
--> 101
first_index_using_where
--> index 0 is out of bounds for axis 0 with size 0

Searchsorted, argmax i numba po prostu zwracają nieprawidłową wartość. Jednak searchsortedi numbazwróć indeks, który nie jest prawidłowym indeksem dla tablicy.

Funkcje where, min, nonzeroa calculaterzut wyjątek. Jednak tylko wyjątek calculatefaktycznie mówi coś pomocnego.

Oznacza to, że w rzeczywistości należy zawrzeć te wywołania w odpowiedniej funkcji opakowującej, która wyłapuje wyjątki lub nieprawidłowe wartości zwracane i odpowiednio je obsługuje, przynajmniej jeśli nie jesteś pewien, czy wartość może znajdować się w tablicy.


Uwaga: Obliczanie i searchsortedopcje działają tylko w specjalnych warunkach. Funkcja „oblicz” wymaga stałego kroku, a posortowane wyszukiwanie wymaga posortowania tablicy. Mogą więc być przydatne w odpowiednich okolicznościach, ale nie są ogólnymi rozwiązaniami tego problemu. W przypadku, gdy mamy do czynienia z posortowanych list Pythona warto spojrzeć na przepoławiać modułu zamiast korzystania Numpys searchsorted.


3

Chciałbym zaproponować

np.min(np.append(np.where(aa>5)[0],np.inf))

Zwróci to najmniejszy indeks, w którym warunek jest spełniony, podczas gdy wherezwróci nieskończoność, jeśli warunek nigdy nie zostanie spełniony (i zwróci pustą tablicę).


1

Poszedłbym z

i = np.min(np.where(V >= x))

gdzie Vjest wektor (tablica 1d), xjest wartością i ijest wynikiem indeksu.

Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.