Borer i Sparks ALS-anbefalingsalgoritme PlatoBlockchain Data Intelligence. Lodret søgning. Ai.

Borer i Sparks ALS-anbefalingsalgoritme

ALS-algoritmen introduceret af Hu et al., er en meget populær teknik, der bruges i Recommender System-problemer, især når vi har implicitte datasæt (for eksempel klik, likes osv.). Den kan håndtere store mængder data rimeligt godt, og vi kan finde mange gode implementeringer i forskellige Machine Learning rammer. Spark inkluderer algoritmen i MLlib-komponenten, som for nylig er blevet refaktoreret for at forbedre kodens læsbarhed og arkitektur.

Sparks implementering kræver, at varen og bruger-id'et er tal inden for heltalsområdet (enten heltalstype eller lang inden for heltalsområdet), hvilket er rimeligt, da dette kan hjælpe med at fremskynde operationerne og reducere hukommelsesforbruget. En ting, jeg dog bemærkede, mens jeg læste koden, er, at disse id-kolonner bliver castet ind i Doubles og derefter i Heltal i begyndelsen af ​​fit/predict-metoderne. Dette virker en smule hacket, og jeg har set, at det belaster skraldeopsamleren unødigt. Her er linjerne på ALS kode der kaster id'erne til doubler:
Borer i Sparks ALS-anbefalingsalgoritme PlatoBlockchain Data Intelligence. Lodret søgning. Ai.
Borer i Sparks ALS-anbefalingsalgoritme PlatoBlockchain Data Intelligence. Lodret søgning. Ai.

For at forstå, hvorfor dette gøres, skal man læse checkedCast():
Borer i Sparks ALS-anbefalingsalgoritme PlatoBlockchain Data Intelligence. Lodret søgning. Ai.

Denne UDF modtager en Double og kontrollerer dens rækkevidde og kaster den derefter til heltal. Denne UDF bruges til skemavalidering. Spørgsmålet er, om vi kan opnå dette uden at bruge grimme dobbeltstøbninger? Jeg tror ja:

  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

Ovenstående kode viser en modificeret checkedCast(), som modtager input, checks hævder, at værdien er numerisk og ellers rejser undtagelser. Da inputtet er Any, kan vi sikkert fjerne alle cast to Double-sætninger fra resten af ​​koden. Desuden er det rimeligt at forvente, at da ALS kræver id'er inden for heltalsområde, bruger de fleste mennesker faktisk heltalstyper. Som et resultat på linje 3 håndterer denne metode heltal eksplicit for at undgå at foretage nogen casting. For alle andre numeriske værdier kontrollerer den, om inputtet er inden for et heltalsområde. Denne kontrol sker på linje 7.

Man kunne skrive dette anderledes og eksplicit håndtere alle de tilladte typer. Desværre ville dette føre til duplikatkode. Det, jeg gør her, er i stedet at konvertere tallet til heltal og sammenligne det med det originale tal. Hvis værdierne er identiske, er et af følgende sandt:

  1. Værdien er Byte eller Short.
  2. Værdien er lang, men inden for heltalsområdet.
  3. Værdien er Double eller Float, men uden nogen brøkdel.

For at sikre, at koden kører godt, testede jeg den med standard unit-testene fra Spark og manuelt ved at kontrollere metodens adfærd for forskellige lovlige og ulovlige værdier. For at sikre, at løsningen er mindst lige så hurtig som originalen, testede jeg adskillige gange ved hjælp af uddraget nedenfor. Dette kan placeres i ALSSuite klasse i Spark:

  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

Efter et par test kan vi se, at den nye rettelse er lidt hurtigere end den originale:

Kode

Antal kørsler

Samlet udførelsestid

Gennemsnitlig udførelsestid pr. kørsel

Original 100 588.458s 5.88458s
Fast 100 566.722s 5.66722s

Jeg gentog eksperimenterne flere gange for at bekræfte, og resultaterne er konsistente. Her kan du finde det detaljerede output fra et eksperiment for original kode og fastsætte. Forskellen er lille for et lille datasæt, men tidligere har jeg formået at opnå en mærkbar reduktion i GC-overhead ved hjælp af denne rettelse. Vi kan bekræfte dette ved at køre Spark lokalt og vedhæfte en Java-profiler på Spark-instansen. Jeg åbnede en billet og en Træk-anmodning på den officielle Spark-repo men fordi det er usikkert om det bliver lagt sammen, tænkte jeg at dele det her med jer og det er nu en del af Spark 2.2.

Alle tanker, kommentarer eller kritik er velkomne! 🙂

Tidsstempel:

Mere fra Datumboks