Jak działa parametr class_weight w scikit-learn?


116

Mam wiele problemów ze zrozumieniem, jak działa class_weightparametr regresji logistycznej scikit-learn.

Sytuacja

Chcę użyć regresji logistycznej, aby przeprowadzić klasyfikację binarną na bardzo niezrównoważonym zestawie danych. Klasy są oznaczone jako 0 (negatywne) i 1 (pozytywne), a obserwowane dane są w stosunku około 19: 1, przy czym większość próbek daje wynik negatywny.

Pierwsza próba: ręczne przygotowanie danych treningowych

Dane, które miałem, podzieliłem na rozłączne zbiory do treningu i testów (około 80/20). Następnie ręcznie pobierałem losowe próbki danych szkoleniowych, aby uzyskać dane szkoleniowe w proporcjach innych niż 19: 1; od 2: 1 do 16: 1.

Następnie wytrenowałem regresję logistyczną na tych różnych podzbiorach danych szkoleniowych i wykreśliłem pamięć (= TP / (TP + FN)) jako funkcję różnych proporcji treningowych. Oczywiście przywołanie obliczono na rozłącznych próbkach TEST, które miały obserwowane proporcje 19: 1. Uwaga, chociaż trenowałem różne modele na różnych danych uczących, obliczyłem przypomnienia dla nich wszystkich na tych samych (rozłącznych) danych testowych.

Wyniki były zgodne z oczekiwaniami: przywołanie było około 60% przy proporcjach treningu 2: 1 i spadło dość szybko, gdy osiągnęło 16: 1. Było kilka proporcji 2: 1 -> 6: 1, gdzie zapamiętanie było przyzwoicie powyżej 5%.

Druga próba: wyszukiwanie sieci

Następnie chciałem przetestować różne parametry regularyzacji, więc użyłem GridSearchCV i utworzyłem siatkę kilku wartości Cparametru oraz class_weightparametru. Aby przetłumaczyć moje n: m proporcje negatywnych: pozytywnych próbek szkoleniowych na język słownikowy class_weight, pomyślałem, że po prostu określę kilka słowników w następujący sposób:

{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 }   #expected 4:1

a także włączyłem Nonei auto.

Tym razem wyniki były totalnie oszukane. Wszystkie moje przypomnienia wyszły niewielkie (<0,05) dla każdej wartości class_weightoprócz auto. Mogę więc tylko założyć, że moje rozumienie, jak ustawić class_weightsłownik, jest błędne. Co ciekawe, class_weightwartość „auto” w wyszukiwaniu w siatce wynosiła około 59% dla wszystkich wartości C, a zgadłem, że wynosi 1: 1?

Moje pytania

  1. Jak prawidłowo wykorzystujesz class_weightdane treningowe, aby osiągnąć inną równowagę w porównaniu z tym, co faktycznie je dajesz? Konkretnie, do jakiego słownika należy przejść, aby class_weightzastosować n: m proporcje negatywnych: pozytywnych próbek szkoleniowych?

  2. Jeśli przekażesz różne class_weightsłowniki do GridSearchCV, czy podczas walidacji krzyżowej zrównoważy on dane krotności treningowej zgodnie ze słownikiem, ale użyje prawdziwych podanych proporcji próbki do obliczenia mojej funkcji oceniania na krotnie testowym? Ma to kluczowe znaczenie, ponieważ każdy wskaźnik jest dla mnie przydatny tylko wtedy, gdy pochodzi z danych w obserwowanych proporcjach.

  3. Jaka jest autowartość class_weight, jeśli chodzi o proporcje? Czytam dokumentację i zakładam, że „równoważy dane odwrotnie proporcjonalnie do ich częstotliwości”, czyli po prostu daje wynik 1: 1. Czy to jest poprawne? Jeśli nie, czy ktoś może wyjaśnić?


Kiedy używa się class_weight, funkcja straty zostaje zmodyfikowana. Na przykład, zamiast entropii krzyżowej, staje się zważoną entropią krzyżową. directiondatascience.com/…
prashanth

Odpowiedzi:


123

Po pierwsze, samo przywołanie może nie być dobre. Możesz po prostu osiągnąć 100% wycofania, klasyfikując wszystko jako pozytywną klasę. Zwykle sugeruję użycie AUC do wyboru parametrów, a następnie znalezienia progu dla punktu pracy (powiedzmy danego poziomu dokładności), który Cię interesuje.

Jak to class_weightdziała: penalizuje błędy w próbkach class[i]z class_weight[i]zamiast 1. Tak więc wyższa waga klasy oznacza, że ​​chcesz położyć większy nacisk na klasę. Z tego, co mówisz, wydaje się, że klasa 0 występuje 19 razy częściej niż klasa 1. Dlatego należy zwiększyć wartość class_weightklasy 1 względem klasy 0, powiedzmy {0: .1, 1: .9}. Jeśli class_weightsuma nie wynosi 1, zasadniczo zmieni parametr regularyzacji.

