02
18
03
40
为了账号安全,请及时绑定邮箱和手机立即绑定

经典永不过时:详解高斯混合模型及其变体!

本文深入讲解了高斯混合模型(GMM),并展示了它在聚类任务中的优势,比K-means更为出色。

图1:高斯混合模型的示意图 [AI生成的图片]

开始

在深度学习(DL)和变压器成为焦点的时代,很容易忘记经典算法如K-meansDBSCAN,和GMM。但是这里有一个大胆的见解:在处理现实世界的聚类和异常检测问题时,这些统计学工具仍然不可或缺,仍然具有惊人的持久力。

考虑日常的聚类难题:客户分类、社交网络分析或图像分割。K-means 通过其简单的基于质心的聚类解决了这些问题几十年。当数据呈现不规则形状时,DBSCAN 则使用其基于密度的聚类识别出让 K-means 无能为力的非凸簇。

但现实世界的数据很少形成整洁独立的气泡。这时,高斯混合模型及其变体就登场了!高斯混合模型承认聚类分配的基本不确定性。通过建模正常行为的概率分布,它们可以识别不符合预期模式的观测值,而无需事先标记的样本。所以在追逐最新的神经网络架构为你的聚类或分割任务之前,不妨考虑一下统计学中的经典模型,比如高斯混合模型。

许多人可以自信地谈论K均值的工作原理,但我敢用我的钱打赌,当谈到高斯混合模型时,没有多少人会那么自信。本文将用浅显易懂的方式来(我会尽力!)讨论高斯混合模型及其变体背后的数学,并说明它为什么值得你在下次聚类任务中更多关注。

记住,经典永远不会过时。它们等待那个完美的时机,重新夺回在那些过时疲惫的趋势之前的焦点。

本是什么样的高斯混合模型呢?

高斯混合是由多个高斯分布组成的,每个分布用 k ∈ {1,…, K} 来标识,其中 K 表示我们数据集中的簇的数量,您需要提前知道这个数量。混合中的每个高斯分布(标记为 k)包含以下参数:

  • 一个表示中心位置的均值 μ
  • 一个协方差矩阵 Σ。在多变量情况下,这类似于描述椭球的维度。
  • 一个定义高斯函数的权重比例 π,其中 π ≥ 0 且所有 kπ 之和等于 1。

数学上,它可以写成:

其中 p(x) 表示在点 x 的概率密度,而 N(x |μ , Σ) 表示均值为 μ 且协方差矩阵为 Σ 的多元高斯分布的概率密度。

方程式看起来都挺吓人,但别担心。先别急,我们来看看这个多元高斯函数 Nx | μ , Σ)以及它们各自的维度。假设数据集包含 N = 500 个三维数据点(D=3),那么数据集 x 是一个 500 × 3 的矩阵,μ 是一个 1 × 3 的向量,而 Σ 是一个 3 × 3 的矩阵。该函数的输出会是一个 500 × 1 的向量。

当我们使用GMM时,会遇到一个循环难题。

  • 要确定每个数据点属于哪个高斯簇,我们需要知道每个高斯簇的参数(均值、协方差、权重)。
  • 但是要准确估计这些参数,我们需要知道哪些数据点属于每个高斯簇。

为了打破这个循环,这里就是期望最大化(EM)算法登场的时候了,它会进行有根据的假设,然后迭代地调整这些假设。

使用EM算法进行参数估计(期望最大化算法)

EM算法通过以下步骤来确定GMM参数的最优值。

  • 第一步:初始化 — 随机初始化每个高斯聚类的参数(μ, Σ, π)。
  • 第二步:期望(E步骤) — 计算每个数据点对每个高斯聚类的归属程度,然后为每个数据点计算一组“责任”,这些“责任”表示该数据点来自每个聚类的概率。
  • 第三步:最大化(M步骤) — 使用数据集中的所有实例来更新每个高斯聚类,每个实例的权重是它属于该聚类的估计概率(即“责任”)。具体来说,新均值是所有数据点的加权平均,其中权重是“责任”。新协方差是围绕每个新均值的加权偏差。最后,新混合权重是每个组件收到的总责任的比例。需要注意的是,每个聚类的更新主要取决于其最负责的数据点
  • 第四步:重复 — 返回第二步,使用这些更新的参数,并继续直到变化变得最小(收敛)。

