Fúrás a Spark ALS ajánlási algoritmusába, a PlatoBlockchain Data Intelligence-be. Függőleges keresés. Ai.

Fúrás a Spark ALS ajánlási algoritmusába

által bevezetett ALS algoritmus Hu és mtsai., egy nagyon népszerű technika Recommender System problémák esetén, különösen akkor, ha implicit adatkészleteink vannak (például kattintások, kedvelések stb.). Meglehetősen jól tud kezelni nagy mennyiségű adatot, és számos jó implementációt találunk a különféle Machine Learning keretrendszerekben. A Spark az MLlib komponensben tartalmazza az algoritmust, amelyet a közelmúltban átdolgoztak a kód olvashatóságának és architektúrájának javítása érdekében.

A Spark megvalósítása megköveteli, hogy az elem és a felhasználói azonosító egész tartományon belüli számok legyenek (Integer típusú vagy hosszú egész számok), ami ésszerű, mivel ez segíthet felgyorsítani a műveleteket és csökkenteni a memóriafelhasználást. Egy dolgot azonban észrevettem a kód olvasása közben, hogy ezek az id oszlopok az illesztési/előrejelzési metódusok elején duplákba, majd egész számokba kerülnek. Ez egy kicsit furcsának tűnik, és láttam, hogy feleslegesen megterheli a szemétszállítót. Itt vannak a vonalak a ALS kód amelyek az azonosítókat duplájára öntik:
Fúrás a Spark ALS ajánlási algoritmusába, a PlatoBlockchain Data Intelligence-be. Függőleges keresés. Ai.
Fúrás a Spark ALS ajánlási algoritmusába, a PlatoBlockchain Data Intelligence-be. Függőleges keresés. Ai.

Annak megértéséhez, hogy ez miért történik, el kell olvasni a checkedCast():
Fúrás a Spark ALS ajánlási algoritmusába, a PlatoBlockchain Data Intelligence-be. Függőleges keresés. Ai.

Ez az UDF duplát kap, és ellenőrzi a tartományát, majd egész számra adja. Ezt az UDF-et a séma ellenőrzésére használják. A kérdés az, hogy el tudjuk ezt érni csúnya dupla öntvények használata nélkül? azt hiszem, igen:

  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

A fenti kód egy módosított checkedCast()-et mutat, amely fogadja a bemenetet, ellenőrzi, hogy az érték numerikus-e, és egyébként kivételeket vet fel. Mivel a bemenet Any, biztonságosan eltávolíthatjuk az összes cast to Double utasítást a kód többi részéből. Ezen túlmenően joggal feltételezhető, hogy mivel az ALS-hez egész számok tartományon belüli azonosítók szükségesek, az emberek többsége valójában egész típusú típusokat használ. Ennek eredményeként a 3. sorban ez a módszer kifejezetten az egész számokat kezeli, hogy elkerülje az öntést. Minden más numerikus értéknél ellenőrzi, hogy a bemenet egész számok tartományon belül van-e. Ez az ellenőrzés a 7-es sorban történik.

Lehet ezt másként is írni, és kifejezetten kezelni az összes engedélyezett típust. Sajnos ez duplikált kódhoz vezet. Ehelyett azt csinálom, hogy a számot egész számmá alakítom, és összehasonlítom az eredeti számmal. Ha az értékek azonosak, a következők egyike igaz:

  1. Az érték Byte vagy Short.
  2. Az érték Long, de az Integer tartományon belül van.
  3. Az érték Dupla vagy Float, de törtrész nélkül.

A kód megfelelő működésének biztosítása érdekében a Spark szabványos egységtesztjeivel és manuálisan teszteltem a módszer viselkedését különféle legális és illegális értékek tekintetében. Annak érdekében, hogy a megoldás legalább olyan gyors legyen, mint az eredeti, többször teszteltem az alábbi részlet segítségével. Ez elhelyezhető a ALSUite osztály a Sparkban:

  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

Néhány teszt után láthatjuk, hogy az új javítás valamivel gyorsabb, mint az eredeti:

Kód

Futások száma

Teljes végrehajtási idő

Futásonkénti átlagos végrehajtási idő

eredeti 100 Ötvenes évek Ötvenes évek
Rögzített 100 Ötvenes évek Ötvenes évek

A kísérleteket többször megismételtem, hogy megerősítsem, és az eredmények konzisztensek. Itt megtalálhatja egy kísérlet részletes kimenetét a eredeti kód és a erősít. A különbség kicsi egy kis adathalmazhoz képest, de a múltban ezzel a javítással sikerült észrevehetően csökkenteni a GC többletköltségét. Ezt úgy tudjuk megerősíteni, ha helyileg futtatjuk a Sparkot, és csatolunk egy Java-profilozót a Spark-példányhoz. Kinyitottam a jegy és egy Pull-Request a hivatalos Spark repón de mivel bizonytalan, hogy összevonják-e, gondoltam megosztom itt veletek és most a Spark 2.2 része.

Bármilyen gondolatot, véleményt, kritikát szívesen fogadunk! 🙂

Időbélyeg:

Még több Datumbox