Exploration de l'algorithme de recommandation ALS de Spark PlatoBlockchain Data Intelligence. Recherche verticale. Aï.

Exploration de l'algorithme de recommandation ALS de Spark

L'algorithme ALS introduit par Hu et al., est une technique très populaire utilisée dans les problèmes du système de recommandation, en particulier lorsque nous avons des ensembles de données implicites (par exemple, des clics, des likes, etc.). Il peut gérer de gros volumes de données raisonnablement bien et nous pouvons trouver de nombreuses bonnes implémentations dans divers frameworks d'apprentissage automatique. Spark inclut l'algorithme dans le composant MLlib qui a été récemment refactorisé pour améliorer la lisibilité et l'architecture du code.

L'implémentation de Spark nécessite que l'élément et l'ID utilisateur soient des nombres dans une plage d'entiers (type entier ou long dans une plage d'entiers), ce qui est raisonnable car cela peut aider à accélérer les opérations et à réduire la consommation de mémoire. Une chose que j'ai cependant remarquée en lisant le code est que ces colonnes id sont converties en doubles puis en entiers au début des méthodes d'ajustement / prédiction. Cela semble un peu piraté et je l'ai vu mettre une pression inutile sur le ramasse-miettes. Voici les lignes sur le Code ALS qui convertissent les identifiants en double:
Exploration de l'algorithme de recommandation ALS de Spark PlatoBlockchain Data Intelligence. Recherche verticale. Aï.
Exploration de l'algorithme de recommandation ALS de Spark PlatoBlockchain Data Intelligence. Recherche verticale. Aï.

Pour comprendre pourquoi cela est fait, il faut lire le checkCast ():
Exploration de l'algorithme de recommandation ALS de Spark PlatoBlockchain Data Intelligence. Recherche verticale. Aï.

Cet UDF reçoit un Double et vérifie sa plage, puis le convertit en entier. Cet UDF est utilisé pour la validation de schéma. La question est de savoir si nous pouvons y parvenir sans utiliser de vilains doubles moulages? Je crois que oui:

  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

Le code ci-dessus montre un checkedCast () modifié qui reçoit l'entrée, vérifie que la valeur est numérique et déclenche des exceptions dans le cas contraire. Comme l'entrée est Any, nous pouvons supprimer en toute sécurité toutes les instructions de conversion en Double du reste du code. De plus, il est raisonnable de s'attendre à ce que, puisque l'ALS nécessite des identifiants dans une plage d'entiers, la majorité des gens utilisent en fait des types entiers. En conséquence, à la ligne 3, cette méthode gère explicitement les nombres entiers pour éviter de faire un cast. Pour toutes les autres valeurs numériques, il vérifie si l'entrée se trouve dans une plage d'entiers. Cette vérification s'effectue à la ligne 7.

On pourrait écrire cela différemment et gérer explicitement tous les types autorisés. Malheureusement, cela conduirait à un code en double. Au lieu de cela, ce que je fais ici est de convertir le nombre en entier et de le comparer avec le nombre d'origine. Si les valeurs sont identiques, l'une des conditions suivantes est vraie:

  1. La valeur est Byte ou Short.
  2. La valeur est Long mais dans la plage Integer.
  3. La valeur est Double ou Float mais sans aucune partie fractionnaire.

Pour m'assurer que le code fonctionne bien, je l'ai testé avec les tests unitaires standard de Spark et manuellement en vérifiant le comportement de la méthode pour diverses valeurs légales et illégales. Pour m'assurer que la solution est au moins aussi rapide que l'original, j'ai testé plusieurs fois en utilisant l'extrait ci-dessous. Cela peut être placé dans le Classe ALSSuite dans Spark:

  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

Après quelques tests, nous pouvons voir que le nouveau correctif est légèrement plus rapide que l'original:

Code

Nombre de courses

Temps d'exécution total

Temps d'exécution moyen par exécution

ORIGINALE 100 588.458s 5.88458s
Fixé 100 566.722s 5.66722s

J'ai répété les expériences plusieurs fois pour confirmer et les résultats sont cohérents. Ici vous pouvez trouver la sortie détaillée d'une expérience pour le code d'origine et par fixer. La différence est petite pour un minuscule ensemble de données, mais dans le passé, j'ai réussi à réduire considérablement les frais généraux du GC en utilisant ce correctif. Nous pouvons le confirmer en exécutant Spark localement et en attachant un profileur Java sur l'instance Spark. J'ai ouvert un billet et Pull-Demande sur le repo officiel Spark mais comme il n'est pas certain qu'il soit fusionné, j'ai pensé le partager ici avec vous et il fait maintenant partie de Spark 2.2.

Toutes les pensées, commentaires ou critiques sont les bienvenus! 🙂

Horodatage:

Plus de Boîte de données