Jakie jest intuicyjne wyjaśnienie techniki maksymalizacji oczekiwań? [Zamknięte]


109

Maksymalizacja oczekiwań (EM) jest rodzajem probabilistycznej metody klasyfikacji danych. Proszę poprawić mnie, jeśli się mylę, jeśli nie jest to klasyfikator.

Jakie jest intuicyjne wyjaśnienie tej techniki EM? Co expectationtu jest i co się dzieje maximized?


12
Jaki jest algorytm maksymalizacji oczekiwań? , Nature Biotechnology 26 , 897–899 (2008) zawiera ładny obraz, który ilustruje sposób działania algorytmu.
chl

@chl w części b na zdjęciu ładne , jak oni uzyskać wartości rozkładu prawdopodobieństwa na Z (tj 0.45xA, 0.55xB, itd.)?
Noob Saibot

3
Możesz spojrzeć na to pytanie math.stackexchange.com/questions/25111/ ...
v4r

3
Zaktualizowany link do zdjęcia, o którym wspomniał @chl.
n1k31t4

Odpowiedzi:


120

Uwaga: kod odpowiadający za tę odpowiedź można znaleźć tutaj .


Załóżmy, że mamy pewne dane pobrane z dwóch różnych grup, czerwonej i niebieskiej:

wprowadź opis obrazu tutaj

Tutaj możemy zobaczyć, który punkt danych należy do grupy czerwonej lub niebieskiej. Ułatwia to znalezienie parametrów charakteryzujących każdą grupę. Na przykład, średnia grupy czerwonej wynosi około 3, średnia grupy niebieskiej wynosi około 7 (i moglibyśmy znaleźć dokładną średnią, gdybyśmy chcieli).

Jest to ogólnie znane jako oszacowanie maksymalnego prawdopodobieństwa . Biorąc pod uwagę pewne dane, obliczamy wartość parametru (lub parametrów), który najlepiej wyjaśnia te dane.

Teraz wyobraź sobie, że nie możemy zobaczyć, która wartość była próbkowana z której grupy. Dla nas wszystko wygląda na fioletowe:

wprowadź opis obrazu tutaj

Tutaj wiemy, że istnieją dwie grupy wartości, ale nie wiemy, do której grupy należy dana wartość.

Czy nadal możemy oszacować średnie dla grupy czerwonej i niebieskiej, które najlepiej pasują do tych danych?

Tak, często możemy! Maksymalizacja oczekiwań daje nam na to sposób. Bardzo ogólna idea algorytmu jest taka:

  1. Zacznij od wstępnego oszacowania, jaki może być każdy parametr.
  2. Oblicz prawdopodobieństwo że każdy parametr tworzy punkt danych.
  3. Oblicz wagi dla każdego punktu danych, wskazując, czy jest bardziej czerwony, czy bardziej niebieski, w oparciu o prawdopodobieństwo, że zostanie wytworzony przez parametr. Połącz wagi z danymi ( oczekiwanie ).
  4. Oblicz lepsze oszacowanie parametrów przy użyciu danych skorygowanych o wagę ( maksymalizacja ).
  5. Powtarzaj kroki od 2 do 4, aż oszacowanie parametru osiągnie zbieżność (proces przestanie dawać inną ocenę).

Te kroki wymagają dalszych wyjaśnień, więc omówię problem opisany powyżej.

Przykład: szacowanie średniej i odchylenia standardowego

W tym przykładzie użyję Pythona, ale kod powinien być dość łatwy do zrozumienia, jeśli nie znasz tego języka.

Załóżmy, że mamy dwie grupy, czerwoną i niebieską, z wartościami rozłożonymi jak na powyższym obrazku. W szczególności każda grupa zawiera wartość pobraną z rozkładu normalnego z następującymi parametrami:

import numpy as np
from scipy import stats

np.random.seed(110) # for reproducible results

# set parameters
red_mean = 3
red_std = 0.8

blue_mean = 7
blue_std = 2

# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)

both_colours = np.sort(np.concatenate((red, blue))) # for later use...

Oto ponownie obraz tych czerwonych i niebieskich grup (aby uniknąć konieczności przewijania w górę):

wprowadź opis obrazu tutaj

