Borra i Sparks ALS-rekommendationsalgoritm PlatoBlockchain Data Intelligence. Vertikal sökning. Ai.

Borrning i Sparks ALS-rekommendationsalgoritm

ALS-algoritmen introducerad av Hu et al., är en mycket populär teknik som används i Recommender System-problem, speciellt när vi har implicita datauppsättningar (till exempel klick, gilla-markeringar etc). Den kan hantera stora datamängder någorlunda bra och vi kan hitta många bra implementeringar i olika Machine Learning-ramverk. Spark inkluderar algoritmen i MLlib-komponenten som nyligen har omstrukturerats för att förbättra läsbarheten och arkitekturen för koden.

Sparks implementering kräver att objektet och användar-id:t är siffror inom heltalsintervallet (antingen heltalstyp eller långt inom heltalsintervallet), vilket är rimligt eftersom detta kan hjälpa till att påskynda operationerna och minska minnesförbrukningen. En sak jag dock märkte när jag läste koden är att dessa id-kolumner kastas in i Doubles och sedan till heltal i början av fit/predict-metoderna. Det här verkar lite hackigt och jag har sett att det belastar sophämtaren i onödan. Här är raderna på ALS-kod som kastar id:en till dubbel:
Borra i Sparks ALS-rekommendationsalgoritm PlatoBlockchain Data Intelligence. Vertikal sökning. Ai.
Borra i Sparks ALS-rekommendationsalgoritm PlatoBlockchain Data Intelligence. Vertikal sökning. Ai.

För att förstå varför detta görs måste man läsa checkedCast():
Borra i Sparks ALS-rekommendationsalgoritm PlatoBlockchain Data Intelligence. Vertikal sökning. Ai.

Denna UDF tar emot en dubbel och kontrollerar dess räckvidd och kastar den sedan till heltal. Denna UDF används för schemavalidering. Frågan är kan vi uppnå detta utan att använda fula dubbelgjutningar? Jag tror ja:

  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

Koden ovan visar en modifierad checkedCast() som tar emot indata, kontroller hävdar att värdet är numeriskt och höjer undantag annars. Eftersom ingången är Any, kan vi säkert ta bort alla cast to Double-satser från resten av koden. Dessutom är det rimligt att förvänta sig att eftersom ALS kräver id inom heltalsintervallet, så använder majoriteten av människor faktiskt heltalstyper. Som ett resultat på rad 3 hanterar denna metod heltal uttryckligen för att undvika att göra någon casting. För alla andra numeriska värden kontrollerar den om ingången ligger inom heltalsområdet. Denna kontroll sker på rad 7.

Man skulle kunna skriva detta annorlunda och explicit hantera alla tillåtna typer. Tyvärr skulle detta leda till duplicerad kod. Det jag gör här är istället att konvertera talet till heltal och jämföra det med det ursprungliga talet. Om värdena är identiska är något av följande sant:

  1. Värdet är Byte eller Short.
  2. Värdet är långt men inom heltalsintervallet.
  3. Värdet är Double eller Float men utan någon bråkdel.

För att säkerställa att koden fungerar bra testade jag den med Sparks standardenhetstester och manuellt genom att kontrollera metodens beteende för olika lagliga och olagliga värden. För att säkerställa att lösningen är minst lika snabb som originalet testade jag flera gånger med hjälp av utdraget nedan. Denna kan placeras i ALSSvitklass i Spark:

  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

Efter några tester kan vi se att den nya fixen är något snabbare än originalet:

Koda

Antal körningar

Total utförandetid

Genomsnittlig exekveringstid per körning

Ursprungliga 100 588.458s 5.88458s
Fast 100 566.722s 5.66722s

Jag upprepade experimenten flera gånger för att bekräfta och resultaten är konsekventa. Här kan du hitta den detaljerade produktionen av ett experiment för ursprungskod och fast. Skillnaden är liten för en liten datamängd men tidigare har jag lyckats uppnå en märkbar minskning av GC-overhead med den här fixen. Vi kan bekräfta detta genom att köra Spark lokalt och bifoga en Java-profilerare på Spark-instansen. Jag öppnade en biljett och en Pull-Request på den officiella Spark-repo men eftersom det är osäkert om det kommer att slås samman tänkte jag dela det här med er och det är nu en del av Spark 2.2.

Alla tankar, kommentarer eller kritik är välkomna! 🙂

Tidsstämpel:

Mer från Datumbox