Zagłębianie się w algorytm rekomendacji ALS Sparka PlatoBlockchain Data Intelligence. Wyszukiwanie pionowe. AI.

Wiercenie w algorytmie rekomendacji ALS Spark'a

Algorytm ALS wprowadzony przez Hu i in., jest bardzo popularną techniką stosowaną w problemach z systemem polecającym, zwłaszcza gdy mamy niejawne zbiory danych (na przykład kliknięcia, polubienia itp.). Dość dobrze radzi sobie z dużymi ilościami danych i możemy znaleźć wiele dobrych implementacji w różnych strukturach uczenia maszynowego. Spark zawiera algorytm w składniku MLlib, który został niedawno refaktoryzowany w celu poprawy czytelności i architektury kodu.

Implementacja platformy Spark wymaga, aby identyfikator elementu i użytkownika były liczbami z zakresu liczb całkowitych (typu Integer lub Long w zakresie liczb całkowitych), co jest rozsądne, ponieważ może to przyspieszyć operacje i zmniejszyć zużycie pamięci. Jedną rzeczą, którą zauważyłem podczas czytania kodu, jest to, że te kolumny id są rzutowane na podwójne, a następnie na liczby całkowite na początku metody fit / Predict. Wydaje się to trochę dziwaczne i widziałem, jak niepotrzebnie obciąża śmieciarkę. Oto linie na Kod ALS które rzucają identyfikatory na podwójne:
Zagłębianie się w algorytm rekomendacji ALS Sparka PlatoBlockchain Data Intelligence. Wyszukiwanie pionowe. AI.
Zagłębianie się w algorytm rekomendacji ALS Sparka PlatoBlockchain Data Intelligence. Wyszukiwanie pionowe. AI.

Aby zrozumieć, dlaczego tak się dzieje, należy przeczytać checkCast ():
Zagłębianie się w algorytm rekomendacji ALS Sparka PlatoBlockchain Data Intelligence. Wyszukiwanie pionowe. AI.

Ten UDF otrzymuje Double i sprawdza jego zakres, a następnie rzutuje go na liczbę całkowitą. Ten UDF jest używany do sprawdzania poprawności schematu. Pytanie brzmi, czy możemy to osiągnąć bez stosowania brzydkich podwójnych odlewów? Wierzę, że tak:

  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

Powyższy kod przedstawia zmodyfikowaną metodę checkCast (), która otrzymuje dane wejściowe, sprawdza, czy wartość jest numeryczna, aw przeciwnym razie zgłasza wyjątki. Ponieważ dane wejściowe to Any, możemy bezpiecznie usunąć całe rzutowanie na instrukcje Double z reszty kodu. Ponadto rozsądnie jest oczekiwać, że skoro ALS wymaga identyfikatorów z zakresu liczb całkowitych, większość ludzi faktycznie używa typów całkowitych. W rezultacie w linii 3 ta metoda obsługuje liczby całkowite jawnie, aby uniknąć wykonywania rzutowania. Dla wszystkich innych wartości liczbowych sprawdza, czy dane wejściowe mieszczą się w zakresie liczb całkowitych. To sprawdzenie odbywa się w linii 7.

Można to napisać inaczej i jawnie obsłużyć wszystkie dozwolone typy. Niestety doprowadziłoby to do zduplikowania kodu. Zamiast tego robię tutaj konwersję liczby na liczbę całkowitą i porównywanie jej z oryginalną liczbą. Jeśli wartości są identyczne, spełniony jest jeden z poniższych warunków:

  1. Wartość to Byte lub Short.
  2. Wartość jest długa, ale mieści się w zakresie liczb całkowitych.
  3. Wartość to Double lub Float, ale bez części ułamkowej.

Aby upewnić się, że kod działa dobrze, przetestowałem go za pomocą standardowych testów jednostkowych Spark i ręcznie, sprawdzając zachowanie metody dla różnych legalnych i nielegalnych wartości. Aby upewnić się, że rozwiązanie jest co najmniej tak szybkie jak oryginał, testowałem wiele razy, używając poniższego fragmentu kodu. Można to umieścić w Klasa ALSSuite w Spark:

  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

Po kilku testach widać, że nowa poprawka jest nieco szybsza od oryginału:

Code

Liczba uruchomień

Całkowity czas wykonania

Średni czas wykonania na przebieg

Oryginalny 100 588.458s 5.88458s
Stały 100 566.722s 5.66722s

Powtarzałem eksperymenty wiele razy, aby potwierdzić, a wyniki są zgodne. Tutaj można znaleźć szczegółowe wyniki jednego eksperymentu dla oryginalny kod oraz stały. Różnica jest niewielka dla małego zestawu danych, ale w przeszłości udało mi się osiągnąć zauważalne zmniejszenie narzutu GC przy użyciu tej poprawki. Możemy to potwierdzić, uruchamiając Spark lokalnie i dołączając profiler Java do wystąpienia Spark. Otworzyłem bilet oraz Prośba o pociągnięcie w oficjalnym repozytorium Sparka ale ponieważ nie jest pewne, czy zostanie połączone, pomyślałem, że podzielę się nim z wami i jest teraz częścią Spark 2.2.

Wszelkie uwagi, komentarze lub krytyka są mile widziane! 🙂

Znak czasu:

Więcej z Skrzynka odniesienia