Перегляд алгоритму Spark ALS Recommendation PlatoBlockchain Data Intelligence. Вертикальний пошук. Ai.

Буріння в алгоритмі рекомендацій ALS від Spark

Алгоритм ALS, введений в Ху та ін., є дуже популярним методом, який використовується в проблемах системи рекомендацій, особливо коли ми маємо неявні набори даних (наприклад, кліки, лайки тощо). Він може досить добре обробляти великі обсяги даних, і ми можемо знайти багато хороших реалізацій у різних платформах машинного навчання. Spark містить алгоритм у компоненті MLlib, який нещодавно був перероблений для покращення читабельності та архітектури коду.

Реалізація Spark вимагає, щоб ідентифікатор елемента та користувача були числами в межах цілого діапазону (або типу Integer, або Long у межах цілого діапазону), що є розумним, оскільки це може допомогти прискорити операції та зменшити споживання пам’яті. Одне, що я помітив, коли читав код, це те, що ці стовпці id перетворюються на подвійні, а потім у цілі числа на початку методів fit/predict. Це здається трохи хакерським, і я бачив, що це створює непотрібне навантаження на збирач сміття. Ось рядки на ALS код що перетворює ідентифікатори на подвійні:
Перегляд алгоритму Spark ALS Recommendation PlatoBlockchain Data Intelligence. Вертикальний пошук. Ai.
Перегляд алгоритму Spark ALS Recommendation PlatoBlockchain Data Intelligence. Вертикальний пошук. Ai.

Щоб зрозуміти, чому це робиться, потрібно прочитати checkedCast():
Перегляд алгоритму Spark ALS Recommendation PlatoBlockchain Data Intelligence. Вертикальний пошук. Ai.

Цей UDF отримує Double і перевіряє його діапазон, а потім перетворює його на ціле число. Цей UDF використовується для перевірки схеми. Питання в тому, чи можемо ми досягти цього без використання потворних подвійних відливок? Я вірю, що так:

  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

Наведений вище код показує змінений checkedCast(), який отримує вхідні дані, перевіряє, що значення є числовим, і створює винятки в іншому випадку. Оскільки вхід є будь-яким, ми можемо безпечно видалити всі оператори приведення до Double з решти коду. Крім того, розумно очікувати, що оскільки ALS вимагає ідентифікаторів у межах цілого діапазону, більшість людей насправді використовують цілі типи. У результаті в рядку 3 цей метод обробляє цілі числа явно, щоб уникнути приведення. Для всіх інших числових значень він перевіряє, чи введені дані знаходяться в діапазоні цілих чисел. Ця перевірка відбувається в рядку 7.

Можна написати це по-іншому і явно обробити всі дозволені типи. На жаль, це призведе до дублювання коду. Замість цього я перетворюю число в ціле число і порівнюю його з вихідним числом. Якщо значення ідентичні, вірно одне з наступного:

  1. Значення — Byte або Short.
  2. Значення — Long, але в межах цілого діапазону.
  3. Значення має значення Double або Float, але без дробової частини.

Щоб переконатися, що код працює добре, я перевірив його за допомогою стандартних модульних тестів Spark і вручну, перевіривши поведінку методу для різних легальних і незаконних значень. Щоб переконатися, що рішення є принаймні таким же швидким, як оригінал, я багато разів тестував, використовуючи фрагмент нижче. Це можна помістити в Клас ALSSuite в Spark:

  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

Після кількох тестів ми бачимо, що нове виправлення трохи швидше, ніж оригінал:

код

Кількість пробігів

Загальний час виконання

Середній час виконання за запуск

Оригінал 100 588.458s 5.88458s
Виправлено 100 566.722s 5.66722s

Я повторив експерименти кілька разів, щоб підтвердити, і результати були послідовними. Тут ви можете знайти детальний результат одного експерименту для оригінальний код і фіксувати. Різниця невелика для крихітного набору даних, але в минулому мені вдавалося досягти помітного зменшення накладних витрат GC за допомогою цього виправлення. Ми можемо підтвердити це, запустивши Spark локально та підключивши профайлер Java до екземпляра Spark. Я відкрив а квиток і Pull-Request на офіційному репо Spark але оскільки невідомо, чи буде він об’єднаний, я вирішив поділитися цим з вами і тепер він є частиною Spark 2.2.

Будь-які думки, коментарі чи критика вітаються! 🙂

Часова мітка:

Більше від Датабокс