Einblick in den ALS-Empfehlungsalgorithmus von Spark, PlatoBlockchain Data Intelligence. Vertikale Suche. Ai.

Bohren in den ALS-Empfehlungsalgorithmus von Spark

Der von eingeführte ALS-Algorithmus Hu et al.ist eine sehr beliebte Technik, die bei Problemen mit dem Empfehlungssystem verwendet wird, insbesondere wenn implizite Datensätze vorhanden sind (z. B. Klicks, Likes usw.). Es kann ziemlich gut mit großen Datenmengen umgehen und wir können viele gute Implementierungen in verschiedenen Frameworks für maschinelles Lernen finden. Spark enthält den Algorithmus in der MLlib-Komponente, der kürzlich überarbeitet wurde, um die Lesbarkeit und die Architektur des Codes zu verbessern.

Die Implementierung von Spark erfordert, dass die Element- und Benutzer-ID Zahlen im Ganzzahlbereich sind (entweder Integer-Typ oder Long im Ganzzahlbereich). Dies ist sinnvoll, da dies die Operationen beschleunigen und den Speicherverbrauch reduzieren kann. Beim Lesen des Codes ist mir jedoch aufgefallen, dass diese ID-Spalten zu Beginn der Fit / Predict-Methoden in Doubles und dann in Integers umgewandelt werden. Das scheint ein bisschen hackig zu sein und ich habe gesehen, dass es den Müllsammler unnötig belastet. Hier sind die Zeilen auf der ALS-Code das warf die IDs in Doppel:
Einblick in den ALS-Empfehlungsalgorithmus von Spark, PlatoBlockchain Data Intelligence. Vertikale Suche. Ai.
Einblick in den ALS-Empfehlungsalgorithmus von Spark, PlatoBlockchain Data Intelligence. Vertikale Suche. Ai.

Um zu verstehen, warum dies getan wird, muss man checkedCast () lesen:
Einblick in den ALS-Empfehlungsalgorithmus von Spark, PlatoBlockchain Data Intelligence. Vertikale Suche. Ai.

Diese UDF empfängt ein Double, überprüft seinen Bereich und wandelt ihn dann in eine Ganzzahl um. Diese UDF wird für die Schemaüberprüfung verwendet. Die Frage ist, können wir dies erreichen, ohne hässliche Doppelgussteile zu verwenden? Ich glaube ja:

  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

Der obige Code zeigt eine modifizierte checkedCast (), die die Eingabe empfängt, prüft, ob der Wert numerisch ist, und ansonsten Ausnahmen auslöst. Da die Eingabe Beliebig ist, können wir alle in Double umgesetzten Anweisungen sicher aus dem Rest des Codes entfernen. Darüber hinaus ist zu erwarten, dass die Mehrheit der Benutzer tatsächlich ganzzahlige Typen verwendet, da der ALS IDs im ganzzahligen Bereich erfordert. Infolgedessen behandelt diese Methode in Zeile 3 Ganzzahlen explizit, um ein Casting zu vermeiden. Bei allen anderen numerischen Werten wird geprüft, ob die Eingabe im ganzzahligen Bereich liegt. Diese Überprüfung erfolgt in Zeile 7.

Man könnte dies anders schreiben und explizit alle erlaubten Typen behandeln. Leider würde dies zu doppeltem Code führen. Stattdessen konvertiere ich hier die Zahl in eine Ganzzahl und vergleiche sie mit der ursprünglichen Zahl. Wenn die Werte identisch sind, gilt eine der folgenden Bedingungen:

  1. Der Wert ist Byte oder Short.
  2. Der Wert ist Long, liegt jedoch im Integer-Bereich.
  3. Der Wert ist Double oder Float, jedoch ohne Bruchteil.

Um sicherzustellen, dass der Code gut funktioniert, habe ich ihn mit den Standard-Unit-Tests von Spark und manuell getestet, indem ich das Verhalten der Methode auf verschiedene legale und illegale Werte überprüft habe. Um sicherzustellen, dass die Lösung mindestens so schnell wie das Original ist, habe ich sie mehrfach mit dem folgenden Snippet getestet. Dies kann in der platziert werden ALSSuite Klasse im Funken:

  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

Nach einigen Tests können wir feststellen, dass das neue Update etwas schneller als das Original ist:

Code

Anzahl der Läufe

Gesamtausführungszeit

Durchschnittliche Ausführungszeit pro Lauf

Original 100 588.458er-Jahre 5.88458er-Jahre
Behoben 100 566.722er-Jahre 5.66722er-Jahre

Ich habe die Experimente mehrmals wiederholt, um zu bestätigen, dass die Ergebnisse konsistent sind. Hier finden Sie die detaillierte Ausgabe eines Experiments für die ursprünglicher Code und für fixieren. Der Unterschied ist für einen winzigen Datensatz gering, aber in der Vergangenheit konnte ich mit diesem Fix eine spürbare Reduzierung des GC-Overheads erzielen. Wir können dies bestätigen, indem wir Spark lokal ausführen und einen Java-Profiler an die Spark-Instanz anhängen. Ich öffnete eine Ticket und einem Pull-Anfrage auf dem offiziellen Spark Repo aber weil es ungewiss ist, ob es zusammengeführt wird, dachte ich, es hier mit Ihnen zu teilen und es ist jetzt Teil von Spark 2.2.

Alle Gedanken, Kommentare oder Kritik sind willkommen! 🙂

Zeitstempel:

Mehr von Bezugsbox