Детализация алгоритма рекомендаций ALS от Spark PlatoBlockchain Data Intelligence. Вертикальный поиск. Ай.

Детализация алгоритма ALS Рекомендации Spark

Алгоритм ALS, представленный Ху и соавт., это очень популярный метод, используемый в проблемах с системой Recommender System, особенно когда у нас есть неявные наборы данных (например, клики, лайки и т. д.). Он может достаточно хорошо обрабатывать большие объемы данных, и мы можем найти много хороших реализаций в различных системах машинного обучения. Spark включает алгоритм в компонент MLlib, который недавно был подвергнут рефакторингу для улучшения читабельности и архитектуры кода.

Реализация Spark требует, чтобы Item и User id были числами в целочисленном диапазоне (либо целочисленного типа, либо Long в целочисленном диапазоне), что является разумным, поскольку это может помочь ускорить операции и уменьшить потребление памяти. Однако при чтении кода я заметил одну вещь: эти столбцы идентификаторов преобразуются в Doubles, а затем в Integer в начале методов соответствия / предсказания. Это кажется немного странным, и я видел, как это создает ненужную нагрузку на сборщик мусора. Вот строки на Код ALS которые приводят идентификаторы в двойники:
Детализация алгоритма рекомендаций ALS от Spark PlatoBlockchain Data Intelligence. Вертикальный поиск. Ай.
Детализация алгоритма рекомендаций ALS от Spark PlatoBlockchain Data Intelligence. Вертикальный поиск. Ай.

Чтобы понять, почему это делается, нужно прочитать checkCast ():
Детализация алгоритма рекомендаций ALS от Spark PlatoBlockchain Data Intelligence. Вертикальный поиск. Ай.

Этот UDF получает Double и проверяет его диапазон, а затем переводит его в целое число. Этот UDF используется для проверки схемы. Вопрос в том, можем ли мы достичь этого, не используя уродливые двойные отливки? Я верю да:

  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

Приведенный выше код показывает измененный selectedCast (), который получает входные данные, проверяет, что значение является числовым, и в противном случае вызывает исключения. Поскольку входное значение Any, мы можем безопасно удалить все операторы приведения к Double из остальной части кода. Более того, разумно ожидать, что, поскольку ALS требует идентификаторы в диапазоне целых чисел, большинство людей фактически используют целочисленные типы. В результате в строке 3 этот метод явно обрабатывает целые числа, чтобы избежать приведения. Для всех других числовых значений он проверяет, находится ли ввод в целочисленном диапазоне. Эта проверка происходит в строке 7.

Можно написать это по-другому и явно обрабатывать все разрешенные типы. К сожалению, это приведет к дублированию кода. Вместо этого я конвертирую число в целое число и сравниваю его с исходным числом. Если значения идентичны, одно из следующих условий:

  1. Значение является байтовым или коротким.
  2. Значение Long, но в диапазоне Integer.
  3. Значение Double или Float, но без дробной части.

Чтобы убедиться, что код работает хорошо, я протестировал его с помощью стандартных модульных тестов Spark и вручную, проверив поведение метода на наличие различных допустимых и недопустимых значений. Чтобы убедиться, что решение, по крайней мере, так же быстро, как и оригинал, я много раз тестировал, используя приведенный ниже фрагмент кода. Это можно поместить в ALSSuite класс в Spark:

  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

После нескольких тестов мы видим, что новое исправление немного быстрее оригинального:

Code

Количество прогонов

Общее время выполнения

Среднее время выполнения за прогон

Оригинал 100 588.458s 5.88458s
Исправлена 100 566.722s 5.66722s

Я повторил эксперименты несколько раз, чтобы подтвердить, и результаты согласуются. Здесь вы можете найти подробный результат одного эксперимента для исходный код и фиксировать, Разница небольшая для небольшого набора данных, но в прошлом мне удалось добиться заметного снижения накладных расходов GC с помощью этого исправления. Мы можем подтвердить это, запустив Spark локально и подключив Java-профилировщик к экземпляру Spark. Я открыл билет и еще один Pull-запрос на официальном репо Spark но так как неясно, будет ли он объединен, я решил поделиться этим здесь с вами и теперь это часть Spark 2.2.

Любые мысли, комментарии или критика приветствуются! 🙂

Отметка времени:

Больше от Датумбокс