Vrtanje v Sparkov algoritem priporočila ALS PlatoBlockchain Data Intelligence. Navpično iskanje. Ai.

Vrtanje v algoritem priporočila ALS Spark

Algoritem ALS, ki ga je predstavil Hu in sod., je zelo priljubljena tehnika, ki se uporablja pri težavah s sistemom Recommender, zlasti kadar imamo implicitne nize podatkov (na primer kliki, všečki itd.). Razmerno dobro lahko obravnava velike količine podatkov in najdemo lahko veliko dobrih izvedb v različnih okvirih strojnega učenja. Spark vključuje algoritem v komponenti MLlib, ki je bil pred kratkim preoblikovan za izboljšanje berljivosti in arhitekture kode.

Implementacija Spark zahteva, da sta ID predmeta in uporabnika številki v obsegu celih števil (bodisi tipa Integer ali Long znotraj obsega celih števil), kar je razumno, saj lahko to pomaga pospešiti operacije in zmanjša porabo pomnilnika. Ena stvar, ki sem jo opazil med branjem kode, je, da se ti stolpci z id-ji na začetku metod prileganja/predvidevanja uvrščajo v Doubles in nato v Integers. To se zdi malce hecno in videl sem, da po nepotrebnem obremenjuje zbiralnik smeti. Tukaj so vrstice na koda ALS ki ID-je pretvori v dvojnike:
Vrtanje v Sparkov algoritem priporočila ALS PlatoBlockchain Data Intelligence. Navpično iskanje. Ai.
Vrtanje v Sparkov algoritem priporočila ALS PlatoBlockchain Data Intelligence. Navpično iskanje. Ai.

Če želite razumeti, zakaj je to storjeno, morate prebrati checkedCast():
Vrtanje v Sparkov algoritem priporočila ALS PlatoBlockchain Data Intelligence. Navpično iskanje. Ai.

Ta UDF prejme Double in preveri njegov obseg ter ga nato pretvori v celo število. Ta UDF se uporablja za preverjanje sheme. Vprašanje je, ali lahko to dosežemo brez uporabe grdih dvojnih odlitkov? Verjamem, da ja:

  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

Zgornja koda prikazuje spremenjen checkedCast(), ki prejme vhod, preveri, ali je vrednost številska, in sproži izjeme v nasprotnem primeru. Ker je vnos Any, lahko varno odstranimo vse stavke za pretvorbo v Double iz preostale kode. Poleg tega je razumno pričakovati, da večina ljudi dejansko uporablja celoštevilske tipe, ker ALS zahteva ID-je v območju celih števil. Posledično v vrstici 3 ta metoda izrecno obravnava cela števila, da se izogne ​​kakršnemu koli prelivanju. Za vse druge številske vrednosti preveri, ali je vnos v območju celih števil. To preverjanje se izvede v vrstici 7.

To bi lahko napisali drugače in eksplicitno obravnavali vse dovoljene vrste. Na žalost bi to vodilo do podvojene kode. Namesto tega tukaj pretvorim število v celo število in ga primerjam z izvirnim številom. Če sta vrednosti enaki, velja eno od naslednjega:

  1. Vrednost je Byte ali Short.
  2. Vrednost je Long, vendar znotraj obsega Integer.
  3. Vrednost je Double ali Float, vendar brez ulomka.

Da bi zagotovil, da koda dobro deluje, sem jo preizkusil s standardnimi testi enote Spark in ročno s preverjanjem obnašanja metode za različne zakonite in nezakonite vrednosti. Da bi zagotovil, da je rešitev vsaj tako hitra kot izvirnik, sem večkrat preizkusil s spodnjim delčkom. To lahko postavite v ALSSuite razred v Sparku:

  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

Po nekaj testih lahko vidimo, da je novi popravek nekoliko hitrejši od izvirnika:

Koda

Število tekov

Skupni čas izvedbe

Povprečni čas izvajanja na zagon

prvotni 100 588.458s 5.88458s
Določi 100 566.722s 5.66722s

Poskuse sem večkrat ponovil, da sem potrdil in rezultati so dosledni. Tukaj lahko najdete podroben rezultat enega poskusa za originalna koda in fiksna. Razlika je majhna za majhen nabor podatkov, vendar mi je v preteklosti s tem popravkom uspelo doseči opazno zmanjšanje režijskih stroškov GC. To lahko potrdimo tako, da lokalno zaženemo Spark in na primerek Spark pripnemo orodje za profiliranje Java. Odprl sem a Vstopnica in Pull-Request na uradnem repo Spark a ker ni gotovo, ali bo združeno, sem ga mislil deliti tukaj z vami in zdaj je del Spark 2.2.

Vsakršno razmišljanje, komentar ali kritika je dobrodošla! 🙂

Časovni žig:

Več od Datumbox