Kiedy widzimy kolor każdego punktu (tj. Do której grupy należy), bardzo łatwo jest oszacować średnią i odchylenie standardowe dla każdej grupy. Po prostu przekazujemy wartości czerwony i niebieski do funkcji wbudowanych w NumPy. Na przykład:

>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195

Ale co, jeśli nie widzimy kolorów punktów? Oznacza to, że zamiast czerwonego lub niebieskiego każdy punkt został pokolorowany na fioletowo.

Aby spróbować odzyskać średnią i parametry odchylenia standardowego dla grup czerwonych i niebieskich, możemy użyć maksymalizacji oczekiwań.

Naszym pierwszym krokiem ( krok 1 powyżej) jest odgadnięcie wartości parametrów dla średniej i odchylenia standardowego każdej grupy. Nie musimy inteligentnie zgadywać; możemy wybrać dowolne liczby:

# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9

# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

Te oszacowania parametrów dają krzywe dzwonowe, które wyglądają następująco:

wprowadź opis obrazu tutaj

To są złe szacunki. Oba środki (pionowe przerywane linie) wyglądają na daleko od wszelkiego rodzaju „środka”, na przykład w przypadku rozsądnych grup punktów. Chcemy poprawić te szacunki.

Następnym krokiem ( krok 2 ) jest obliczenie prawdopodobieństwa pojawienia się każdego punktu danych pod bieżącymi domysłami parametrów:

likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)

Tutaj po prostu umieściliśmy każdy punkt danych w funkcji gęstości prawdopodobieństwa dla rozkładu normalnego, używając naszych aktualnych przypuszczeń na temat średniej i odchylenia standardowego dla czerwieni i niebieskiego. To mówi nam na przykład, że przy naszych aktualnych domysłach punkt danych przy 1,761 jest znacznie bardziej prawdopodobny, że będzie czerwony (0,189) niż niebieski (0,00003).

Dla każdego punktu danych możemy zamienić te dwie wartości prawdopodobieństwa na wagi ( krok 3 ), aby sumowały się do 1 w następujący sposób:

likelihood_total = likelihood_of_red + likelihood_of_blue

red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

Dzięki naszym bieżącym szacunkom i nowo obliczonym wagom możemy teraz obliczyć nowe oszacowania średniej i odchylenia standardowego grup czerwonych i niebieskich ( krok 4 ).

Dwukrotnie obliczamy średnią i odchylenie standardowe przy użyciu wszystkich punktów danych, ale z różnymi wagami: raz dla wag czerwonych i raz dla wag niebieskich.

Kluczową intuicją jest to, że im większa waga koloru w punkcie danych, tym bardziej punkt danych wpływa na następne oszacowania parametrów tego koloru. Powoduje to „ciągnięcie” parametrów we właściwym kierunku.

def estimate_mean(data, weight):
    """
    For each data point, multiply the point by the probability it
    was drawn from the colour's distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among our data points.
    """
    return np.sum(data * weight) / np.sum(weight)

def estimate_std(data, weight, mean):
    """
    For each data point, multiply the point's squared difference
    from a mean value by the probability it was drawn from
    that distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among the values for the difference of
    each data point from the mean.

    This is the estimate of the variance, take the positive square
    root to find the standard deviation.
    """
    variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
    return np.sqrt(variance)

# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)

# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

Mamy nowe szacunki parametrów. Aby je ponownie ulepszyć, możemy wrócić do kroku 2 i powtórzyć proces. Robimy to do osiągnięcia zbieżności szacunków lub po wykonaniu pewnej liczby iteracji ( krok 5 ).

W przypadku naszych danych pierwsze pięć iteracji tego procesu wygląda następująco (ostatnie iteracje mają silniejszy wygląd):

wprowadź opis obrazu tutaj

Widzimy, że średnie już zbiegają się na niektórych wartościach, a kształty krzywych (regulowane odchyleniem standardowym) również stają się bardziej stabilne.

Jeśli będziemy kontynuować przez 20 iteracji, otrzymamy co następuje:

wprowadź opis obrazu tutaj

Proces EM zbiegał się do następujących wartości, które okazują się bardzo zbliżone do rzeczywistych wartości (gdzie widzimy kolory - brak ukrytych zmiennych):

          | EM guess | Actual |  Delta
