Aprofundando o algoritmo de recomendação ALS do Spark, PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.

Analisando o algoritmo de recomendação ALS do Spark

O algoritmo ALS introduzido por Hu et ai., é uma técnica muito popular usada em problemas de sistema de recomendação, especialmente quando temos conjuntos de dados implícitos (por exemplo, cliques, curtidas etc.). Ele pode lidar com grandes volumes de dados razoavelmente bem e podemos encontrar muitas boas implementações em várias estruturas de aprendizado de máquina. O Spark inclui o algoritmo no componente MLlib que foi recentemente refatorado para melhorar a legibilidade e a arquitetura do código.

A implementação do Spark requer que o item e a id do usuário sejam números dentro do intervalo inteiro (seja do tipo inteiro ou longo dentro do intervalo inteiro), o que é razoável, pois pode ajudar a acelerar as operações e reduzir o consumo de memória. Uma coisa que notei enquanto lia o código é que essas colunas de id estão sendo convertidas em Doubles e, em seguida, em Integers no início dos métodos de ajuste / previsão. Isso parece um pouco maluco e já vi colocar uma pressão desnecessária no coletor de lixo. Aqui estão as linhas no Código ALS que transforma os ids em duplos:
Aprofundando o algoritmo de recomendação ALS do Spark, PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.
Aprofundando o algoritmo de recomendação ALS do Spark, PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.

Para entender por que isso é feito, é necessário ler o checkCast ():
Aprofundando o algoritmo de recomendação ALS do Spark, PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.

Este UDF recebe um Double e verifica seu intervalo e então o converte em um inteiro. Este UDF é usado para validação de esquema. A questão é podemos conseguir isso sem usar dupla fundição feia? Eu acredito que sim:

  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

O código acima mostra um CHECKCast () modificado que recebe a entrada, verifica se o valor é numérico e levanta exceções caso contrário. Como a entrada é Any, podemos remover com segurança todo o elenco para instruções Double do resto do código. Além disso, é razoável esperar que, uma vez que o ALS requer ids dentro do intervalo de inteiros, a maioria das pessoas realmente usa tipos inteiros. Como resultado, na linha 3, esse método lida com inteiros explicitamente para evitar qualquer conversão. Para todos os outros valores numéricos, ele verifica se a entrada está dentro do intervalo inteiro. Essa verificação acontece na linha 7.

Pode-se escrever isso de forma diferente e lidar explicitamente com todos os tipos permitidos. Infelizmente, isso levaria à duplicação do código. Em vez disso, o que faço aqui é converter o número em inteiro e compará-lo com o número original. Se os valores forem idênticos, um dos seguintes será verdadeiro:

  1. O valor é Byte ou Short.
  2. O valor é longo, mas está dentro do intervalo de números inteiros.
  3. O valor é Double ou Float, mas sem nenhuma parte fracionária.

Para garantir que o código funcione bem, eu o testei com os testes de unidade padrão do Spark e manualmente, verificando o comportamento do método para vários valores legais e ilegais. Para garantir que a solução seja pelo menos tão rápida quanto a original, testei várias vezes usando o trecho abaixo. Isso pode ser colocado no Classe ALS Suite no Spark:

  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

Depois de alguns testes, podemos ver que a nova correção é um pouco mais rápida do que a original:

Code

Número de execuções

Tempo Total de Execução

Tempo médio de execução por execução

Óptimo estado. Original 100 588.458s 5.88458s
Fixo 100 566.722s 5.66722s

Repeti os experimentos várias vezes para confirmar e os resultados são consistentes. Aqui você pode encontrar a saída detalhada de um experimento para o código original e os votos de fixo. A diferença é pequena para um conjunto de dados minúsculo, mas no passado consegui atingir uma redução perceptível na sobrecarga de GC usando essa correção. Podemos confirmar isso executando o Spark localmente e anexando um criador de perfil Java na instância do Spark. Eu abri um bilhete e de um Solicitação de envio no repositório oficial do Spark mas porque não há certeza se será mesclado, pensei em compartilhá-lo aqui com você e agora faz parte do Spark 2.2.

Quaisquer pensamentos, comentários ou críticas são bem-vindos! 🙂

Carimbo de hora:

Mais de Caixa de dados