通常,人们在M步时会感到困惑,因为很多术语会让人感到困惑。我将使用之前的例子(500个3维数据点和3个高斯聚类)来更具体地解释它。

为了更新,我们采用加权平均的方式,其中每个点的贡献根据其“责任”值来决定,这与对应的高斯簇相关联。从数学上讲,对于第k个高斯簇,

新的权重计算公式是:((\sum[responsibility_{ik} \times point_i]) / (\sum \text{簇 } k \text{ 的权重})).

为了更新协方差,我们使用类似的加权方法。对于每个点,计算它与这个新均值之间的距离,然后将这个_deviation_与其_transpose_相乘以得到一个矩阵。然后,根据该点的r值对该矩阵进行加权,最后将所有点的加权矩阵相加。除以该簇的总责任,得到最终结果。

为了更新权重,我们只需将集群的“贡献”加起来,然后除以总数据点数。这样,对于第$k$个集群,我们可以通过这种方式来计算。

比如说对于高斯分布簇2。

  • 责任的总分数是200(满分500分)
  • 加权得分分别是400、600和800
  • 加权平方偏差的总和形成了一个协方差矩阵

那么:

  • 第 2 个聚类的每个聚类中心 = (400, 600, 800)/200 = (2, 3, 4);
  • 每个聚类的新混合比例 = 200/500 = 0.4;
  • 每个聚类的新协方差矩阵 = (加权偏差的总和)/200。

现在应该容易理解多了吧!

使用GMM进行聚类

现在我已经估算出了每个高斯聚类的位置、大小、形状、方向和相对权重,GMM 可以很容易地把数据点归到最可能的聚类中(硬聚类方法),或者估算它属于某个特定聚类的概率(软聚类方法)。

GMM在Python中的实现非常简单直接,这要感谢经典的_scikit-learn_库的支持。这里我提供了一个使用内置GMM进行聚类任务的示例代码,代码使用的数据点是随机生成的,包含3个聚类。如图2。

图2:用于聚类的数据点(data points)(图源:作者)

下面给出的是用 Python 写的代码:

    sklearn.mixture导入GaussianMixture模型  

    gm = GaussianMixture(n_components=3, n_init=10, random_state=42)  
    gm.fit(X)  

    # 查看GM估计的参数(权重、均值和协方差)  
    # 权重、均值和协方差  
    gm.weights_  
    gm.means_  
    gm.covariances_  

    # 对数据点进行预测  

    # 硬聚类  
    gm.predict(X)  

    # 软分配  
    gm.predict_proba(X).round(2)

图3展示了聚类位置、决策的界限以及GMM(混合高斯模型)的密度分布的轮廓(它估计了模型在任何位置的密度特性)。

图3:训练后的GMM模型的聚类的位置、决策界限和密度曲线 [图由作者绘制]

看起来GMM确实找到了一个很好的解决方案!但值得注意的是,实际数据并不总是这么符合高斯分布且维度较低。当问题是高维度和众多聚类时,EM算法可能会难以达到收敛到最优解。为解决这个问题,你可以减少GMM需要学习的参数数量。一种方法是限制聚类形状和方向的变化范围。这可以通过对协方差矩阵进行约束来实现,这可以通过设置covariance_type超参数来完成。

    gm_full = GaussianMixture(n_components=3,   
                              n_init=10,  
                              covariance_type="full", # 默认(default)  
                              random_state=42)
  • "full"(默认):无约束,所有聚类可以采用任意大小的椭球形状 [1]。

  • "spherical":所有聚类必须是球形的,但它们可以有不同的直径(即不同的差异)。

  • "diag":聚类可以采用任意大小的椭球形状,但椭球的轴必须平行于坐标轴(也就是说,协方差矩阵必须是对角的)。

  • "tied":所有聚类的形状必须相同,可以是任意椭球(也就是说,它们共享同一个协方差矩阵)。