----------+----------+--------+-------
Red mean  |    2.910 |  2.802 |  0.108
Red std   |    0.854 |  0.871 | -0.017
Blue mean |    6.838 |  6.932 | -0.094
Blue std  |    2.227 |  2.195 |  0.032

W powyższym kodzie mogłeś zauważyć, że nowe oszacowanie odchylenia standardowego zostało obliczone przy użyciu oszacowania średniej z poprzedniej iteracji. Ostatecznie nie ma znaczenia, czy najpierw obliczymy nową wartość dla średniej, ponieważ właśnie znajdujemy (ważoną) wariancję wartości wokół jakiegoś centralnego punktu. Nadal będziemy widzieć zbieżność szacunków parametrów.


co jeśli nawet nie znamy liczby normalnych rozkładów, z których to pochodzi? Tutaj masz przykład rozkładów k = 2, czy możemy również oszacować k i zestawy parametrów k?
stackit

1
@stackit: Nie jestem pewien, czy istnieje prosty, ogólny sposób obliczenia najbardziej prawdopodobnej wartości k jako części procesu EM w tym przypadku. Głównym problemem jest to, że musielibyśmy rozpocząć EM od oszacowań dla każdego z parametrów, które chcemy znaleźć, a to oznacza, że ​​musimy znać / oszacować k przed rozpoczęciem. W tym miejscu można jednak oszacować odsetek punktów należących do grupy za pośrednictwem EM. Może gdybyśmy przeszacowali k, proporcja wszystkich grup oprócz dwóch spadłaby prawie do zera. Nie eksperymentowałem z tym, więc nie wiem, jak dobrze by to działało w praktyce.
Alex Riley,

1
@AlexRiley Czy możesz powiedzieć trochę więcej o formułach obliczania nowej średniej i szacunkowych odchyleń standardowych?
Lemon

2
@AlexRiley Dzięki za wyjaśnienie. Dlaczego nowe szacunki odchylenia standardowego są obliczane przy użyciu starego domysłu średniej? A jeśli nowe szacunki średniej zostaną znalezione jako pierwsze?
GoodDeeds

1
@Lemon GoodDeeds Kaushal - przepraszam za spóźnioną odpowiedź na Twoje pytania. Próbowałem zmienić odpowiedź, aby odnieść się do podniesionych przez Ciebie punktów. Udostępniłem również cały kod użyty w tej odpowiedzi w zeszycie tutaj (który zawiera również bardziej szczegółowe wyjaśnienia niektórych punktów, które poruszyłem).
Alex Riley,

36

EM to algorytm maksymalizacji funkcji wiarygodności, gdy niektóre zmienne w twoim modelu są niezauważone (np. Gdy masz zmienne latentne).

Możesz uczciwie zapytać, jeśli po prostu próbujemy zmaksymalizować funkcję, dlaczego nie wykorzystamy po prostu istniejącej maszyny do maksymalizacji funkcji. Cóż, jeśli spróbujesz to zmaksymalizować, biorąc pochodne i ustawiając je na zero, okaże się, że w wielu przypadkach warunki pierwszego rzędu nie mają rozwiązania. Istnieje problem typu kura i jajko, aby rozwiązać parametry modelu, musisz znać dystrybucję nieobserwowanych danych; ale rozkład twoich nieobserwowanych danych jest funkcją parametrów twojego modelu.

EM próbuje obejść ten problem poprzez iteracyjne odgadywanie rozkładu nieobserwowanych danych, a następnie szacowanie parametrów modelu poprzez maksymalizację czegoś, co jest dolną granicą rzeczywistej funkcji wiarygodności i powtarzanie aż do zbieżności:

Algorytm EM

Zacznij od odgadnięcia wartości parametrów modelu

E-krok: dla każdego punktu danych, który ma brakujące wartości, użyj równania modelu, aby znaleźć rozkład brakujących danych, biorąc pod uwagę aktualne przypuszczenie parametrów modelu i dane obserwowane (zwróć uwagę, że rozwiązujesz rozkład dla każdego brakującego wartość, a nie wartość oczekiwana). Teraz, gdy mamy rozkład dla każdej brakującej wartości, możemy obliczyć oczekiwanie funkcji wiarygodności w odniesieniu do nieobserwowanych zmiennych. Jeśli nasze przypuszczenie dla parametru modelu było poprawne, to oczekiwane prawdopodobieństwo będzie rzeczywistym prawdopodobieństwem zaobserwowanych przez nas danych; jeśli parametry nie były prawidłowe, będzie to tylko dolna granica.

