为了账号安全,请及时绑定邮箱和手机立即绑定

关于EM算法原理的分析与理解(Python实现)

标签:
Python

本文的计算公式出自《统计学习方法》,写这篇文章主要是想把自己对这个算法的思路理清,并把自己的理解记录下来,同时分享出来,希望能够帮助到打算入门机器学习的人。

定义:
概率模型有时既含有观测变量,又含有隐变量或潜在变量。如果概率模型的变量都是观测变量,那么给定数据,可以直接用极大似然估计法,或贝叶斯估计法估计模型参数,但是,当模型含有隐变量时,就不能简单地使用这些估计方法了。EM算法就是含有隐变量的概率模型参数的极大似然估计法,或极大后验概率估计法。

算法原理:
一般用Y表示观测随机变量的数据,Z表示隐随机变量的数据,theta是估计的模型参数。


EM算法通过迭代求

webp

L_theta1.PNG


的极大似然估计,每次迭代包含两步:E步,求期望;M步,求极大化。
将极大似然函数展开得到:

webp

9.PNG


该极大似然函数的展开式的意思是:在模型参数中隐变量Z的条件概率与在隐变量下观测变量的条件概率乘积是一次操作中观测变量的概率,而求和则是所有操作中观测变量的概率。
该极大似然函数是无法求解的,事实上EM算法是通过不断的迭代近似极大化极大似然函数的。
为了新模型参数估计值能使极大似然函数极大化,则考虑第i+1次迭代与第i次迭代的差,通过不断求解差值的极大化,从而求得第i+1次迭代的极大化。


webp

6.PNG


webp

7.PNG


webp

8.PNG


看到这里就可以知道,通过取全数据对数似然的期望最大化,就可以最大化式子(9.12)

下面给出EM算法的一般步骤:


webp

10.PNG

Python实现:
以《统计学习方法》书上的三硬币例子为例,过程就不描述了,自己参考书上。主要是通过这个例子讲解一下EM算法,一步一步的说明EM算法的过程。
三硬币模型极大似然函数:


webp

11.PNG


webp

12.PNG


webp

13.PNG


(9.5)式子就是在第I次迭代中,根据每个观测变量分别求隐变量的条件概率(硬币A出现正面,选择硬币B得到观测数据的概率)。对应EM算法的第2步。


webp

14.PNG


(9.6)~(9.7)是如何得到的?我会在下一篇《EM算法在高斯混合模型学习中的应用》详细推导是如何得来的。目前我们只需要理解这几个公式的意思即可。
(9.6)公式:对第i次迭代中求出的所有操作中硬币A出现正面的概率期望。
(9.7)公式:对第i次迭代中硬币A出现正面,然后选择硬币B得到正面占硬币B所有操作中的概率。
(9.8)公式:对第i次迭代中硬币A出现反面,然后选择硬币C得到正面占硬币C所有操作中的概率。
不断的迭代直到满足条件则停止迭代,算法结束。
需要注意的是,EM算法对初始值是敏感的,不同的初始值会得到的参数不同。

完整代码如下:

from numpy import *import numpy as npimport matplotlib.pyplot as pltimport randomdef create_sample_data(m, n):
    mat_y = mat(zeros((m, n)))    for i in range(m):        for j in range(n):            #通过产生随机数,每一行表示一次实验结果
            mat_y[i, j] = random.randint(0, 1)    return mat_t#EM算法def em(arr_y, theta, tol, iterator_num):
    PI = 0
    P = 0
    Q = 0
    m,n = shape(arr_y)
    mat_y = arr_y.getA()    for i in range(iterator_num):
        miu = []
        PI = copy(theta[0])
        P = copy(theta[1])
        Q = copy(theta[2])        for j in range(m):
            miu_value = (PI * (P**mat_y[j]) * ((1 - P)**(1 - mat_y[j]))) / \
                (PI * (P**mat_y[j]) * ((1 - P)**(1 - mat_y[j])) + (1 - PI) * (Q**mat_y[j]) * ((1 - Q)**(1 - mat_y[j])))
            miu.append(miu_value)
        
        sum1 = 0.0
        for j in range(m):
            sum1 += miu[j]
        theta[0] = sum1 / m

        sum1 = 0.0
        sum2 = 0.0
        for j in range(m):
            sum1 += miu[j] * mat_y[j]
            sum2 += miu[j]
        theta[1] = sum1 / sum2

        sum1 = 0.0
        sum2 = 0.0
        for j in range(m):
            sum1 += (1 - miu[j]) * mat_y[j]
            sum2 += (1 - miu[j])
        theta[2] = sum1 / sum2

        print("-------------------")
        print(theta)        if(abs(theta[0] - PI) <= tol and abs(theta[1] - P) <= tol \            and abs(theta[2] - Q) <= tol):
            print("break")            break
    return PI,P,Qdef main():
    #mat_y = create_sample_data(100, 1)
    mat_y = mat(zeros((10, 1)))
    mat_y[0,0] = 1
    mat_y[1,0] = 1
    mat_y[2,0] = 0
    mat_y[3,0] = 1
    mat_y[4,0] = 0
    mat_y[5,0] = 0
    mat_y[6,0] = 1
    mat_y[7,0] = 0
    mat_y[8,0] = 1
    mat_y[9,0] = 1
    theta = [0.5, 0.6, 0.5]
    print(mat_y)
    PI,P,Q = em(mat_y, theta, 0.001, 100)
    print(PI, P, Q)

main()

输入数据(与《统计学习方法》的输入数据一样):

1
1
0
1
0
0
1
0
1
1

本文的输出结果:


webp

15.PNG


《统计学习方法》上的输出结果:


webp

16.PNG


两者对比结果是一样的。
可以改变模型参数再实验一下(依旧取书上的参数,看看是否输出依旧一样的)

def main():
    #mat_y = create_sample_data(100, 1)
    mat_y = mat(zeros((10, 1)))
    mat_y[0,0] = 1
    mat_y[1,0] = 1
    mat_y[2,0] = 0
    mat_y[3,0] = 1
    mat_y[4,0] = 0
    mat_y[5,0] = 0
    mat_y[6,0] = 1
    mat_y[7,0] = 0
    mat_y[8,0] = 1
    mat_y[9,0] = 1
    theta = [0.4, 0.6, 0.7]
    print(mat_y)
    PI,P,Q = em(mat_y, theta, 0.001, 100)
    print(PI, P, Q)

输出结果:


webp

17.PNG


书上的输出结果:


webp

18.PNG


经过对比还是完全一样的,算法正确。



作者:幸福洋溢的季节
链接:https://www.jianshu.com/p/154ee3354b59


点击查看更多内容
1人点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消