- 2014 年 7 月 7 日
- 瓦西利斯·弗里尼奥提斯(Vasilis Vryniotis)
- 。 1条评论
在之前的文章中,我们详细讨论了 Dirichlet过程混合模型 以及如何将它们用于聚类分析。 在本文中,我们将介绍Java的实现 两种不同的DPMM型号:可用于对高斯数据进行聚类的Dirichlet多元正态混合模型和用于对文档进行聚类的Dirichlet-多项式混合模型。 Java代码在GPL v3许可下是开源的,可以从以下位置免费下载 Github上.
更新:Datumbox机器学习框架现在是开源的,免费提供给 下载。 检出com.datumbox.framework.machinelearning.clustering软件包,以了解Java中Dirichlet Process Mixture模型的实现。
Java中的Dirichlet Process Mixture Model实现
该代码使用Gibbs Sampler实现了Dirichlet Process Mixture模型,并使用Apache Commons Math 3.3作为矩阵库。 它已获得GPLv3许可,因此可以随意使用,修改和自由分发它,您可以从以下位置下载Java实现: Github上。 请注意,您可以在前5篇文章中找到集群方法的所有理论部分,并在源代码中找到用于实现的详细Javadoc注释。
下面我们列出了有关代码的高级描述:
1. DPMM类
DPMM是一个抽象类,就像各种不同模型的基础一样,实现了 中餐厅流程 并包含 折叠的吉布斯采样器。 它具有公共方法cluster(),该方法将数据集作为点列表接收,并负责执行聚类分析。 该类的其他有用方法是getPointAssignments()和getClusterList(),getPointAssignments()用于在完成集群后检索集群分配,getClusterList()用于获取已识别集群的列表。 DPMM包含静态嵌套的抽象类Cluster。 它包含几种有关点管理和后pdf估计的抽象方法,这些方法用于估计聚类分配。
2. GaussianDPMM类
GaussianDPMM是Dirichlet多元正态混合模型的实现,并且扩展了DPMM类。 它包含估计高斯假设下的概率所需的所有方法。 此外,它包含静态嵌套类Cluster,该类实现了DPMM.Cluster类的所有抽象方法。
3. MultinomialDPMM类
MultinomialDPMM实现Dirichlet-多项混合模型并扩展了DPMM类。 与GaussianDPMM类类似,它包含在Multinomial-Dirichlet假设下估算概率所需的所有方法,并包含静态嵌套类Cluster,该类实现了DPMM.Cluster的抽象方法。
4. SRS课
SRS类用于从频率表执行简单随机采样。 Gibbs Sampler使用它在迭代过程的每个步骤中估计新的群集分配。
5.积分课
Point类用作元组,用于存储记录的数据及其ID。
6. Apache Commons Math库
Apache Commons Math 3.3库用于矩阵乘法,它是我们实现的唯一依赖项。
7. DPMMExample类
此类包含有关如何使用Java实现的示例。
使用Java实现
代码的用户能够配置混合模型的所有参数,包括模型类型和超参数。 在以下代码片段中,我们可以看到如何初始化和执行算法:
List<Point> pointList = new ArrayList<>(); //add records in pointList //Dirichlet Process parameter Integer dimensionality = 2; double alpha = 1.0; //Hyper parameters of Base Function int kappa0 = 0; int nu0 = 1; RealVector mu0 = new ArrayRealVector(new double[]{0.0, 0.0}); RealMatrix psi0 = new BlockRealMatrix(new double[][]{{1.0,0.0},{0.0,1.0}}); //Create a DPMM object DPMM dpmm = new GaussianDPMM(dimensionality, alpha, kappa0, nu0, mu0, psi0); int maxIterations = 100; int performedIterations = dpmm.cluster(pointList, maxIterations); //get a list with the point ids and their assignments Map<Integer, Integer> zi = dpmm.getPointAssignments();
下面我们可以看到在包含300个数据点的综合数据集上运行算法的结果。 这些点最初是由3种不同的分布生成的:N([10,50],I),N([50,10],I)和N([150,100],I)。
图1:演示数据集的散点图
该算法运行10次迭代后,确定了以下3个聚类中心:[10.17,50.11],[49.99,10.13]和[149.97,99.81]。 最后,由于我们以贝叶斯方式对待所有事物,因此我们不仅能够提供聚类中心的单点估计,而且能够通过使用 公式 .
图2:集群中心概率的散点图
在上图中,我们绘制了这些概率; 红色区域表示处于群集中心的可能性很高,黑色区域表示处于概率较低的可能性。
要在现实世界的应用程序中使用Java实现,您必须编写外部代码,以将原始数据集转换为所需的格式。 此外,如果您想像上面看到的那样可视化输出,可能还需要其他代码。 最后请注意,Apache Commons Math库包含在项目中,因此无需额外配置即可运行演示。
如果您在有趣的项目中使用该实现,请给我们留言,我们将在博客中介绍您的项目。 另外,如果您喜欢这篇文章,请花一点时间在Twitter或Facebook上分享。