Boren in Spark's ALS Recommendation-algoritme PlatoBlockchain Data Intelligence. Verticaal zoeken. Ai.

Boren in het ALS-aanbevelingsalgoritme van Spark

Het ALS-algoritme geïntroduceerd door Hu et al., is een zeer populaire techniek die wordt gebruikt bij problemen met het recommender-systeem, vooral wanneer we impliciete datasets hebben (bijvoorbeeld klikken, likes enz.). Het kan redelijk goed omgaan met grote hoeveelheden gegevens en we kunnen veel goede implementaties vinden in verschillende Machine Learning-frameworks. Spark neemt het algoritme op in de MLlib-component die onlangs is aangepast om de leesbaarheid en de architectuur van de code te verbeteren.

De implementatie van Spark vereist dat het item en de gebruikers-ID getallen zijn binnen een geheel getalbereik (ofwel een geheel getal of lang binnen een geheel getalbereik), wat redelijk is, omdat dit de bewerkingen kan versnellen en het geheugenverbruik kan verminderen. Een ding dat ik echter opmerkte tijdens het lezen van de code, is dat die id-kolommen aan het begin van de fit / predict-methoden in Doubles worden gegoten en vervolgens in Integers. Dit lijkt een beetje hacky en ik heb gezien dat het de vuilnisman onnodig belast. Hier zijn de lijnen op de ALS-code die de ID's in het dubbel werpen:
Boren in Spark's ALS Recommendation-algoritme PlatoBlockchain Data Intelligence. Verticaal zoeken. Ai.
Boren in Spark's ALS Recommendation-algoritme PlatoBlockchain Data Intelligence. Verticaal zoeken. Ai.

Om te begrijpen waarom dit wordt gedaan, moet men checkCast () lezen:
Boren in Spark's ALS Recommendation-algoritme PlatoBlockchain Data Intelligence. Verticaal zoeken. Ai.

Deze UDF ontvangt een Double en controleert het bereik en werpt het vervolgens naar een geheel getal. Deze UDF wordt gebruikt voor schemavalidatie. De vraag is of we dit kunnen bereiken zonder lelijke dubbele gietstukken te gebruiken? Ik geloof van wel:

  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

De bovenstaande code toont een gewijzigde checkCast () die de invoer ontvangt, controleert of de waarde numeriek is en anders uitzonderingen oplevert. Aangezien de invoer Any is, kunnen we alle cast naar Double-instructies veilig verwijderen uit de rest van de code. Bovendien is het redelijk om te verwachten dat aangezien de ALS ids vereist binnen een integer bereik, de meerderheid van de mensen daadwerkelijk integer types gebruikt. Als resultaat op regel 3 behandelt deze methode Integers expliciet om te voorkomen dat er wordt gecast. Voor alle andere numerieke waarden wordt gecontroleerd of de invoer binnen het gehele bereik valt. Deze controle vindt plaats op regel 7.

Je zou dit anders kunnen schrijven en expliciet omgaan met alle toegestane typen. Helaas zou dit leiden tot dubbele code. Wat ik hier doe, is het getal converteren naar een geheel getal en het vergelijken met het oorspronkelijke nummer. Als de waarden identiek zijn, is een van de volgende voorwaarden van toepassing:

  1. De waarde is Byte of Short.
  2. De waarde is lang, maar binnen het geheel getal.
  3. De waarde is Double of Float, maar zonder een fractioneel deel.

Om ervoor te zorgen dat de code goed werkt, heb ik deze getest met de standaard unit-tests van Spark en handmatig door het gedrag van de methode te controleren op verschillende legale en illegale waarden. Om ervoor te zorgen dat de oplossing minstens zo snel is als het origineel, heb ik meerdere keren getest met behulp van het onderstaande fragment. Deze kan in de ALSSuite-klasse in vonk:

  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

Na een paar tests kunnen we zien dat de nieuwe oplossing iets sneller is dan de originele:

Code

Aantal runs

Totale uitvoeringstijd

Gemiddelde uitvoeringstijd per run

ORIGINELE 100 588.458s 5.88458s
vast 100 566.722s 5.66722s

Ik herhaalde de experimenten meerdere keren om te bevestigen en de resultaten zijn consistent. Hier vindt u de gedetailleerde output van één experiment voor de originele code en repareren. Het verschil is klein voor een kleine dataset, maar in het verleden heb ik met deze oplossing een merkbare vermindering van GC-overhead bereikt. We kunnen dit bevestigen door Spark lokaal uit te voeren en een Java-profiler op de Spark-instantie te koppelen. Ik opende een ticket en Pull-verzoek op de officiële Spark-opslagplaats maar omdat het niet zeker is of het zal worden samengevoegd, dacht ik het hier met u te delen en het maakt nu deel uit van Spark 2.2.

Alle gedachten, opmerkingen of kritiek zijn welkom! 🙂

Tijdstempel:

Meer van Datumbox