M-step: Teraz, gdy mamy oczekiwaną funkcję prawdopodobieństwa bez nieobserwowanych zmiennych, zmaksymalizuj funkcję tak, jak w przypadku w pełni obserwowanego, aby uzyskać nowe oszacowanie parametrów modelu.

Powtarzaj do zbieżności.


5
Nie rozumiem twojego kroku E. Częściowo problem polega na tym, że kiedy się tego uczę, nie mogę znaleźć ludzi, którzy używają tej samej terminologii. Więc co masz na myśli mówiąc o równaniu modelu? Nie wiem, co masz na myśli, rozwiązując rozkład prawdopodobieństwa?
user678392,

27

Oto prosty przepis na zrozumienie algorytmu maksymalizacji oczekiwań:

1- Przeczytaj ten samouczek EM autorstwa Do i Batzoglou.

2- Możesz mieć w głowie znaki zapytania, spójrz na wyjaśnienia na tej stronie wymiany stosów matematycznych .

3- Spójrz na ten kod, który napisałem w Pythonie, który wyjaśnia przykład w dokumencie instruktażowym EM w punkcie 1:

Ostrzeżenie: kod może być niechlujny / nieoptymalny, ponieważ nie jestem programistą Pythona. Ale spełnia swoje zadanie.

import numpy as np
import math

#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* #### 

def get_mn_log_likelihood(obs,probs):
    """ Return the (log)likelihood of obs, given the probs"""
    # Multinomial Distribution Log PMF
    # ln (pdf)      =             multinomial coeff            *   product of probabilities
    # ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]     

    multinomial_coeff_denom= 0
    prod_probs = 0
    for x in range(0,len(obs)): # loop through state counts in each observation
        multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
        prod_probs = prod_probs + obs[x]*math.log(probs[x])

    multinomial_coeff = math.log(math.factorial(sum(obs))) -  multinomial_coeff_denom
    likelihood = multinomial_coeff + prod_probs
    return likelihood

# 1st:  Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd:  Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd:  Coin A, {HTHHHHHTHH}, 8H,2T
# 4th:  Coin B, {HTHTTTHHTT}, 4H,6T
# 5th:  Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45

# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = zip(head_counts,tail_counts)

# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50

# E-M begins!
delta = 0.001  
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
    expectation_A = np.zeros((5,2), dtype=float) 
    expectation_B = np.zeros((5,2), dtype=float)
    for i in range(0,len(experiments)):
        e = experiments[i] # i'th experiment
        ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
        ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B

        weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A 
        weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B                            

        expectation_A[i] = np.dot(weightA, e) 
        expectation_B[i] = np.dot(weightB, e)

    pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A)); 
    pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B)); 

    improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
    j = j+1

Uważam, że twój program da wynik zarówno z A, jak i B do 0,66, implementuję go również za pomocą scala, również stwierdzam, że wynik to 0,66, czy możesz pomóc to sprawdzić?
zjffdu,

Korzystając z arkusza kalkulacyjnego, znajduję wyniki 0,66 tylko wtedy, gdy moje początkowe przypuszczenia są równe. W przeciwnym razie mogę odtworzyć wyniki samouczka.
soakley

@zjffdu, ile iteracji wykonuje EM, zanim zwróci ci 0.66? Jeśli zainicjujesz z równymi wartościami, może utknąć na lokalnym maksimum i zobaczysz, że liczba iteracji jest wyjątkowo niska (ponieważ nie ma poprawy).
Zhubarb

Możesz również obejrzeć ten slajd autorstwa Andrew Ng i notatki z kursu Harvardu
Minh Phan

16

Technicznie termin „EM” jest nieco niedookreślony, ale zakładam, że odnosisz się do techniki analizy skupień Gaussian Mixture Modeling, która jest przykładem ogólnej zasady EM.

