Γεωτρήσεις στον αλγόριθμο σύστασης ALS του Spark, PlatoBlockchain Data Intelligence. Κάθετη αναζήτηση. Ολα συμπεριλαμβάνονται.

Διάτρηση στον αλγόριθμο Συστάσεων ALS του Spark

Ο αλγόριθμος ALS που εισήγαγε ο Hu et al., είναι μια πολύ δημοφιλής τεχνική που χρησιμοποιείται σε προβλήματα συστημάτων προτάσεων, ειδικά όταν έχουμε σιωπηρά σύνολα δεδομένων (για παράδειγμα κλικ, συμπάθειες κ.λπ.). Μπορεί να χειριστεί αρκετά καλά μεγάλο όγκο δεδομένων και μπορούμε να βρούμε πολλές καλές εφαρμογές σε διάφορα πλαίσια μηχανικής μάθησης. Το Spark περιλαμβάνει τον αλγόριθμο στο στοιχείο MLlib, ο οποίος έχει αναδιαμορφωθεί πρόσφατα για να βελτιώσει την αναγνωσιμότητα και την αρχιτεκτονική του κώδικα.

Η εφαρμογή του Spark απαιτεί το Item και το User ID να είναι αριθμοί εντός ακέραιου εύρους (είτε τύπου Integer είτε Long εντός ακέραιου εύρους), κάτι που είναι λογικό καθώς αυτό μπορεί να βοηθήσει στην επιτάχυνση των λειτουργιών και στη μείωση της κατανάλωσης μνήμης. Ένα πράγμα που παρατήρησα όμως κατά την ανάγνωση του κώδικα είναι ότι αυτές οι στήλες ταυτότητας μεταφέρονται σε διπλά και στη συνέχεια σε ακέραιους αριθμούς στην αρχή των μεθόδων προσαρμογής/πρόβλεψης. Αυτό φαίνεται λίγο χακαρό και το έχω δει να ασκεί περιττή πίεση στον συλλέκτη σκουπιδιών. Εδώ είναι οι γραμμές στο Κωδικός ALS που ρίχνει τα αναγνωριστικά σε διπλά:
Γεωτρήσεις στον αλγόριθμο σύστασης ALS του Spark, PlatoBlockchain Data Intelligence. Κάθετη αναζήτηση. Ολα συμπεριλαμβάνονται.
Γεωτρήσεις στον αλγόριθμο σύστασης ALS του Spark, PlatoBlockchain Data Intelligence. Κάθετη αναζήτηση. Ολα συμπεριλαμβάνονται.

Για να καταλάβετε γιατί γίνεται αυτό, πρέπει να διαβάσετε το checkCast ():
Γεωτρήσεις στον αλγόριθμο σύστασης ALS του Spark, PlatoBlockchain Data Intelligence. Κάθετη αναζήτηση. Ολα συμπεριλαμβάνονται.

Αυτό το UDF λαμβάνει ένα διπλό και ελέγχει το εύρος του και στη συνέχεια το μεταφέρει σε ακέραιο. Αυτό το UDF χρησιμοποιείται για την επικύρωση σχήματος. Το ερώτημα είναι αν μπορούμε να το πετύχουμε αυτό χωρίς τη χρήση άσχημων διπλών χυτών; Πιστεύω ναι:

  protected val checkedCast = udf { (n: Any) =>
    n match {
      case v: Int => v // Avoid unnecessary casting
      case v: Number =>
        val intV = v.intValue()
        // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
        if (v.doubleValue == intV) {
          intV
        }
        else {
          throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
            s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
        }
      case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
        s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
    }
  }

Ο παραπάνω κώδικας δείχνει ένα τροποποιημένο checkCast () που λαμβάνει την είσοδο, ελέγχει ότι η τιμή είναι αριθμητική και εγείρει εξαιρέσεις διαφορετικά. Δεδομένου ότι η είσοδος είναι Οποιαδήποτε, μπορούμε να αφαιρέσουμε με ασφάλεια όλο το cast στο Double statement από τον υπόλοιπο κώδικα. Επιπλέον, είναι λογικό να αναμένεται ότι δεδομένου ότι το ALS απαιτεί αναγνωριστικά εντός ακέραιου εύρους, η πλειοψηφία των ανθρώπων χρησιμοποιεί πραγματικά ακέραιους τύπους. Ως αποτέλεσμα στη γραμμή 3, αυτή η μέθοδος χειρίζεται ακέραιους αριθμούς για να αποφύγει τη μετάδοση. Για όλες τις άλλες αριθμητικές τιμές ελέγχει αν η είσοδος βρίσκεται εντός ακέραιου εύρους. Αυτός ο έλεγχος γίνεται στη γραμμή 7.

