Why does Spark ML ALS algorithm print RMSE = NaN?

I am using ALS to predict rating, this is my code:

val als = new ALS()
  .setMaxIter(5)
  .setRegParam(0.01)
  .setUserCol("user_id")
  .setItemCol("business_id")
  .setRatingCol("stars")
val model = als.fit(training)

// Evaluate the model by computing the RMSE on the test data
val predictions = model.transform(testing)
predictions.sort("user_id").show(1000)
val evaluator = new RegressionEvaluator()
  .setMetricName("rmse")
  .setLabelCol("stars")
  .setPredictionCol("prediction")
val rmse = evaluator.evaluate(predictions)
println(s"Root-mean-square error = $rmse")

      

But get some negative ratings and the RMSE is Nan:

+-------+-----------+---------+------------+
|user_id|business_id|    stars|  prediction|
+-------+-----------+---------+------------+
|      0|       2175|      4.0|   4.0388923|
|      0|       5753|      3.0|   2.6875196|
|      0|       9199|      4.0|   4.1753435|
|      0|      16416|      2.0|   -2.710618|
|      0|       6063|      3.0|         NaN|
|      0|      23076|      2.0|  -0.8930751|

Root-mean-square error = NaN

      

How to get a good result?

+3
scala machine-learning apache-spark


source to share


3 answers


Negative values ​​are irrelevant, as the RMSE squares the values ​​first. You probably have empty prediction values. You can leave them:

predictions.na().drop(["prediction"])

      

While this might be a little flawed, you can alternatively fill in these values ​​with the lowest / highest / average rating.



I would also recommend to round x < min_rating

and x > max_rating

to the smallest / highest ratings, which will improve your RMSE.

EDIT:

More info here: https://issues.apache.org/jira/browse/SPARK-14489

+3


source to share


As of Spark 2.2.0, you can set the parameter coldStartStrategy

to drop

to remove any rows in the Prediction DataFrame that contain NaN values. The scoring metric will then be computed over data other than NaN and will be valid.



model.setColdStartStrategy("drop");

      

+1


source to share


A small correction will help solve this problem:

prediction.na.drop ()

0


source to share







All Articles
Loading...
X
Show
Funny
Dev
Pics