W rzeczywistości analiza skupień EM nie jest klasyfikatorem . Wiem, że niektórzy uważają tworzenie klastrów za „klasyfikację nienadzorowaną”, ale w rzeczywistości analiza skupień to coś zupełnie innego.

Kluczowa różnica i wielkie niezrozumienie klasyfikacji, które ludzie zawsze mają w analizie skupień, jest takie, że: w analizie klastrów nie ma „poprawnego rozwiązania” . Jest to metoda odkrywania wiedzy , w rzeczywistości ma na celu znalezienie czegoś nowego ! To sprawia, że ​​ocena jest bardzo trudna. Często jest oceniany przy użyciu znanej klasyfikacji jako odniesienia, ale nie zawsze jest to właściwe: klasyfikacja, którą posiadasz, może, ale nie musi, odzwierciedlać to, co jest w danych.

Podam przykład: masz duży zbiór danych klientów, w tym dane dotyczące płci. Metoda dzieląca ten zestaw danych na „mężczyzna” i „kobieta” jest optymalna, gdy porównuje się go z istniejącymi klasami. W myśleniu „przewidywania” jest to dobre, ponieważ w przypadku nowych użytkowników można teraz przewidzieć ich płeć. W myśleniu „odkrywania wiedzy” jest to właściwie złe, ponieważ chciałeś odkryć jakąś nową strukturę danych. Metoda, która np. Podzieliłaby dane na osoby starsze i dzieci, uzyskałaby jednak gorsze wyniki, jak to możliwe w odniesieniu do klasy mężczyzn / kobiet. Byłby to jednak doskonały wynik grupowania (gdyby nie podano wieku).

Wróćmy teraz do EM. Zasadniczo zakłada się, że dane składają się z wielu wielowymiarowych rozkładów normalnych (zwróć uwagę, że jest to bardzo mocne założenie, zwłaszcza gdy ustalasz liczbę klastrów!). Następnie próbuje znaleźć optymalny model lokalny, na przemian ulepszając model i przypisanie obiektu do modelu .

Aby uzyskać najlepsze wyniki w kontekście klasyfikacji, wybierz liczbę klastrów większą niż liczba klas, a nawet zastosuj grupowanie tylko do pojedynczych klas (aby dowiedzieć się, czy w klasie jest jakaś struktura!).

Załóżmy, że chcesz nauczyć klasyfikatora rozróżniać „samochody”, „rowery” i „ciężarówki”. Zakładanie, że dane składają się z dokładnie trzech rozkładów normalnych, jest mało przydatne. Możesz jednak założyć, że istnieje więcej niż jeden typ samochodów (oraz ciężarówek i motocykli). Więc zamiast trenować klasyfikator dla tych trzech klas, grupujesz samochody, ciężarówki i motocykle w 10 grup (lub może 10 samochodów, 3 ciężarówki i 3 rowery, cokolwiek), następnie trenujesz klasyfikator, aby rozróżniał te 30 klas, a następnie scal wynik klasy z powrotem do klas oryginalnych. Możesz również odkryć, że istnieje jeden klaster, który jest szczególnie trudny do sklasyfikowania, na przykład Trikes. To trochę samochody i trochę motocykle. Albo samochody dostawcze, które bardziej przypominają duże samochody niż ciężarówki.


jak jest niedookreślona EM?
sam boosalis,

Istnieje więcej niż jedna wersja tego. Technicznie rzecz biorąc, możesz też nazwać k-znaczy w stylu Lloyda „EM”. Musisz określić, jakiego modelu używasz.
ZAKOŃCZYŁO - Anony-Mousse

2

Jeśli inne odpowiedzi są dobre, spróbuję przedstawić inną perspektywę i zająć się intuicyjną częścią pytania.

Algorytm EM (Expectation-Maximization) jest wariantem klasy iteracyjnych algorytmów wykorzystujących dualność

Fragment (moje podkreślenie):

W matematyce dualność, ogólnie mówiąc, przekłada pojęcia, twierdzenia lub struktury matematyczne na inne pojęcia, twierdzenia lub struktury, w sposób jeden do jednego, często (ale nie zawsze) za pomocą operacji inwolucyjnej: jeśli A to B, a następnie podwójna z B to A. Takie inwolucje czasami mają stałe punkty , tak że podwójna liczba A jest sama w sobie A

