EM (Expectation Maximization)은 데이터를 분류하는 일종의 확률 적 방법입니다. 분류자가 아닌 경우 내가 틀렸다면 나를 수정하십시오.
이 EM 기술에 대한 직관적 인 설명은 무엇입니까? 무엇 expectation
여기 존재는 무엇인가 maximized
?
EM (Expectation Maximization)은 데이터를 분류하는 일종의 확률 적 방법입니다. 분류자가 아닌 경우 내가 틀렸다면 나를 수정하십시오.
이 EM 기술에 대한 직관적 인 설명은 무엇입니까? 무엇 expectation
여기 존재는 무엇인가 maximized
?
답변:
참고 :이 답변의 코드는 여기 에서 찾을 수 있습니다 .
빨간색과 파란색의 두 그룹에서 샘플링 된 데이터가 있다고 가정합니다.
여기에서 어떤 데이터 포인트가 빨간색 또는 파란색 그룹에 속하는지 확인할 수 있습니다. 이를 통해 각 그룹을 특징 짓는 매개 변수를 쉽게 찾을 수 있습니다. 예를 들어, 빨간색 그룹의 평균은 약 3이고 파란색 그룹의 평균은 약 7입니다 (원하는 경우 정확한 평균을 찾을 수 있음).
이것은 일반적으로 최대 가능성 추정으로 알려져 있습니다. 일부 데이터가 주어지면 해당 데이터를 가장 잘 설명하는 매개 변수 (또는 매개 변수)의 값을 계산합니다.
이제 어떤 값이 어떤 그룹에서 샘플링되었는지 알 수 없다고 상상해보십시오 . 모든 것이 우리에게 보라색으로 보입니다.
여기에 두 그룹의 값이 있다는 것을 알고 있지만 특정 값이 속한 그룹은 알 수 없습니다.
이 데이터에 가장 적합한 빨간색 그룹과 파란색 그룹의 평균을 여전히 추정 할 수 있습니까?
예, 종종 할 수 있습니다! 기대 최대화 는 우리에게이를위한 방법을 제공합니다. 알고리즘의 가장 일반적인 아이디어는 다음과 같습니다.
이 단계에는 추가 설명이 필요하므로 위에서 설명한 문제를 살펴 보겠습니다.
이 예제에서는 Python을 사용하지만이 언어에 익숙하지 않은 경우 코드를 이해하기가 매우 쉽습니다.
위의 이미지와 같이 값이 분포 된 빨강과 파랑의 두 그룹이 있다고 가정합니다. 특히 각 그룹에는 다음 모수를 사용하여 정규 분포 에서 가져온 값이 포함 됩니다.
import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
both_colours = np.sort(np.concatenate((red, blue))) # for later use...
다음은 이러한 빨간색 및 파란색 그룹의 이미지입니다 (위로 스크롤하지 않아도 됨).
각 포인트의 색상 (즉, 그것이 속한 그룹)을 볼 수있을 때 각 그룹의 평균과 표준 편차를 추정하는 것은 매우 쉽습니다. NumPy의 내장 함수에 빨강 및 파랑 값을 전달합니다. 예를 들면 :
>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195
하지만 포인트의 색상을 볼 수 없다면 어떨까요? 즉, 빨간색 또는 파란색 대신 모든 점이 자주색으로 표시됩니다.
빨간색 및 파란색 그룹에 대한 평균 및 표준 편차 매개 변수를 복구하기 위해 기대 최대화를 사용할 수 있습니다.
첫 번째 단계 ( 단계 1 위)의 각 그룹의 평균 및 표준 편차에 대한 파라미터 값을 추측한다. 우리는 지능적으로 추측 할 필요가 없습니다. 원하는 숫자를 선택할 수 있습니다.
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7
이러한 모수 추정값은 다음과 같은 종형 곡선을 생성합니다.
이것은 잘못된 추정치입니다. 예를 들어, 두 가지 모두 (수직 점선) 의미있는 점 그룹에 대해 모든 종류의 "중간"에서 멀리 보입니다. 우리는 이러한 추정치를 개선하고자합니다.
다음 단계 ( 2 단계 )는 현재 매개 변수 추측 아래에 나타나는 각 데이터 포인트의 가능성을 계산하는 것입니다.
likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)
여기서 우리는 빨강과 파랑에 대한 평균과 표준 편차에 대한 현재 추측을 사용하여 정규 분포에 대한 확률 밀도 함수 에 각 데이터 포인트를 간단히 넣었습니다 . 예를 들어, 현재 추측을 통해 1.761의 데이터 포인트 가 파란색 (0.00003)보다 빨간색 (0.189) 일 가능성 이 훨씬 더 높다는 것을 알 수 있습니다.
각 데이터 포인트에 대해이 두 우도 값을 가중치로 변환하여 ( 3 단계 ) 다음과 같이 합계가 1이되도록 할 수 있습니다.
likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total
현재 추정치와 새로 계산 된 가중치 를 사용하여 이제 빨간색 및 파란색 그룹의 평균 및 표준 편차에 대한 새 추정치를 계산할 수 있습니다 ( 4 단계 ).
모든 데이터 포인트를 사용하여 평균과 표준 편차를 두 번 계산 하지만 다른 가중치를 사용합니다. 한 번은 빨간색 가중치에 대해 한 번은 파란색 가중치에 대해 한 번입니다.
직관의 핵심은 데이터 포인트에서 색상의 가중치가 클수록 데이터 포인트가 해당 색상의 매개 변수에 대한 다음 추정치에 더 많은 영향을 미친다는 것입니다. 이는 매개 변수를 올바른 방향으로 "당기는"효과가 있습니다.
def estimate_mean(data, weight):
"""
For each data point, multiply the point by the probability it
was drawn from the colour's distribution (its "weight").
Divide by the total weight: essentially, we're finding where
the weight is centred among our data points.
"""
return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
"""
For each data point, multiply the point's squared difference
from a mean value by the probability it was drawn from
that distribution (its "weight").
Divide by the total weight: essentially, we're finding where
the weight is centred among the values for the difference of
each data point from the mean.
This is the estimate of the variance, take the positive square
root to find the standard deviation.
"""
variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
return np.sqrt(variance)
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)
매개 변수에 대한 새로운 추정치가 있습니다. 다시 개선하기 위해 2 단계로 돌아가 프로세스를 반복 할 수 있습니다. 추정값이 수렴 될 때까지 또는 몇 번의 반복이 수행 된 후에이 작업을 수행합니다 ( 5 단계 ).
데이터의 경우이 프로세스의 처음 5 개 반복은 다음과 같습니다 (최근 반복의 모양이 더 강함).
평균이 이미 일부 값에 수렴되고 있으며 곡선의 모양 (표준 편차에 의해 제어 됨)도 더욱 안정되고 있음을 알 수 있습니다.
20 번 반복하면 다음과 같이됩니다.
EM 프로세스는 다음 값으로 수렴되어 실제 값에 매우 가깝습니다 (색상을 볼 수 있으며 숨겨진 변수 없음).
| EM guess | Actual | Delta
----------+----------+--------+-------
Red mean | 2.910 | 2.802 | 0.108
Red std | 0.854 | 0.871 | -0.017
Blue mean | 6.838 | 6.932 | -0.094
Blue std | 2.227 | 2.195 | 0.032
위의 코드에서 표준 편차에 대한 새로운 추정이 평균에 대한 이전 반복 추정치를 사용하여 계산되었음을 알 수 있습니다. 궁극적으로 우리가 어떤 중심점 주변에서 값의 (가중치) 분산을 찾는 것이므로 먼저 평균에 대한 새 값을 계산하는지 여부는 중요하지 않습니다. 모수에 대한 추정치는 여전히 수렴됩니다.
EM은 모델의 일부 변수가 관찰되지 않을 때 (즉, 잠재 변수가있는 경우) 우도 함수를 최대화하는 알고리즘입니다.
우리가 단지 기능을 최대화하려고한다면, 기능을 최대화하기 위해 기존의 기계를 사용하는 것이 어떨까요? 음, 미분을 취하고 0으로 설정하여이를 최대화하려고하면 많은 경우 1 차 조건에 해가 없다는 것을 알 수 있습니다. 모델 매개 변수를 해결하려면 관측되지 않은 데이터의 분포를 알아야한다는 점에서 닭과 계란 문제가 있습니다. 그러나 관찰되지 않은 데이터의 분포는 모델 매개 변수의 함수입니다.
EM은 관찰되지 않은 데이터에 대한 분포를 반복적으로 추측 한 다음 실제 우도 함수의 하한값을 최대화하고 수렴 될 때까지 반복하여 모델 매개 변수를 추정하여이 문제를 해결하려고합니다.
EM 알고리즘
모델 매개 변수 값에 대한 추측으로 시작
E- 단계 : 결 측값이있는 각 데이터 포인트에 대해 모델 모수에 대한 현재 추측과 관측 된 데이터가 주어지면 모델 방정식을 사용하여 결측 데이터의 분포를 구합니다 (각 결 측값에 대한 분포를 풀고 있음에 유의하십시오. 예상 값이 아닌 값). 이제 각 결 측값에 대한 분포 가 있으므로 관측되지 않은 변수에 대한 우도 함수 의 기대치 를 계산할 수 있습니다 . 모델 모수에 대한 추측이 정확하다면이 예상 우도는 관측 된 데이터의 실제 우도가 될 것입니다. 매개 변수가 올바르지 않으면 하한이됩니다.
M 단계 : 이제 관측되지 않은 변수가없는 예상 우도 함수가 있으므로 완전히 관측 된 경우 에서처럼 함수를 최대화하여 모델 매개 변수의 새로운 추정치를 얻습니다.
수렴 될 때까지 반복합니다.
기대 최대화 알고리즘을 이해하기위한 간단한 방법은 다음과 같습니다.
1- Do 및 Batzoglou의 EM 튜토리얼 문서 를 읽으십시오 .
2- 머릿속에 물음표가있을 수 있습니다.이 수학 스택 교환 페이지 에 대한 설명을 살펴보십시오 .
3- 항목 1의 EM 자습서 문서에서 예제를 설명하는 Python으로 작성한이 코드를보십시오.
경고 : 저는 Python 개발자가 아니기 때문에 코드가 지저분하거나 차선책 일 수 있습니다. 그러나 그것은 일을합니다.
import numpy as np
import math
#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* ####
def get_mn_log_likelihood(obs,probs):
""" Return the (log)likelihood of obs, given the probs"""
# Multinomial Distribution Log PMF
# ln (pdf) = multinomial coeff * product of probabilities
# ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]
multinomial_coeff_denom= 0
prod_probs = 0
for x in range(0,len(obs)): # loop through state counts in each observation
multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
prod_probs = prod_probs + obs[x]*math.log(probs[x])
multinomial_coeff = math.log(math.factorial(sum(obs))) - multinomial_coeff_denom
likelihood = multinomial_coeff + prod_probs
return likelihood
# 1st: Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd: Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd: Coin A, {HTHHHHHTHH}, 8H,2T
# 4th: Coin B, {HTHTTTHHTT}, 4H,6T
# 5th: Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45
# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = zip(head_counts,tail_counts)
# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50
# E-M begins!
delta = 0.001
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
expectation_A = np.zeros((5,2), dtype=float)
expectation_B = np.zeros((5,2), dtype=float)
for i in range(0,len(experiments)):
e = experiments[i] # i'th experiment
ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B
weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A
weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B
expectation_A[i] = np.dot(weightA, e)
expectation_B[i] = np.dot(weightB, e)
pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A));
pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B));
improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
j = j+1
기술적 용어 "EM은"약간 underspecified,하지만 난 당신이 인 가우시안 혼합 모델링 클러스터 분석 기법을 참조 가정 예를 일반 EM 원리.
실제로 EM 클러스터 분석은 분류자가 아닙니다 . 어떤 사람들은 클러스터링을 "비지도 분류"라고 생각하지만 실제로는 클러스터 분석은 상당히 다릅니다.
클러스터 분석에서 사람들이 항상 가지고있는 주요 차이점과 분류 오해는 클러스터 분석 에서 "올바른 솔루션"이 없다는 것입니다 . 그것은 지식 발견 방법이며 실제로 새로운 것을 찾는 것을 의미합니다 ! 이것은 평가를 매우 어렵게 만듭니다. 종종 알려진 분류를 참조로 사용하여 평가되지만 항상 적절한 것은 아닙니다. 분류가 데이터에있는 내용을 반영 할 수도 있고 반영하지 않을 수도 있습니다.
예를 들어 보겠습니다. 성별 데이터를 포함한 대규모 고객 데이터 세트가 있습니다. 이 데이터 세트를 "남성"및 "여성"으로 분할하는 방법은 기존 클래스와 비교할 때 최적입니다. 새로운 사용자의 경우 이제 성별을 예측할 수 있으므로 "예측"방식으로 이것이 좋다고 생각합니다. 데이터에서 새로운 구조 를 발견하고 싶었 기 때문에 "지식 발견"방식으로 이것은 실제로 좋지 않습니다 . 예를 들어 데이터를 노인과 어린이로 분할하는 방법 은 남성 / 여성 클래스와 관련하여 얻을 수있는 것만 큼 점수 가 더 낮습니다 . 그러나 이는 훌륭한 클러스터링 결과가 될 것입니다 (연령이 지정되지 않은 경우).
이제 EM으로 돌아갑니다. 기본적으로 데이터가 다중 다변량 정규 분포로 구성되어 있다고 가정합니다 (이는 특히 클러스터 수를 수정할 때 매우 강력한 가정입니다!). 그런 다음 모델과 모델에 대한 객체 할당을 교대로 개선하여 이에 대한 로컬 최적 모델을 찾으려고합니다 .
분류 컨텍스트에서 최상의 결과를 얻으려면 더 큰 클러스터 수를 선택하십시오. 클래스 수보다 하거나 클러스터링을 단일 클래스에만 적용하십시오 (클래스 내에 구조가 있는지 확인하기 위해!).
"자동차", "자전거"및 "트럭"을 구분하도록 분류기를 훈련시키고 싶다고 가정 해 보겠습니다. 데이터가 정확히 3 개의 정규 분포로 구성되어 있다고 가정하는 데는 거의 사용되지 않습니다. 그러나 두 가지 이상의 자동차 유형 (및 트럭 및 자전거) 이 있다고 가정 할 수 있습니다 . 따라서이 세 클래스에 대한 분류기를 훈련하는 대신 자동차, 트럭 및 자전거를 각각 10 개의 클러스터 (또는 자동차 10 대, 트럭 3 대, 자전거 3 대 등)로 묶은 다음 분류기를 훈련하여이 30 개의 클래스를 구분 한 다음 클래스 결과를 원래 클래스로 다시 병합하십시오. 예를 들어 Trikes와 같이 특히 분류하기 어려운 클러스터가 하나 있음을 발견 할 수도 있습니다. 그들은 다소 자동차이고 다소 자전거입니다. 또는 배달 트럭은 트럭 이라기보다 대형차와 비슷합니다.
다른 답변이 좋으면 다른 관점을 제공하고 질문의 직관적 인 부분을 다루려고 노력할 것입니다.
EM (Expectation-Maximization) 알고리즘 은 이중성을 사용하는 반복 알고리즘 클래스의 변형입니다.
발췌 (강조 내) :
수학에서 이원성은 일반적으로 말해서 개념, 정리 또는 수학적 구조를 일대일 방식으로 다른 개념, 정리 또는 구조로, 종종 (항상은 아님) 혁명 연산을 통해 변환합니다. A는 B이고 B의 쌍대는 A입니다. 이러한 인볼 루션에는 때때로 고정 된 점이 있으므로 A의 쌍대는 A 자체입니다.
일반적으로 객체 A 의 이중 B는 대칭 또는 호환성 을 유지하는 방식으로 A와 관련 됩니다 . 예 : AB = const
이중성을 사용하는 반복 알고리즘의 예 (이전 의미에서)는 다음과 같습니다.
비슷한 방식으로 EM 알고리즘은 두 개의 이중 최대화 단계로 볼 수도 있습니다 .
.. [EM]은 매개 변수와 관측되지 않은 변수에 대한 분포의 결합 함수를 최대화하는 것으로 간주됩니다. E- 단계는 관측되지 않은 변수에 대한 분포와 관련하여이 함수를 최대화합니다. 매개 변수에 대한 M 단계 ..
이중성을 사용하는 반복 알고리즘에는 평형 (또는 고정) 수렴 점에 대한 명시 적 (또는 암시 적) 가정이 있습니다 (EM의 경우 이는 Jensen의 부등식을 사용하여 증명 됨).
따라서 이러한 알고리즘의 개요는 다음과 같습니다.
참고 하는 전역 최적 그러한 알고리즘의 수렴은 그것을 구성하는 볼 때 모두 감각 최상의 (즉, 양의 X의 도메인 / 파라미터와 Y 도메인 / 매개 변수). 그러나 알고리즘은 전역 최적이 아닌 로컬 최적을 찾을 수 있습니다 .
이것은 알고리즘의 개요에 대한 직관적 인 설명이라고 말하고 싶습니다.
통계적 인수 및 응용 프로그램의 경우 다른 답변이 좋은 설명을 제공했습니다 (이 답변의 참조도 확인하십시오)
받아 들여지는 대답은 Chuong EM Paper를 참조하며 EM을 설명하는 적절한 작업을 수행합니다. 논문을 더 자세히 설명 하는 유튜브 영상 도 있습니다 .
요약하면 다음과 같은 시나리오가 있습니다.
1st: {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd: {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd: {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th: {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th: {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails
Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.
We don't know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.
첫 번째 시도의 질문의 경우 앞면 비율이 B의 편향과 매우 잘 일치하기 때문에 직관적으로 B가 생성했다고 생각할 수 있지만 그 값은 추측 일 뿐이므로 확신 할 수 없습니다.
이를 염두에두고 다음과 같은 EM 솔루션을 생각하고 싶습니다.
이것은 지나치게 단순화 (또는 일부 수준에서는 근본적으로 잘못된 것) 일 수 있지만 직관적 인 수준에서 도움이되기를 바랍니다.
EM은 잠재 변수 Z가있는 모델 Q의 가능성을 최대화하는 데 사용됩니다.
반복적 인 최적화입니다.
theta <- initial guess for hidden parameters
while not converged:
#e-step
Q(theta'|theta) = E[log L(theta|Z)]
#m-step
theta <- argmax_theta' Q(theta'|theta)
e-step : Z의 현재 추정치가 주어지면 예상 로그 가능도 함수를 계산합니다.
m-step :이 Q를 최대화하는 theta 찾기
GMM 예 :
e-step : 현재 gmm-parameter 추정치에 따라 각 데이터 포인트에 대한 라벨 할당 추정
m-step : 새 레이블 할당이 주어지면 새 세타 최대화
K- 평균은 EM 알고리즘이며 K- 평균에 대한 설명 애니메이션이 많이 있습니다.
Zhubarb의 답변에서 인용 한 Do와 Batzoglou의 동일한 기사를 사용하여 Java 에서 해당 문제에 대해 EM을 구현했습니다. . 그의 답변에 대한 의견은 알고리즘이 로컬 최적에 고정되어 있음을 보여 주며 매개 변수 thetaA와 thetaB가 동일한 경우 내 구현에서도 발생합니다.
아래는 매개 변수의 수렴을 보여주는 내 코드의 표준 출력입니다.
thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960
아래는 문제를 해결하기위한 EM의 Java 구현입니다 (Do and Batzoglou, 2008). 구현의 핵심 부분은 매개 변수가 수렴 될 때까지 EM을 실행하는 루프입니다.
private Parameters _parameters;
public Parameters run()
{
while (true)
{
expectation();
Parameters estimatedParameters = maximization();
if (_parameters.converged(estimatedParameters)) {
break;
}
_parameters = estimatedParameters;
}
return _parameters;
}
아래는 전체 코드입니다.
import java.util.*;
/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
double _thetaA = 0.0; // Probability of heads for coin A.
double _thetaB = 0.0; // Probability of heads for coin B.
double _delta = 0.00001;
public Parameters(double thetaA, double thetaB)
{
_thetaA = thetaA;
_thetaB = thetaB;
}
/*************************************************************************
Returns true if this parameter is close enough to another parameter
(typically the estimated parameter coming from the maximization step).
*************************************************************************/
public boolean converged(Parameters other)
{
if (Math.abs(_thetaA - other._thetaA) < _delta &&
Math.abs(_thetaB - other._thetaB) < _delta)
{
return true;
}
return false;
}
public double getThetaA()
{
return _thetaA;
}
public double getThetaB()
{
return _thetaB;
}
public String toString()
{
return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
}
}
/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
double _numHeads = 0;
double _numTails = 0;
public Observation(String s)
{
for (int i = 0; i < s.length(); i++)
{
char c = s.charAt(i);
if (c == 'H')
{
_numHeads++;
}
else if (c == 'T')
{
_numTails++;
}
else
{
throw new RuntimeException("Unknown character: " + c);
}
}
}
public Observation(double numHeads, double numTails)
{
_numHeads = numHeads;
_numTails = numTails;
}
public double getNumHeads()
{
return _numHeads;
}
public double getNumTails()
{
return _numTails;
}
public String toString()
{
return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
}
}
/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
// Current estimated parameters.
private Parameters _parameters;
// Observations from the trials. These observations are set once.
private final List<Observation> _observations;
// Estimated observations per coin. These observations are the output
// of the expectation step.
private List<Observation> _expectedObservationsForCoinA;
private List<Observation> _expectedObservationsForCoinB;
private static java.io.PrintStream o = System.out;
/*************************************************************************
Principal constructor.
@param observations The observations from the trial.
@param parameters The initial guessed parameters.
*************************************************************************/
public EM(List<Observation> observations, Parameters parameters)
{
_observations = observations;
_parameters = parameters;
}
/*************************************************************************
Run EM until parameters converge.
*************************************************************************/
public Parameters run()
{
while (true)
{
expectation();
Parameters estimatedParameters = maximization();
o.printf("%s\n", estimatedParameters);
if (_parameters.converged(estimatedParameters)) {
break;
}
_parameters = estimatedParameters;
}
return _parameters;
}
/*************************************************************************
Given the observations and current estimated parameters, compute new
estimated completions (distribution over the classes) and observations.
*************************************************************************/
private void expectation()
{
_expectedObservationsForCoinA = new ArrayList<Observation>();
_expectedObservationsForCoinB = new ArrayList<Observation>();
for (Observation observation : _observations)
{
int numHeads = (int)observation.getNumHeads();
int numTails = (int)observation.getNumTails();
double probabilityOfObservationForCoinA=
binomialProbability(10, numHeads, _parameters.getThetaA());
double probabilityOfObservationForCoinB=
binomialProbability(10, numHeads, _parameters.getThetaB());
double normalizer = probabilityOfObservationForCoinA +
probabilityOfObservationForCoinB;
// Compute the completions for coin A and B (i.e. the probability
// distribution of the two classes, summed to 1.0).
double completionCoinA = probabilityOfObservationForCoinA /
normalizer;
double completionCoinB = probabilityOfObservationForCoinB /
normalizer;
// Compute new expected observations for the two coins.
Observation expectedObservationForCoinA =
new Observation(numHeads * completionCoinA,
numTails * completionCoinA);
Observation expectedObservationForCoinB =
new Observation(numHeads * completionCoinB,
numTails * completionCoinB);
_expectedObservationsForCoinA.add(expectedObservationForCoinA);
_expectedObservationsForCoinB.add(expectedObservationForCoinB);
}
}
/*************************************************************************
Given new estimated observations, compute new estimated parameters.
*************************************************************************/
private Parameters maximization()
{
double sumCoinAHeads = 0.0;
double sumCoinATails = 0.0;
double sumCoinBHeads = 0.0;
double sumCoinBTails = 0.0;
for (Observation observation : _expectedObservationsForCoinA)
{
sumCoinAHeads += observation.getNumHeads();
sumCoinATails += observation.getNumTails();
}
for (Observation observation : _expectedObservationsForCoinB)
{
sumCoinBHeads += observation.getNumHeads();
sumCoinBTails += observation.getNumTails();
}
return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));
//o.printf("parameters: %s\n", _parameters);
}
/*************************************************************************
Since the coin-toss experiment posed in this article is a Bernoulli trial,
use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
*************************************************************************/
private static double binomialProbability(int n, int k, double p)
{
double q = 1.0 - p;
return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
}
private static long nChooseK(int n, int k)
{
long numerator = 1;
for (int i = 0; i < k; i++)
{
numerator = numerator * n;
n--;
}
long denominator = factorial(k);
return (long)(numerator / denominator);
}
private static long factorial(int n)
{
long result = 1;
for (; n >0; n--)
{
result = result * n;
}
return result;
}
/*************************************************************************
Entry point into the program.
*************************************************************************/
public static void main(String argv[])
{
// Create the observations and initial parameter guess
// from the (Do and Batzoglou, 2008) article.
List<Observation> observations = new ArrayList<Observation>();
observations.add(new Observation("HTTTHHTHTH"));
observations.add(new Observation("HHHHTHHHHH"));
observations.add(new Observation("HTHHHHHTHH"));
observations.add(new Observation("HTHTTTHHTT"));
observations.add(new Observation("THHHTHHHTH"));
Parameters initialParameters = new Parameters(0.6, 0.5);
EM em = new EM(observations, initialParameters);
Parameters finalParameters = em.run();
o.printf("Final result:\n%s\n", finalParameters);
}
}