- 2017 年 2 月 25 日
- 瓦西利斯·弗里尼奥提斯(Vasilis Vryniotis)
- 。 3条留言
ALS算法由 Hu等。, 是推荐系统问题中使用的一种非常流行的技术,尤其是当我们有隐式数据集时(例如点击、喜欢等)。 它可以很好地处理大量数据,我们可以在各种机器学习框架中找到许多好的实现。 Spark 将算法包含在 MLlib 组件中,该组件最近已被重构以提高代码的可读性和架构。
Spark 的实现要求 Item 和 User id 为整数范围内的数字(整数类型或整数范围内的 Long),这是合理的,因为这有助于加快操作并减少内存消耗。 我在阅读代码时注意到的一件事是,这些 id 列在 fit/predict 方法的开头被转换为 Doubles 然后转换为 Integers。 这似乎有点hacky,我已经看到它给垃圾收集器带来了不必要的压力。 这是上面的行 ALS 代码 将ID转换为双打:
要了解这样做的原因,需要阅读 checkedCast():
此 UDF 接收 Double 并检查其范围,然后将其转换为整数。 此 UDF 用于模式验证。 问题是我们可以在不使用丑陋的双重铸件的情况下实现这一目标吗? 我相信是的:
protected val checkedCast = udf { (n: Any) => n match { case v: Int => v // Avoid unnecessary casting case v: Number => val intV = v.intValue() // True for Byte/Short, Long within the Int range and Double/Float with no fractional part. if (v.doubleValue == intV) { intV } else { throw new IllegalArgumentException(s"ALS only supports values in Integer range " + s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") } case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " + s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.") } }
上面的代码显示了一个修改后的checkedCast(),它接收输入,检查断言该值是数字,否则引发异常。 由于输入是 Any,我们可以安全地从其余代码中删除所有强制转换为 Double 的语句。 此外,可以合理地预期,由于 ALS 需要整数范围内的 id,因此大多数人实际上使用整数类型。 因此,在第 3 行,此方法显式处理整数以避免进行任何强制转换。 对于所有其他数值,它会检查输入是否在整数范围内。 此检查发生在第 7 行。
可以用不同的方式编写它并显式处理所有允许的类型。 不幸的是,这会导致重复代码。 相反,我在这里所做的是将数字转换为整数并将其与原始数字进行比较。 如果值相同,则以下情况之一为真:
- 该值为字节或短。
- 该值为 Long 但在 Integer 范围内。
- 该值是 Double 或 Float,但没有任何小数部分。
为了确保代码运行良好,我使用 Spark 的标准单元测试并通过检查方法的行为来手动检查各种合法和非法值的行为。 为了确保解决方案至少与原始解决方案一样快,我使用下面的代码片段进行了多次测试。 这个可以放在 ALSSuite 类 在Spark中:
test("Speed difference") { val (training, test) = genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01) val runs = 100 var totalTime = 0.0 println("Performing "+runs+" runs") for(i <- 0 until runs) { val t0 = System.currentTimeMillis testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1) val secs = (System.currentTimeMillis - t0)/1000.0 println("Run "+i+" executed in "+secs+"s") totalTime += secs } println("AVG Execution Time: "+(totalTime/runs)+"s") }
经过几次测试,我们可以看到新的修复比原来的要快一些:
代码 |
运行次数 |
总执行时间 |
每次运行的平均执行时间 |
原版 | 100 | 588.458s | 5.88458s |
固定 | 100 | 566.722s | 5.66722s |
我多次重复实验以确认,结果是一致的。 在这里你可以找到一个实验的详细输出 原始代码 和 固定. 对于一个很小的数据集来说,差异很小,但在过去,我已经成功地使用这个修复实现了 GC 开销的显着减少。 我们可以通过在本地运行 Spark 并在 Spark 实例上附加 Java 分析器来确认这一点。 我开了一个 票 的网络 拉取请求 在官方 Spark 存储库中 但是因为不确定会不会合并,所以想在这里分享给大家 它现在是 Spark 2.2 的一部分。
欢迎任何想法,评论或批评! 🙂