Zwykle podwójne B obiektu A jest w jakiś sposób powiązane z A, co pozwala zachować pewną symetrię lub zgodność . Na przykład AB = const

Przykłady algorytmów iteracyjnych wykorzystujących dualność (w poprzednim znaczeniu) to:

  1. Algorytm euklidesowy dla największego wspólnego dzielnika i jego warianty
  2. Algorytm i warianty bazy wektora Grama – Schmidta
  3. Średnia arytmetyczna - Nierówność średnich geometrycznych i jej warianty
  4. Algorytm oczekiwanie-maksymalizacja i jego warianty (zobacz także informacje-geometryczny widok )
  5. (.. inne podobne algorytmy ..)

W podobny sposób algorytm EM można również postrzegać jako dwa podwójne kroki maksymalizacji :

.. [EM] jest postrzegane jako maksymalizacja łącznej funkcji parametrów i rozkładu względem nieobserwowanych zmiennych. Krok E maksymalizuje tę funkcję w odniesieniu do rozkładu na nieobserwowanych zmiennych; krok M w odniesieniu do parametrów.

W iteracyjnym algorytmie wykorzystującym dualność istnieje jawne (lub niejawne) założenie równowagi (lub ustalonego) punktu zbieżności (dla EM jest to udowodnione za pomocą nierówności Jensena)

Zatem zarys takich algorytmów jest następujący:

  1. Krok podobny do E: znajdź najlepsze rozwiązanie x w odniesieniu do danego y utrzymywanego na stałym poziomie.
  2. Krok podobny do M (podwójny): Znajdź najlepsze rozwiązanie y w odniesieniu do x (jak obliczono w poprzednim kroku) utrzymywanego na stałym poziomie.
  3. Kryterium etapu zakończenia / zbieżności: Powtórz kroki 1, 2 ze zaktualizowanymi wartościami x , y aż do osiągnięcia zbieżności (lub określonej liczby iteracji)

Zauważ, że kiedy taki algorytm zbiega się do (globalnego) optimum, to znalazł konfigurację, która jest najlepsza z obu względów (tj. Zarówno w domenie / parametrach x, jak i w domenie / parametrach y ). Jednak algorytm może po prostu znaleźć optimum lokalne, a nie optymalne globalne .

powiedziałbym, że jest to intuicyjny opis zarysu algorytmu

W przypadku argumentów statystycznych i zastosowań inne odpowiedzi dały dobre wyjaśnienia (sprawdź również odniesienia w tej odpowiedzi)


2

Przyjęta odpowiedź odwołuje się do Chuong EM Paper , który porządnie wyjaśnia EM. Istnieje również wideo z YouTube, które bardziej szczegółowo wyjaśnia artykuł.

Podsumowując, oto scenariusz:

