التعمق في خوارزمية توصية ALS من Spark وذكاء بيانات PlatoBlockchain. البحث العمودي. منظمة العفو الدولية.

الحفر في خوارزمية توصية ALS الخاصة بشركة Spark

تم تقديم خوارزمية ALS بواسطة هو وآخرون.، هي تقنية شائعة جدًا تستخدم في مشاكل نظام التوصية ، خاصةً عندما يكون لدينا مجموعات بيانات ضمنية (على سبيل المثال النقرات والإعجابات وما إلى ذلك). يمكنه التعامل مع كميات كبيرة من البيانات بشكل معقول ويمكننا العثور على العديد من التطبيقات الجيدة في مختلف أطر تعلم الآلة. يتضمن Spark الخوارزمية في مكون MLlib الذي تم إعادة تشكيله مؤخرًا لتحسين قابلية القراءة وبنية الكود.

يتطلب تطبيق Spark أن يكون العنصر ومعرف المستخدم أرقامًا ضمن نطاق عدد صحيح (إما نوع صحيح أو طويل ضمن نطاق عدد صحيح) ، وهو أمر معقول لأن هذا يمكن أن يساعد في تسريع العمليات وتقليل استهلاك الذاكرة. أحد الأشياء التي لاحظتها أثناء قراءة الكود هو أن أعمدة المعرف هذه يتم صبها في الزوجي ثم في الأعداد الصحيحة في بداية طرق الملاءمة / التنبؤ. يبدو هذا صعبًا بعض الشيء وقد رأيت أنه يضع ضغطًا غير ضروري على جامع القمامة. فيما يلي الخطوط الموجودة على كود ALS التي تحولت الهويات إلى أزواج:
التعمق في خوارزمية توصية ALS من Spark وذكاء بيانات PlatoBlockchain. البحث العمودي. منظمة العفو الدولية.
التعمق في خوارزمية توصية ALS من Spark وذكاء بيانات PlatoBlockchain. البحث العمودي. منظمة العفو الدولية.

لفهم سبب القيام بذلك ، يحتاج المرء إلى قراءة checkCast ():
التعمق في خوارزمية توصية ALS من Spark وذكاء بيانات PlatoBlockchain. البحث العمودي. منظمة العفو الدولية.

يتلقى هذا UDF مزدوجًا ويتحقق من نطاقه ثم يلقي به إلى عدد صحيح. يستخدم هذا UDF للتحقق من صحة المخطط. السؤال هو هل يمكننا تحقيق ذلك بدون استخدام المسبوكات المزدوجة القبيحة؟ أعتقد نعم:

  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

يُظهر الكود أعلاه checkCast () المعدل الذي يتلقى المدخلات ، ويؤكد الشيكات أن القيمة رقمية ويثير الاستثناءات بخلاف ذلك. نظرًا لأن الإدخال هو أي ، يمكننا بأمان إزالة جميع المصبوبات إلى العبارات المزدوجة من بقية الكود. علاوة على ذلك ، من المعقول أن نتوقع أنه نظرًا لأن ALS يتطلب معرفات داخل نطاق عدد صحيح ، فإن غالبية الأشخاص يستخدمون أنواعًا صحيحة بالفعل. نتيجة لذلك في السطر 3 ، تتعامل هذه الطريقة مع الأعداد الصحيحة بشكل صريح لتجنب القيام بأي عملية صب. بالنسبة لجميع القيم الرقمية الأخرى ، يتحقق ما إذا كان الإدخال ضمن نطاق عدد صحيح. يتم إجراء هذا الفحص في السطر 7.

يمكن للمرء أن يكتب هذا بشكل مختلف ويتعامل صراحة مع جميع الأنواع المسموح بها. لسوء الحظ ، سيؤدي هذا إلى رمز مكرر. بدلاً من ذلك ، ما أفعله هنا هو تحويل الرقم إلى عدد صحيح ومقارنته بالرقم الأصلي. إذا كانت القيم متطابقة ، فإن أحد الخيارات التالية يكون صحيحًا:

  1. القيمة هي بايت أو قصير.
  2. القيمة طويلة ولكن ضمن نطاق عدد صحيح.
  3. القيمة مزدوجة أو عائمة ولكن بدون أي جزء كسري.

للتأكد من أن الكود يعمل بشكل جيد ، قمت باختباره باستخدام اختبارات الوحدة القياسية لـ Spark يدويًا عن طريق التحقق من سلوك الطريقة لمختلف القيم القانونية وغير القانونية. للتأكد من أن الحل يكون على الأقل بنفس سرعة الحل الأصلي ، اختبرت عدة مرات باستخدام المقتطف أدناه. يمكن وضع هذا في ملف فئة ALSSuite في سبارك:

  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

بعد عدة اختبارات يمكننا أن نرى أن الإصلاح الجديد أسرع قليلاً من الإصلاح الأصلي:

رمز

عدد الأشواط

إجمالي وقت التنفيذ

متوسط ​​وقت التنفيذ لكل تشغيل

أصلي 100 588.458s 5.88458s
ثابت 100 566.722s 5.66722s

كررت التجارب عدة مرات للتأكيد وكانت النتائج متسقة. هنا يمكنك العثور على المخرجات التفصيلية لتجربة واحدة لـ الكود الأصلي و حل. الفرق صغير بالنسبة لمجموعة بيانات صغيرة ولكن في الماضي تمكنت من تحقيق انخفاض ملحوظ في حمل GC باستخدام هذا الإصلاح. يمكننا تأكيد ذلك عن طريق تشغيل Spark محليًا وإرفاق ملف تعريف Java في مثيل Spark. فتحت أ تذكرة و طلب سحب على Spark repo الرسمي ولكن نظرًا لأنه من غير المؤكد ما إذا كان سيتم دمجه ، فقد فكرت في مشاركته معك هنا وهو الآن جزء من Spark 2.2.

نرحب بأي أفكار أو تعليقات أو نقد! 🙂

الطابع الزمني:

اكثر من داتومبوكس