为了显示与默认设置的不同,如图4所示,当将covariance_type设置为 "tied" 时,EM算法找到的结果。

图4:使用共享簇的GMM模型对同一任务进行聚类的结果(作者提供)

讨论训练GMM的计算复杂度同样非常重要。这主要取决于数据点的数量 m 、维度 n 、聚类的数量 k ,以及协方差矩阵的四种约束(如上所述)。如果 covariance_type 设置为 "spherical""diag",则计算复杂度为 O(kmn),前提是数据具有聚类结构。如果 covariance_type 设置为 "tied""full",则计算复杂度变为 O(kmn² + kn³),这将无法很好地扩展 [1]。

找到合适的聚类数量

这个例子相当简单,部分原因是生成数据集时已经知道了聚类的数量这一信息。但在没有这个信息的情况下进行训练时,需要一些衡量标准来帮助找到最合适的聚类数量。

对于GMM模型,你可以尝试找到使_理论信息标准_最小化的模型,比如贝叶斯信息准则(BIC)和赤池信息准则(AIC),定义如下。

BIC = log(m) p — 2 log(L)

AIC = 2 p — 2 log(L)

这里,m 表示数据点的数量,p 表示模型要学习的参数数量,而 L 表示模型的似然函数的最大值。用 Python 来计算这些值非常简单。

这里有两个命令:gm.bic(X)  gm.aic(X)。这些可能是程序中的函数调用。gm.bic(X) 可能代表某种计算或处理,而 gm.aic(X) 则可能表示另一种处理方式。

BIC 和 AIC 会惩罚参数更多的模型(例如,更多的聚类),并奖励数据拟合更好的模型。值越低,模型拟合数据越好。在实践中,您可以设定一系列的聚类数量 k,绘制 BIC 或 AIC 对不同 k 的变化曲线,选择 BIC 或 AIC 最低的点。

GMM的各种版本

我发现GMM的一些变体非常实用且方便,经常可以在其经典形式上做出进一步改进。

  • 基于贝叶斯的GMM:可以将不必要的聚类的权重设置为等于或接近零。在实践中,您可以将聚类数量 n_components 设置为一个您有充分理由相信大于实际最优聚类数量的值,然后它将自动为您处理学习过程。
  • 鲁棒GMM:它解决了GMM对异常值过于敏感的问题。它不是最大化标准对数似然,而是使用鲁棒估计器来减少远离聚类中心的点的权重。从而提供了更稳定的结果。
  • 在线/增量GMM:在线/增量GMM解决了标准GMM的计算和内存限制问题。参数会在看到每个新数据点或小批量数据后进行更新,而不是需要整个数据集来进行参数更新。它还包含一个“遗忘机制”,使模型可以忘记较早的数据并更好地适应非平稳分布。
GMM 与 K-means

由于实际数据常常很复杂,GMM 在聚类和分段任务中通常表现更佳。我通常先运行 K-means 作为基准线,然后尝试 GMM 或其变体,看看额外的复杂性是否带来了实际的改进。但让我们将它们并列比较,看看这两种经典算法之间的主要的不同之处。

图5:K-means与GMM在不同方面的对比 [图片由作者提供]

GMM的,优点——实用异常检测工具!

使用高斯混合模型(GMM)来进行异常检测任务很简单:任何位于低密度区域的实例都可以被视为异常。然而,关键是必须确定要使用的密度阈值。例如,我将使用GMM来识别异常网络流量

这项任务包括以下功能:

  • 数据包大小(字节)
  • 包间间隔时间(毫秒)
  • 连接持续时间(秒)
  • 特定协议值
  • 包负载熵
  • TCP窗口大小

原始的数据集看起来是这样的,如下所示。

图6:网络流量数据集的头部信息和值格式 [图由作者绘制]

