Oto przykład metody Expectation Maximization (EM) zastosowanej do oszacowania średniej i odchylenia standardowego. Kod jest w języku Python, ale powinien być łatwy do naśladowania, nawet jeśli nie znasz języka.
Motywacja do EM
Przedstawione poniżej czerwone i niebieskie punkty pochodzą z dwóch różnych rozkładów normalnych, z których każdy ma określoną średnią i standardowe odchylenie:
Aby obliczyć rozsądne aproksymacje „prawdziwych” średnich i standardowych parametrów odchylenia dla rozkładu czerwonego, możemy bardzo łatwo spojrzeć na czerwone punkty i zapisać położenie każdego z nich, a następnie użyć znanych wzorów (i podobnie dla grupy niebieskiej) .
Rozważmy teraz przypadek, w którym wiemy, że istnieją dwie grupy punktów, ale nie widzimy, który punkt należy do której grupy. Innymi słowy, kolory są ukryte:
Nie jest wcale oczywiste, jak podzielić punkty na dwie grupy. Nie jesteśmy teraz w stanie spojrzeć na pozycje i obliczyć oszacowań parametrów rozkładu czerwonego lub niebieskiego.
Tutaj EM można wykorzystać do rozwiązania problemu.
Wykorzystanie EM do oszacowania parametrów
Oto kod używany do generowania punktów pokazanych powyżej. Możesz zobaczyć rzeczywiste średnie i standardowe odchylenia rozkładów normalnych, z których zostały narysowane punkty. Zmienne red
i blue
utrzymują pozycje każdego punktu odpowiednio w grupie czerwonej i niebieskiej:
import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible random results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
both_colours = np.sort(np.concatenate((red, blue)))
Gdybyśmy mogli zobaczyć kolor każdego punktu, spróbowalibyśmy odzyskać średnie i standardowe odchylenia za pomocą funkcji bibliotecznych:
>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195
Ale ponieważ kolory są przed nami ukryte, rozpoczniemy proces EM ...
Po pierwsze, domyślamy się tylko wartości parametrów każdej grupy ( krok 1 ). Te domysły nie muszą być dobre:
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7
Całkiem złe domysły - środki wyglądają, jakby były daleko od jakiegokolwiek „środka” grupy punktów.
Aby kontynuować EM i poprawić te przypuszczenia, obliczamy prawdopodobieństwo, że każdy punkt danych (niezależnie od jego tajnego koloru) pojawi się pod tymi przypuszczeniami dla średniej i odchylenia standardowego ( krok 2 ).
Zmienna both_colours
przechowuje każdy punkt danych. Funkcja stats.norm
oblicza prawdopodobieństwo punktu o rozkładzie normalnym na podstawie podanych parametrów:
likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)
To mówi nam na przykład, że przy naszych bieżących przypuszczeniach punkt danych na 1,761 jest znacznie bardziej czerwony (0,189) niż niebieski (0,00003).
Możemy zamienić te dwie wartości prawdopodobieństwa na wagi ( krok 3 ), aby sumowały się do 1 w następujący sposób:
likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total
Dzięki naszym bieżącym oszacowaniom i naszym nowo obliczonym wagom możemy teraz obliczyć nowe, prawdopodobnie lepsze oszacowania parametrów ( krok 4 ). Potrzebujemy funkcji dla średniej i funkcji dla odchylenia standardowego:
def estimate_mean(data, weight):
return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
return np.sqrt(variance)
Wyglądają one bardzo podobnie do zwykłych funkcji do średniej i odchylenia standardowego danych. Różnica polega na zastosowaniu weight
parametru, który przypisuje wagę do każdego punktu danych.
Ta waga jest kluczem do EM. Im większa waga koloru w punkcie danych, tym bardziej punkt danych wpływa na następne oszacowania parametrów tego koloru. Ostatecznie powoduje to pociągnięcie każdego parametru we właściwym kierunku.
Nowe domysły są obliczane za pomocą tych funkcji:
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)
Proces EM jest następnie powtarzany z tymi nowymi domysłami, począwszy od kroku 2. Możemy powtórzyć kroki dla danej liczby iteracji (powiedzmy 20) lub dopóki nie zobaczymy, że parametry się zbiegają.
Po pięciu iteracjach widzimy, że nasze początkowe błędne domysły zaczynają się poprawiać:
Po 20 iteracjach proces EM zbliżył się mniej więcej:
Dla porównania, oto wyniki procesu EM w porównaniu z wartościami obliczonymi, gdy informacja o kolorze nie jest ukryta:
| EM guess | Actual
----------+----------+--------
Red mean | 2.910 | 2.802
Red std | 0.854 | 0.871
Blue mean | 6.838 | 6.932
Blue std | 2.227 | 2.195
Uwaga: ta odpowiedź została zaadaptowana z mojej odpowiedzi na temat Przepełnienia stosu tutaj .