Θα μπορούσε κανείς να το γράψει διαφορετικά και να χειριστεί ρητά όλους τους επιτρεπόμενους τύπους. Δυστυχώς, αυτό θα οδηγούσε σε διπλό κώδικα. Αντ 'αυτού, αυτό που κάνω εδώ είναι να μετατρέψω τον αριθμό σε ακέραιο και να τον συγκρίνω με τον αρχικό αριθμό. Εάν οι τιμές είναι πανομοιότυπες, ισχύει μία από τις ακόλουθες:

  1. Η τιμή είναι Byte ή Short.
  2. Η τιμή είναι Long αλλά εντός του εύρους του ακέραιου.
  3. Η τιμή είναι Double ή Float αλλά χωρίς κλασματικό μέρος.

Για να διασφαλίσω ότι ο κώδικας λειτουργεί καλά, τον δοκίμασα με τις τυπικές μονάδες-δοκιμές του Spark και χειροκίνητα ελέγχοντας τη συμπεριφορά της μεθόδου για διάφορες νόμιμες και παράνομες τιμές. Για να διασφαλίσω ότι η λύση είναι τουλάχιστον τόσο γρήγορη όσο η αρχική, δοκίμασα πολλές φορές χρησιμοποιώντας το παρακάτω απόσπασμα. Αυτό μπορεί να τοποθετηθεί στο ALSSuite τάξη στο Spark:

  test("Speed difference") {
    val (training, test) =
      genExplicitTestData(numUsers = 200, numItems = 400, rank = 2, noiseStd = 0.01)

    val runs = 100
    var totalTime = 0.0
    println("Performing "+runs+" runs")
    for(i <- 0 until runs) {
      val t0 = System.currentTimeMillis
      testALS(training, test, maxIter = 1, rank = 2, regParam = 0.01, targetRMSE = 0.1)
      val secs = (System.currentTimeMillis - t0)/1000.0
      println("Run "+i+" executed in "+secs+"s")
      totalTime += secs
    }
    println("AVG Execution Time: "+(totalTime/runs)+"s")

  }

Μετά από μερικές δοκιμές μπορούμε να δούμε ότι η νέα επιδιόρθωση είναι ελαφρώς ταχύτερη από την αρχική:

Κώδικας

Αριθμός εκτελέσεων

Συνολικός χρόνος εκτέλεσης

Μέσος Χρόνος Εκτέλεσης ανά Εκτέλεση

Πρωτότυπο 100 588.458s 5.88458s
Σταθερό 100 566.722s 5.66722s

Επανέλαβα τα πειράματα πολλές φορές για επιβεβαίωση και τα αποτελέσματα είναι συνεπή. Εδώ μπορείτε να βρείτε τη λεπτομερή έξοδο ενός πειράματος για το αρχικός κώδικας και την σταθερόςΤο Η διαφορά είναι μικρή για ένα μικροσκοπικό σύνολο δεδομένων, αλλά στο παρελθόν κατάφερα να επιτύχω μια αισθητή μείωση στα γενικά έξοδα GC χρησιμοποιώντας αυτήν την επιδιόρθωση. Μπορούμε να το επιβεβαιώσουμε εκτελώντας το Spark τοπικά και επισυνάπτοντας έναν προφίλ Java στην παρουσία Spark. Άνοιξα ένα εισιτήριο και σε έναν Τράβηγμα-Αίτημα στο επίσημο repo του Spark αλλά επειδή είναι αβέβαιο αν θα συγχωνευτεί, σκέφτηκα να το μοιραστώ εδώ μαζί σας και είναι πλέον μέρος του Spark 2.2.

Οποιαδήποτε σκέψη, σχόλιο ή κριτική είναι ευπρόσδεκτη! 🙂

Σφραγίδα ώρας:

Περισσότερα από Databox