- 2017 年 2 月 25 日
- ヴァシリス・ヴリニオティス
- 。 3コメント
によって導入されたALSアルゴリズム Hu et al。は、特に暗黙的なデータセット(クリックなど)がある場合に、Recommenderシステムの問題で使用される非常に人気のある手法です。 大量のデータを適切に処理でき、さまざまな機械学習フレームワークで多くの優れた実装を見つけることができます。 SparkのアルゴリズムはMLlibコンポーネントに組み込まれており、コードの可読性とアーキテクチャを改善するために最近リファクタリングされています。
Sparkの実装では、アイテムとユーザーのIDが整数の範囲内の数値(整数型または整数の範囲内のLong)である必要があります。これは、操作の高速化とメモリ消費の削減に役立つため、妥当です。 コードを読んでいるときに気付いたのは、これらのid列がdoubleにキャストされ、次にfit / predictメソッドの最初でIntegerにキャストされていることです。 これは少しハックに見え、ガベージコレクタに不必要な負担がかかるのを見てきました。 ここに行があります ALSコード idをdoubleにキャストします。
これが行われる理由を理解するには、checkedCast()を読み取る必要があります。
このUDFはDoubleを受け取り、その範囲をチェックして、それを整数にキャストします。 このUDFは、スキーマ検証に使用されます。 問題は、醜いダブルキャスティングを使用せずにこれを実現できるかどうかです。 はい:
protected val checkedCast = udf { (n: Any) => n match { case v: Int => v // Avoid unnecessary casting case v: Number => val intV = v.intValue() // True for Byte/Short, Long within the Int range and Double/Float with no fractional part. if (v.doubleValue == intV) { intV } else { throw new IllegalArgumentException(s"ALS only supports values in Integer range " + s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") } case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " + s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.") } }
上記のコードは、入力を受け取り、値が数値であることを確認し、それ以外の場合は例外を発生させる、変更されたcheckedCast()を示しています。 入力はAnyなので、残りのコードからDoubleステートメントへのすべてのキャストを安全に削除できます。 さらに、ALSは整数の範囲内のIDを必要とするため、大多数の人々が実際に整数型を使用することを期待するのは妥当です。 その結果、3行目では、このメソッドは整数を明示的に処理して、キャストを行わないようにします。 他のすべての数値については、入力が整数の範囲内にあるかどうかをチェックします。 このチェックは7行目で行われます。
これを別の方法で記述し、許可されたすべてのタイプを明示的に処理することができます。 残念ながら、これはコードの重複につながります。 代わりに、ここで行うことは、数値を整数に変換し、それを元の数値と比較することです。 値が同一の場合、次のいずれかが当てはまります。
- 値はByteまたはShortです。
- 値はLongですが、整数の範囲内です。
- 値はDoubleまたはFloatですが、小数部分はありません。
コードが適切に実行されることを確認するために、Sparkの標準の単体テストでテストし、メソッドの動作を確認して、さまざまな正当な値と不正な値がないか手動でテストしました。 ソリューションが少なくとも元のソリューションと同じ速さであることを確認するために、以下のスニペットを使用して何度もテストしました。 これは ALSSuiteクラス Sparkでは:
test("Speed difference") { val (training, test) = genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01) val runs = 100 var totalTime = 0.0 println("Performing "+runs+" runs") for(i <- 0 until runs) { val t0 = System.currentTimeMillis testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1) val secs = (System.currentTimeMillis - t0)/1000.0 println("Run "+i+" executed in "+secs+"s") totalTime += secs } println("AVG Execution Time: "+(totalTime/runs)+"s") }
いくつかのテストの後、新しい修正が元の修正よりもわずかに速いことがわかります。
Code |
実行数 |
総実行時間 |
実行ごとの平均実行時間 |
元の | 100 | 588.458s | 5.88458s |
固定の | 100 | 566.722s | 5.66722s |
私は確認するために実験を複数回繰り返し、結果は一貫しています。 ここでは、XNUMXつの実験の詳細な出力を見つけることができます 元のコード と 修正します。 小さなデータセットの違いは小さいですが、過去にこの修正を使用して、GCオーバーヘッドの顕著な削減を達成することができました。 これを確認するには、Sparkをローカルで実行し、SparkインスタンスにJavaプロファイラーをアタッチします。 私が開いた チケット フォルダーとその下に プルリクエスト 公式のSparkリポジトリ 合併されるかどうかは定かではないので、こちらで共有したいと思います そして現在はSpark 2.2の一部です。
どんな考え、コメント、批判も歓迎します! 🙂