1st:  {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd:  {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd:  {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th:  {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th:  {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails

Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.

We don't know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.

W przypadku pytania z pierwszej próby, intuicyjnie myślelibyśmy, że B wygenerował je, ponieważ proporcja orłów bardzo dobrze pasuje do odchylenia B ... ale ta wartość była tylko przypuszczeniem, więc nie możemy być pewni.

Mając to na uwadze, lubię myśleć o rozwiązaniu EM w następujący sposób:

  • Każda próba rzutów pozwala „głosować” na monetę, którą lubi najbardziej
    • Jest to oparte na tym, jak dobrze każda moneta pasuje do jej dystrybucji
    • LUB, z punktu widzenia monety, istnieje duże oczekiwanie, że zobaczymy tę próbę w porównaniu z inną monetą (na podstawie prawdopodobieństwa logów ).
  • W zależności od tego, jak bardzo każda próba podoba się każdej monecie, może zaktualizować przypuszczalny parametr tej monety (odchylenie).
    • Im bardziej próba lubi monetę, tym bardziej może zaktualizować nastawienie monety, aby odzwierciedlić własne!
    • Zasadniczo odchylenia monety są aktualizowane poprzez łączenie tych ważonych aktualizacji we wszystkich próbach, proces zwany ( maksymalizacją ), który odnosi się do próby uzyskania najlepszych przypuszczeń dla odchylenia każdej monety przy danym zestawie prób.

Może to być nadmierne uproszczenie (lub nawet fundamentalnie błędne na niektórych poziomach), ale mam nadzieję, że pomoże to na poziomie intuicyjnym!


1

EM służy do maksymalizacji prawdopodobieństwa modelu Q ze zmiennymi latentnymi Z.

To iteracyjna optymalizacja.

theta <- initial guess for hidden parameters
while not converged:
    #e-step
    Q(theta'|theta) = E[log L(theta|Z)]
    #m-step
    theta <- argmax_theta' Q(theta'|theta)

e-step: biorąc pod uwagę bieżące oszacowanie Z, oblicz oczekiwaną funkcję loglikwidencji

m-step: znajdź theta, który maksymalizuje to Q

Przykład GMM:

e-step: oszacowanie przypisań etykiet dla każdego punktu danych przy aktualnym oszacowaniu parametru gmm

m-step: maksymalizuj nowe theta biorąc pod uwagę nowe przypisania etykiet

K-średnie jest również algorytmem EM i istnieje wiele animacji wyjaśniających K-średnich.


1

Korzystając z tego samego artykułu autorstwa Do i Batzoglou, cytowanego w odpowiedzi Zhubarba, zaimplementowałem EM dla tego problemu w Javie . Komentarze do jego odpowiedzi pokazują, że algorytm utknie na lokalnym optimum, co również ma miejsce w mojej implementacji, jeśli parametry thetaA i thetaB są takie same.

Poniżej znajduje się standardowe wyjście mojego kodu, pokazujące zbieżność parametrów.

thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960

Poniżej znajduje się moja implementacja EM w Javie w celu rozwiązania problemu w (Do i Batzoglou, 2008). Podstawową częścią implementacji jest pętla do uruchamiania EM do momentu zbieżności parametrów.

private Parameters _parameters;

public Parameters run()
{
    while (true)
    {
        expectation();

        Parameters estimatedParameters = maximization();

        if (_parameters.converged(estimatedParameters)) {
            break;
        }

        _parameters = estimatedParameters;
    }

    return _parameters;
}

Poniżej znajduje się cały kod.

import java.util.*;

/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
    double _thetaA = 0.0; // Probability of heads for coin A.
    double _thetaB = 0.0; // Probability of heads for coin B.

    double _delta = 0.00001;

    public Parameters(double thetaA, double thetaB)
    {
        _thetaA = thetaA;
        _thetaB = thetaB;
    }

    /*************************************************************************
    Returns true if this parameter is close enough to another parameter
    (typically the estimated parameter coming from the maximization step).
    *************************************************************************/
    public boolean converged(Parameters other)
    {
        if (Math.abs(_thetaA - other._thetaA) < _delta &&
            Math.abs(_thetaB - other._thetaB) < _delta)
        {
            return true;
        }

        return false;
    }

    public double getThetaA()
    {
        return _thetaA;
    }

    public double getThetaB()
    {
        return _thetaB;
    }

    public String toString()
    {
        return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
    }

}


/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
    double _numHeads = 0;
    double _numTails = 0;

    public Observation(String s)
    {
        for (int i = 0; i < s.length(); i++)
        {
            char c = s.charAt(i);

            if (c == 'H')
            {
                _numHeads++;
            }
            else if (c == 'T')
            {
                _numTails++;
            }
            else
            {
                throw new RuntimeException("Unknown character: " + c);
            }
        }
    }

    public Observation(double numHeads, double numTails)
    {
        _numHeads = numHeads;
        _numTails = numTails;
    }

    public double getNumHeads()
    {
        return _numHeads;
    }

    public double getNumTails()
    {
        return _numTails;
    }

    public String toString()
    {
        return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
    }

}

/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
    // Current estimated parameters.
    private Parameters _parameters;

    // Observations from the trials. These observations are set once.
    private final List<Observation> _observations;

    // Estimated observations per coin. These observations are the output
    // of the expectation step.
    private List<Observation> _expectedObservationsForCoinA;
    private List<Observation> _expectedObservationsForCoinB;

    private static java.io.PrintStream o = System.out;

    /*************************************************************************
    Principal constructor.
    @param observations The observations from the trial.
    @param parameters The initial guessed parameters.
    *************************************************************************/
    public EM(List<Observation> observations, Parameters parameters)
    {
        _observations = observations;
        _parameters = parameters;
    }

    /*************************************************************************
    Run EM until parameters converge.
    *************************************************************************/
    public Parameters run()
    {

        while (true)
        {
            expectation();

            Parameters estimatedParameters = maximization();

            o.printf("%s\n", estimatedParameters);

            if (_parameters.converged(estimatedParameters)) {
                break;
            }

            _parameters = estimatedParameters;
        }

        return _parameters;

    }

    /*************************************************************************
    Given the observations and current estimated parameters, compute new
    estimated completions (distribution over the classes) and observations.
    *************************************************************************/
    private void expectation()
    {

        _expectedObservationsForCoinA = new ArrayList<Observation>();
        _expectedObservationsForCoinB = new ArrayList<Observation>();

        for (Observation observation : _observations)
        {
            int numHeads = (int)observation.getNumHeads();
            int numTails = (int)observation.getNumTails();

            double probabilityOfObservationForCoinA=
                binomialProbability(10, numHeads, _parameters.getThetaA());

            double probabilityOfObservationForCoinB=
                binomialProbability(10, numHeads, _parameters.getThetaB());

            double normalizer = probabilityOfObservationForCoinA +
                                probabilityOfObservationForCoinB;

            // Compute the completions for coin A and B (i.e. the probability
            // distribution of the two classes, summed to 1.0).

            double completionCoinA = probabilityOfObservationForCoinA /
                                     normalizer;
            double completionCoinB = probabilityOfObservationForCoinB /
                                     normalizer;

            // Compute new expected observations for the two coins.

            Observation expectedObservationForCoinA =
                new Observation(numHeads * completionCoinA,
                                numTails * completionCoinA);

            Observation expectedObservationForCoinB =
                new Observation(numHeads * completionCoinB,
                                numTails * completionCoinB);

            _expectedObservationsForCoinA.add(expectedObservationForCoinA);
            _expectedObservationsForCoinB.add(expectedObservationForCoinB);
        }
    }

    /*************************************************************************
    Given new estimated observations, compute new estimated parameters.
    *************************************************************************/
    private Parameters maximization()
    {

        double sumCoinAHeads = 0.0;
        double sumCoinATails = 0.0;
        double sumCoinBHeads = 0.0;
        double sumCoinBTails = 0.0;

        for (Observation observation : _expectedObservationsForCoinA)
        {
            sumCoinAHeads += observation.getNumHeads();
            sumCoinATails += observation.getNumTails();
        }

        for (Observation observation : _expectedObservationsForCoinB)
        {
            sumCoinBHeads += observation.getNumHeads();
            sumCoinBTails += observation.getNumTails();
        }

        return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
                              sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));

        //o.printf("parameters: %s\n", _parameters);

    }

    /*************************************************************************
    Since the coin-toss experiment posed in this article is a Bernoulli trial,
    use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
    *************************************************************************/
    private static double binomialProbability(int n, int k, double p)
    {
        double q = 1.0 - p;
        return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
    }

    private static long nChooseK(int n, int k)
    {
        long numerator = 1;

        for (int i = 0; i < k; i++)
        {
            numerator = numerator * n;
            n--;
        }

        long denominator = factorial(k);

        return (long)(numerator / denominator);
    }

    private static long factorial(int n)
    {
        long result = 1;
        for (; n >0; n--)
        {
            result = result * n;
        }

        return result;
    }

    /*************************************************************************
    Entry point into the program.
    *************************************************************************/
    public static void main(String argv[])
    {
        // Create the observations and initial parameter guess
        // from the (Do and Batzoglou, 2008) article.

        List<Observation> observations = new ArrayList<Observation>();
        observations.add(new Observation("HTTTHHTHTH"));
        observations.add(new Observation("HHHHTHHHHH"));
        observations.add(new Observation("HTHHHHHTHH"));
        observations.add(new Observation("HTHTTTHHTT"));
        observations.add(new Observation("THHHTHHHTH"));

        Parameters initialParameters = new Parameters(0.6, 0.5);

        EM em = new EM(observations, initialParameters);

        Parameters finalParameters = em.run();

        o.printf("Final result:\n%s\n", finalParameters);
    }
}
Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.