Losowa prognoza probabilistyczna lasu a głosowanie większościowe


10

Wydaje się, że scikit uczy się przewidywania probabilistycznego zamiast głosowania większością za techniką agregacji modelu bez wyjaśnienia, dlaczego (1.9.2.1. Losowe lasy).

Czy istnieje jasne wyjaśnienie, dlaczego? Czy jest też dobry artykuł lub artykuł przeglądowy na temat różnych technik agregacji modeli, które można zastosować do tworzenia worków w Losowym lesie?

Dzięki!

Odpowiedzi:


10

Na takie pytania zawsze najlepiej odpowiedzieć, patrząc na kod, jeśli znasz biegle język Python.

RandomForestClassifier.predict, przynajmniej w bieżącej wersji 0.16.1, przewiduje klasę o najwyższym oszacowaniu prawdopodobieństwa, podanym przez predict_proba. ( ta linia )

Dokumentacja predict_probamówi:

Przewidywane prawdopodobieństwa klasowe próbki wejściowej są obliczane jako średnie przewidywane prawdopodobieństwa klasowe drzew w lesie. Prawdopodobieństwo klasowe pojedynczego drzewa to ułamek próbek tej samej klasy w liściu.

Różnica w stosunku do oryginalnej metody jest prawdopodobnie taka, że predictdaje przewidywania zgodne z predict_proba. Wynik ten nazywany jest czasem „miękkim głosowaniem”, a nie „twardym” głosowaniem większościowym w pierwotnym dokumencie Breimana. W szybkim wyszukiwaniu nie mogłem znaleźć odpowiedniego porównania wydajności tych dwóch metod, ale obie wydają się dość rozsądne w tej sytuacji.

predictDokumentacja jest w najlepszym razie dość mylące; Mam złożony wniosek ciągnącego go naprawić.

Jeśli zamiast tego chcesz przewidywać większość głosów, oto funkcja, która to umożliwia. Nazwij to predict_majvote(clf, X)raczej niż clf.predict(X). (Na podstawie predict_proba; tylko lekko przetestowane, ale myślę, że powinno działać.)

from scipy.stats import mode
from sklearn.ensemble.forest import _partition_estimators, _parallel_helper
from sklearn.tree._tree import DTYPE
from sklearn.externals.joblib import Parallel, delayed
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted

def predict_majvote(forest, X):
    """Predict class for X.

    Uses majority voting, rather than the soft voting scheme
    used by RandomForestClassifier.predict.

    Parameters
    ----------
    X : array-like or sparse matrix of shape = [n_samples, n_features]
        The input samples. Internally, it will be converted to
        ``dtype=np.float32`` and if a sparse matrix is provided
        to a sparse ``csr_matrix``.
    Returns
    -------
    y : array of shape = [n_samples] or [n_samples, n_outputs]
        The predicted classes.
    """
    check_is_fitted(forest, 'n_outputs_')

    # Check data
    X = check_array(X, dtype=DTYPE, accept_sparse="csr")

    # Assign chunk of trees to jobs
    n_jobs, n_trees, starts = _partition_estimators(forest.n_estimators,
                                                    forest.n_jobs)

    # Parallel loop
    all_preds = Parallel(n_jobs=n_jobs, verbose=forest.verbose,
                         backend="threading")(
        delayed(_parallel_helper)(e, 'predict', X, check_input=False)
        for e in forest.estimators_)

    # Reduce
    modes, counts = mode(all_preds, axis=0)

    if forest.n_outputs_ == 1:
        return forest.classes_.take(modes[0], axis=0)
    else:
        n_samples = all_preds[0].shape[0]
        preds = np.zeros((n_samples, forest.n_outputs_),
                         dtype=forest.classes_.dtype)
        for k in range(forest.n_outputs_):
            preds[:, k] = forest.classes_[k].take(modes[:, k], axis=0)
        return preds

W tym głupim syntetycznym przypadku, którego próbowałem, przewidywania za predictkażdym razem zgadzały się z tą metodą.


Świetna odpowiedź, Dougal! Dziękujemy za poświęcenie czasu na dokładne wyjaśnienie tego. Proszę również rozważyć przejście do przepełnienia stosu i odpowiedź na to pytanie .
user1745038

1
Jest też papier, tutaj , która rozwiązuje probabilistyczny prognozy.
user1745038
Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.