EM 算法python實現

from scipy import stats
import numpy as np

observations = np.array([[1, 0, 0, 0, 1, 1, 0, 1, 0, 1],
                         [1, 1, 1, 1, 0, 1, 1, 1, 1, 1],
                         [1, 0, 1, 1, 1, 1, 1, 0, 1, 1],
                         [1, 0, 1, 0, 0, 0, 1, 1, 0, 0],
                         [0, 1, 1, 1, 0, 1, 1, 1, 0, 1]])
def em(observations, prior, tol=1e-6, iterations=10):

    import math
    iteration = 0
    while iteration < iterations:
    	print "iter %s" % iteration
        new_prior = em_single(prior, observations)
        delta_change = np.abs(prior[0] - new_prior[0])
        if delta_change < tol:
            break
        else:
            prior = new_prior
            iteration += 1
        print '\n'
    return [new_prior, iteration]

def em_single(priors, observations):
	"""
	EM
	Arguments
	---------
	priors : [theta_A, theta_B]
	observations : [m X n matrix]

	Returns
	--------
	new_priors: [new_theta_A, new_theta_B]
	:param priors:
	:param observations:
	:return:
	"""
	counts = {'A': {'H': 0, 'T': 0}, 'B': {'H': 0, 'T': 0}}
	theta_A = priors[0]
	theta_B = priors[1]
	# E step
	for observation in observations:
		len_observation = len(observation)
		num_heads = observation.sum()
		num_tails = len_observation - num_heads
		contribution_A = stats.binom.pmf(num_heads, len_observation, theta_A)
		contribution_B = stats.binom.pmf(num_heads, len_observation, theta_B) 
		weight_A = contribution_A / (contribution_A + contribution_B)
		weight_B = contribution_B / (contribution_A + contribution_B)
		print "A :%0.2f,	B :%0.2f" % (weight_A,weight_B)

		counts['A']['H'] += weight_A * num_heads
		counts['A']['T'] += weight_A * num_tails
		counts['B']['H'] += weight_B * num_heads
		counts['B']['T'] += weight_B * num_tails
		print counts
	# M step
	new_theta_A = counts['A']['H'] / (counts['A']['H'] + counts['A']['T'])
	new_theta_B = counts['B']['H'] / (counts['B']['H'] + counts['B']['T'])
	return [new_theta_A, new_theta_B]


print em(observations, [0.6, 0.5])


n = 10
k = np.arange(n+1)
pcoin = stats.binom.pmf(k, n, 0.5)
print pcoin
相關文章
相關標籤/搜索