代码片段将展示如何预处理数据的方式、训练并评估模型,以及与常见异常检测方法的比较。

    import numpy as np  
    import pandas as pd  
    import matplotlib.pyplot as plt  
    from sklearn.mixture import GaussianMixture  
    from sklearn.preprocessing import StandardScaler  
    from sklearn.model_selection import train_test_split  
    from sklearn.metrics import precision_recall_curve, average_precision_score  
    from sklearn.metrics import f1_score  

    # raw_df 已在之前展示过  
    df = raw_df.copy()  

    X = df.drop(columns=['is_anomaly'])  
    y = df['is_anomaly']  

    # 划分数据  
    X_train, X_test, y_train, y_test = train_test_split(  
        X, y, test_size=0.3, random_state=42, stratify=y  
    )  

    # 标准化特征  
    scaler = StandardScaler()  
    X_train_scaled = scaler.fit_transform(X_train)  
    X_test_scaled = scaler.transform(X_test)  

    # 我们将使用仅包含正常流量的数据进行训练  
    # 这是异常检测中常见的方法  
    X_train_normal = X_train_scaled[y_train == 0]  

    # 尝试不同的组件数量以找到最佳拟合效果  
    n_components_range = range(1, 10)  
    bic_scores = []  
    aic_scores = []  

    for n_components in n_components_range:  
        gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=42)  
        gmm.fit(X_train_normal)  
        bic_scores.append(gmm.bic(X_train_normal))  
        aic_scores.append(gmm.aic(X_train_normal))  

    # 根据 BIC 选择最优组件数  
    optimal_components = n_components_range[np.argmin(bic_scores)]  
    print(f"根据 BIC 选择的最优组件数: {optimal_components}")  

    # 训练最终模型  
    gmm = GaussianMixture(n_components=optimal_components, covariance_type='full', random_state=42)  
    gmm.fit(X_train_normal)  

    # 计算负对数概率(分数越高表示越异常)  
    # gmm_train_scores 对评估中的阈值百分位数非常重要  
    gmm_train_scores = -gmm.score_samples(X_train_scaled)  
    gmm_test_scores = -gmm.score_samples(X_test_scaled)  

    def evaluate_model(y_true, anomaly_scores, threshold_percentile=3):  
        """  
        评估模型性能  

        参数:  
        y_true: 真实标签(0 表示正常,1 表示异常)  
        anomaly_scores: 分数越高表示越异常  
        阈值百分位设定: 阈值选择的百分位数  

        返回:  
        性能指标字典  
        """  
        # 基于训练分数计算阈值  
        threshold = np.percentile(anomaly_scores, 100 - threshold_percentile)  

        # 预测异常  
        y_pred = (anomaly_scores > threshold).astype(int)  

        # 计算各个评估指标  
        f1 = f1_score(y_true, y_pred)  
        precision, recall, _ = precision_recall_curve(y_true, anomaly_scores)  
        avg_precision = average_precision_score(y_true, anomaly_scores)  

        return {  
            'f1_score': f1,  
            'avg_precision': avg_precision,  
            'precision_curve': precision,  
            'recall_curve': recall,  
            'threshold': threshold,  
            'y_pred': y_pred  
        }  

    # 计算各项指标  
    gmm_results = evaluate_model(y_test, gmm_test_scores)

让我们加入几个常见的异常检测方法,即孤立森林一类支持向量机,和局部离群点因子(LOF),并检查它们的性能。由于不规则的交通模式是罕见的,所以我将使用PR-AUC这一指标作为模型有效性的评估指标。结果如图7所示,其中结果越接近1,模型越准确。

图7:使用PR-AUC(精确率-召回率曲线下的面积)指标对GMM、孤立森林、单类SVM和LOF进行网络流量检测任务的对比分析。[作者制作]

结果显示,GMM在识别不规则的网络流量方面很强大,并且优于其他常用的方法!GMM特别适合用于异常检测任务,特别是当正常行为包括多种不同的模式,或者你需要概率异常分数的时候。

实际案例通常比我在文中展示的步骤更复杂,但希望这篇博客文章能为你理解GMM的工作原理以及如何将其应用于你的聚类或异常检测任务提供一个坚实的基础。

参考资料

[Aurelien Geron. 手动手学机器学习:基于scikit-learn、Keras和TensorFlow. O'Reilly, 2023,]

[2] 贝叶斯信息标准,维基百科: https://en.wikipedia.org/wiki/Bayesian_information_criterion

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

0 评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号

举报

0/150
提交
取消