Aby dowiedzieć się class_weight="auto", jak to działa, możesz rzucić okiem na tę dyskusję . W wersji deweloperskiej możesz użyć class_weight="balanced", co jest łatwiejsze do zrozumienia: w zasadzie oznacza to replikowanie mniejszej klasy, aż będziesz mieć tyle próbek, co w większej, ale w niejawny sposób.


1
Dzięki! Szybkie pytanie: wspomniałem o przypomnieniu dla jasności i faktycznie próbuję zdecydować, którego AUC użyć jako mojej miary. Rozumiem, że powinienem albo maksymalizować obszar pod krzywą ROC, albo obszar pod krzywą przypominania w stosunku do krzywej precyzji, aby znaleźć parametry. Po wybraniu parametrów w ten sposób uważam, że wybieram próg klasyfikacji, przesuwając się po krzywej. Czy to miałeś na myśli? Jeśli tak, to na którą z dwóch krzywych warto spojrzeć, jeśli moim celem jest uchwycenie jak największej liczby TP? Dziękuję również za waszą pracę i wkład w scikit-learn !!!
kilgoretrout

1
Myślę, że użycie ROC byłoby bardziej standardowym sposobem, ale nie sądzę, że będzie ogromna różnica. Potrzebujesz jednak pewnych kryteriów, aby wybrać punkt na krzywej.
Andreas Mueller

3
@MiNdFrEaK Myślę, że Andrew ma na myśli to, że estymator replikuje próbki w klasie mniejszości, tak aby próbka różnych klas była zrównoważona. To po prostu nadpróbkowanie w niejawny sposób.
Shawn TIAN

8
@MiNdFrEaK i Shawn Tian: klasyfikatory oparte na SV nie generują więcej sampli mniejszych klas, gdy używasz „zbalansowanego”. Dosłownie karze błędy popełnione w mniejszych klasach. Twierdzenie inaczej jest błędem i wprowadza w błąd, zwłaszcza w przypadku dużych zbiorów danych, kiedy nie można pozwolić sobie na tworzenie większej liczby próbek. Ta odpowiedź musi zostać zredagowana.
Pablo Rivas

4
scikit-learn.org/dev/glossary.html#term-class-weight Wagi klas będą używane w różny sposób w zależności od algorytmu: w przypadku modeli liniowych (takich jak liniowe SVM lub regresja logistyczna) wagi klas będą zmieniać funkcję straty o ważenie ubytku każdej próbki wagą jej klasy. W przypadku algorytmów opartych na drzewie wagi klas będą używane do ponownego ważenia kryterium podziału. Należy jednak pamiętać, że to ponowne wyważenie nie uwzględnia wagi próbek w każdej klasie.
prashanth

2

Pierwsza odpowiedź jest dobra, aby zrozumieć, jak to działa. Ale chciałem zrozumieć, jak powinienem go używać w praktyce.

PODSUMOWANIE

  • w przypadku danych umiarkowanie niezrównoważonych BEZ szumów nie ma dużej różnicy w stosowaniu wag klas
  • dla średnio niezrównoważonych danych Z szumem i silnie niezrównoważonych lepiej jest zastosować wagi klas
  • param class_weight="balanced"działa przyzwoicie pod nieobecność ciebie, który chcesz optymalizować ręcznie
  • dzięki class_weight="balanced"czemu wychwytujesz więcej prawdziwych zdarzeń (wyższa dokładność PRAWDA), ale także masz większe prawdopodobieństwo otrzymania fałszywych alertów (niższa dokładność PRAWDA)
    • w rezultacie całkowity% TRUE może być wyższy niż w rzeczywistości z powodu wszystkich fałszywych alarmów
    • AUC może wprowadzić Cię w błąd, jeśli problemem są fałszywe alarmy
  • nie ma potrzeby zmiany progu decyzyjnego na% nierównowagi, nawet przy silnym braku równowagi, ok, aby zachować 0,5 (lub gdzieś w pobliżu w zależności od tego, czego potrzebujesz)

NB

Wynik może się różnić w przypadku korzystania z RF lub GBM. sklearn nie ma class_weight="balanced" dla GBM, ale lightgbm maLGBMClassifier(is_unbalance=False)

KOD

# scikit-learn==0.21.3
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
import pandas as pd

# case: moderate imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.8]) #,flip_y=0.1,class_sep=0.5)
np.mean(y) # 0.2

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.184
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X).mean() # 0.296 => seems to make things worse?
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.292 => seems to make things worse?

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.83
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X)) # 0.86 => about the same
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.86 => about the same

# case: strong imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.95])
np.mean(y) # 0.06

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.02
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X).mean() # 0.25 => huh??
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.22 => huh??
(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).mean() # same as last

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.64
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X)) # 0.84 => much better
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.85 => similar to manual
roc_auc_score(y,(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).astype(int)) # same as last

print(classification_report(y,LogisticRegression(C=1e9).fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True,normalize='index') # few prediced TRUE with only 28% TRUE recall and 86% TRUE precision so 6%*28%~=2%

print(classification_report(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True,normalize='index') # 88% TRUE recall but also lot of false positives with only 23% TRUE precision, making total predicted % TRUE > actual % TRUE
Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.