다음은 평균 및 표준 편차를 추정하는 데 사용되는 예상 최대화 (EM)의 예입니다. 코드는 파이썬으로되어 있지만, 언어에 익숙하지 않아도 쉽게 따라갈 수 있어야합니다.
EM의 동기
아래 표시된 빨간색과 파란색 점은 각각 특정 평균과 표준 편차가있는 두 개의 서로 다른 정규 분포에서 가져옵니다.
적색 분포에 대한 "true"평균 및 표준 편차 매개 변수의 합리적인 근사값을 계산하기 위해 적색 점을 매우 쉽게보고 각 위치를 기록한 다음 익숙한 공식을 사용합니다 (청색 그룹과 유사). .
이제 두 그룹의 점이 있다는 것을 알고 있지만 어떤 점이 어떤 그룹에 속하는지 알 수 없습니다. 즉, 색상이 숨겨집니다.
포인트를 두 그룹으로 나누는 방법은 전혀 분명하지 않습니다. 이제 빨간색 분포 또는 파란색 분포의 모수에 대한 위치 및 계산 추정값 만 볼 수 없습니다.
여기서 EM을 사용하여 문제를 해결할 수 있습니다.
EM을 사용하여 모수 추정
위의 포인트를 생성하는 데 사용되는 코드는 다음과 같습니다. 점이 도출 된 정규 분포의 실제 평균과 표준 편차를 볼 수 있습니다. 변수 red
와 blue
각각 빨간색과 파란색 그룹의 각 지점의 위치를 잡아 :
import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible random results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
both_colours = np.sort(np.concatenate((red, blue)))
각 점의 색상을 볼 수 있다면 라이브러리 함수를 사용하여 평균과 표준 편차를 복구하려고 시도합니다.
>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195
그러나 색상이 숨겨져 있기 때문에 EM 프로세스를 시작합니다 ...
먼저 각 그룹의 매개 변수 값을 추측합니다 ( 1 단계 ). 이 추측은 좋을 필요는 없습니다.
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7
꽤 나쁜 추측-수단은 점 그룹의 "중간"에서 멀리 떨어져있는 것처럼 보입니다.
EM을 계속 유지하고 이러한 추측을 개선하기 위해 평균 및 표준 편차에 대한 이러한 추측 아래에 각 데이터 포인트 (비밀 색상에 관계없이)가 나타날 가능성을 계산합니다 ( 2 단계 ).
변수 both_colours
는 각 데이터 포인트를 보유합니다. 이 함수 stats.norm
는 주어진 모수를 사용하여 정규 분포에서 점의 확률을 계산합니다.
likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)
예를 들어, 현재 추측에 따르면 1.761의 데이터 포인트는 파란색 (0.00003)보다 빨간색 (0.189) 일 가능성이 훨씬 높습니다.
이 두 가능성 값을 가중치 ( 3 단계 )로 변환하여 다음과 같이 1의 합을 구할 수 있습니다 .
likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total
현재 추정치와 새로 계산 된 가중치를 사용하여 모수에 대한 새 추정치를 더 잘 계산할 수 있습니다 ( 4 단계 ). 평균에 대한 함수와 표준 편차에 대한 함수가 필요합니다.
def estimate_mean(data, weight):
return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
return np.sqrt(variance)
이들은 데이터의 평균 및 표준 편차에 대한 일반적인 기능과 매우 유사합니다. 차이점은 weight
각 데이터 포인트에 가중치를 할당하는 매개 변수 사용입니다 .
이 가중치는 EM의 핵심입니다. 데이터 포인트에서 색상의 가중치가 클수록 데이터 포인트는 해당 색상의 매개 변수에 대한 다음 추정에 더 많은 영향을 미칩니다. 궁극적으로 각 매개 변수를 올바른 방향으로 당기는 효과가 있습니다.
새로운 추측은 다음 기능으로 계산됩니다.
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)
그런 다음 2 단계 이후의 새로운 추측으로 EM 프로세스가 반복됩니다. 주어진 반복 횟수 (예 : 20) 또는 매개 변수가 수렴 될 때까지 단계를 반복 할 수 있습니다.
5 번의 반복 후, 초기 잘못된 추측이 나아지기 시작합니다.
20 회 반복 후 EM 프로세스는 다소 수렴되었습니다.
비교를 위해 다음은 색상 정보가 숨겨지지 않은 경우 계산 된 값과 비교 한 EM 프로세스의 결과입니다.
| EM guess | Actual
----------+----------+--------
Red mean | 2.910 | 2.802
Red std | 0.854 | 0.871
Blue mean | 6.838 | 6.932
Blue std | 2.227 | 2.195
참고 :이 답변은 여기 에서 스택 오버플로에 대한 답변에서 수정되었습니다 .