Java PlatoBlockchain Data Intelligence のディリクレ過程混合モデルを使用したクラスタリング。垂直検索。あい。

Javaでのディリクレプロセス混合モデルによるクラスタリング

以前の記事では、 ディリクレプロセス混合モデル そして、それらをクラスター分析でどのように使用できるか。 この記事では、Javaの実装を紹介します。 XNUMXつの異なるDPMMモデル:ガウスデータのクラスター化に使用できるディリクレ多変量正規混合モデルと、ドキュメントのクラスター化に使用されるディリクレ多項混合モデル。 JavaコードはGPL v3ライセンスの下でオープンソースであり、以下から無料でダウンロードできます。 githubの.

更新:Datumbox Machine Learning Frameworkがオープンソースになり、無料で ダウンロード。 パッケージcom.datumbox.framework.machinelearning.clusteringをチェックして、Javaでのディリクレプロセス混合モデルの実装を確認してください。

Javaでのディリクレプロセス混合モデルの実装

コードは、ギブスサンプラーを使用してディリクレプロセス混合モデルを実装し、Apache Commons Math 3.3をマトリックスライブラリとして使用します。 GPLv3の下でライセンスされているので、自由に使用、変更、自由に再配布してください。Java実装は、次の場所からダウンロードできます。 githubの。 クラスタリングメソッドの理論的な部分はすべて、前の5つの記事と、ソースコードに実装するための詳細なJavadocコメントに記載されています。

以下に、コードの概要を示します。

1. DPMMクラス

DPMMは抽象クラスであり、さまざまな異なるモデルのベースのように機能し、 中華レストランのプロセス とが含まれています 折りたたまれたギブスサンプラー。 データセットをポイントのリストとして受け取り、クラスター分析の実行を担当するパブリックメソッドcluster()があります。 このクラスの他の便利なメソッドは、クラスタリングの完了後にクラスター割り当てを取得するために使用されるgetPointAssignments()と、識別されたクラスターのリストを取得するために使用されるgetClusterList()です。 DPMMには、静的にネストされた抽象クラスClusterが含まれています。 これには、クラスター割り当ての推定に使用されるポイントの管理と事後pdfの推定に関するいくつかの抽象的な方法が含まれています。

2. GaussianDPMMクラス

GaussianDPMMは、ディリクレ多変量正規混合モデルの実装であり、DPMMクラスを拡張します。 これには、ガウス仮定の下で確率を推定するために必要なすべてのメソッドが含まれています。 さらに、DPMM.Clusterクラスのすべての抽象メソッドを実装する静的な入れ子クラスClusterが含まれています。

3. MultinomialDPMMクラス

MultinomialDPMMは、ディリクレ多項式混合モデルを実装し、DPMMクラスを拡張します。 GaussianDPMMクラスと同様に、Multinomial-Dirichlet仮定の下で確率を推定するために必要なすべてのメソッドが含まれ、DPMM.Clusterの抽象メソッドを実装する静的ネストクラスClusterが含まれます。

4. SRSクラス

SRSクラスは、頻度テーブルから単純ランダムサンプリングを実行するために使用されます。 Gibbs Samplerは、反復プロセスの各ステップで新しいクラスターの割り当てを推定するために使用します。

5.ポイントクラス

Pointクラスは、レコードのデータとそのIDを格納するタプルとして機能します。

6. Apache Commons 数学ライブラリ

Apache Commons Math 3.3 libはMatrixの乗算に使用され、実装の唯一の依存関係です。

7. DPMMExampleクラス

このクラスには、Java実装の使用例が含まれています。

Java実装の使用

コードのユーザーは、モデルタイプやハイパーパラメーターなど、混合モデルのすべてのパラメーターを構成できます。 次のコードスニペットでは、アルゴリズムがどのように初期化および実行されるかを確認できます。

List<Point> pointList = new ArrayList<>();
//add records in pointList

//Dirichlet Process parameter
Integer dimensionality = 2;
double alpha = 1.0;

//Hyper parameters of Base Function
int kappa0 = 0;
int nu0 = 1;
RealVector mu0 = new ArrayRealVector(new double[]{0.0, 0.0});
RealMatrix psi0 = new BlockRealMatrix(new double[][]{{1.0,0.0},{0.0,1.0}});

//Create a DPMM object
DPMM dpmm = new GaussianDPMM(dimensionality, alpha, kappa0, nu0, mu0, psi0);

int maxIterations = 100;
int performedIterations = dpmm.cluster(pointList, maxIterations);

//get a list with the point ids and their assignments
Map<Integer, Integer> zi = dpmm.getPointAssignments();

以下に、300のデータポイントで構成される合成データセットでアルゴリズムを実行した結果を示します。 ポイントは元々3つの異なる分布によって生成されました:N([10,50]、I)、N([50,10]、I)およびN([150,100]、I)。

散布図1
図1:デモデータセットの散布図

アルゴリズムは10回の反復を実行した後、次の3つのクラスター中心を識別しました:[10.17、50.11]、[49.99、10.13]および[149.97、99.81]。 最後に、すべてをベイズ法で処理するため、クラスター中心の単一点推定だけでなく、 方程式.

scatterplot2-ヒートマップ
図2:クラスターの中心の確率の散布図

上の図では、これらの確率をプロットしています。 赤い領域はクラスターの中心となる可能性が高いことを示し、黒い領域は確率が低いことを示します。

実際のアプリケーションでJava実装を使用するには、元のデータセットを必要な形式に変換する外部コードを記述する必要があります。 さらに、上記のように出力を視覚化する場合は、追加のコードが必要になる場合があります。 最後に、Apache Commons Mathライブラリがプロジェクトに含まれているため、デモを実行するために追加の構成は必要ないことに注意してください。

興味深いプロジェクトで実装を使用する場合は、ご連絡ください。ブログでプロジェクトを取り上げます。 また、この記事が気に入ったら、少し時間を取ってTwitterまたはFacebookで共有してください。

タイムスタンプ:

より多くの データムボックス