Spark ML和Scala Play模型服务

Spark ML和Scala Play模型服务,scala,apache-spark,playframework,apache-spark-mllib,apache-spark-ml,Scala,Apache Spark,Playframework,Apache Spark Mllib,Apache Spark Ml,我构建了一个Spark(2.2.0)ML管道,其输出为CrossValidatorModel,并使用save方法编写管道。我想使用Play(2.6.0)框架和Scala(2.11.11)为这个预先训练好的模型提供服务,但是在弄清楚如何在Play中使用Spark和/或加载模型的最佳方式时遇到了一些问题 关于我的播放设置,我的文件结构的相关内容相当简单: app/ controllers/ HomeController.scala ModelScorer.scala mode

我构建了一个Spark(2.2.0)ML管道,其输出为
CrossValidatorModel
,并使用
save
方法编写管道。我想使用Play(2.6.0)框架和Scala(2.11.11)为这个预先训练好的模型提供服务,但是在弄清楚如何在Play中使用Spark和/或加载模型的最佳方式时遇到了一些问题

关于我的播放设置,我的文件结构的相关内容相当简单:

app/
  controllers/
    HomeController.scala
    ModelScorer.scala
  models/
    Passenger.scala
    Prediction.scala
conf/
  routes
其中,
乘客
预测
分别是表示我的模型输入和输出的案例类
HomeController
表示将接收JSON格式的
POST
请求的逻辑,将内容解析为
Seq[Passenger]
,并将其提供给
ModelScorer.predict(data)
,如下所示

// HomeController.scala
package controllers

import javax.inject._
import models.{Passenger, Prediction}
import play.api.mvc._
import play.api.libs.json._
import play.api.libs.functional.syntax._

@Singleton
class HomeController @Inject()(cc: ControllerComponents) extends AbstractController(cc) {

  implicit val passengerReads: Reads[Passenger] = (
    ... // Various mappings
  )(Passenger.apply _)

  implicit val predictionWrites: Writes[Prediction] = (
    ... // Various mappings
  )(unlift(Prediction.unapply))

  def myEndpoint() = Action { implicit request: Request[AnyContent] =>
    val inputData: JsValue = request.body.asJson.get
    val passengers: Seq[Passenger] = inputData.validate[Seq[Passenger]].get
    val predictions: Seq[Prediction] = ModelScorer.predict(passengers)
    val outputData: JsValue = Json.toJson(predictions)

    Ok(outputData)
  }
}
要对预测进行评分,
ModelScorer
对象用于初始化
SparkSession
,使用Guava缓存加载模型,然后使用
predict
方法运行逻辑以将预测返回到
HomeController
。据我所知,有问题的行是
val ds:Dataset[Passenger]=passengers.toDS
,它告诉我Spark初始化有问题,但我不确定如何继续

// ModelScorer.scala
package controllers

import com.google.common.cache.{CacheBuilder, CacheLoader}
import models.{Passenger, Prediction}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.tuning.CrossValidatorModel
import org.apache.spark.sql.{Dataset, SparkSession}

object ModelScorer {

  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)

  val spark = SparkSession.builder
    .master("local[*]")
    .appName("ml-server")
    .getOrCreate()

  import spark.implicits._

  val modelCache = CacheBuilder.newBuilder()
    .build(
      new CacheLoader[String, CrossValidatorModel] {
        def load(path: String): CrossValidatorModel = {
          CrossValidatorModel.load(path)
        }
      }
    )

  val model: CrossValidatorModel = modelCache.get("trained-cv-pipeline")

  def predict(passengers: Seq[Passenger]): Seq[Prediction] = {

    val ds: Dataset[Passenger] = passengers.toDS
    val predictions: Seq[Prediction] = model.transform(ds)
      .select("name","probability","prediction")
      .withColumnRenamed("prediction","survives")
      .as[Prediction]
      .collect
      .toSeq
    predictions
  }

}
my
build.sbt
中必需的依赖项包括:

libraryDependencies ++= Seq(
  guice
  , "org.scalatestplus.play" %% "scalatestplus-play" % "3.1.0" % Test
  , "org.apache.spark" %% "spark-core" % "2.2.0"
  , "org.apache.spark" %% "spark-sql" % "2.2.0"
  , "org.apache.spark" %% "spark-mllib" % "2.2.0"
  , "org.apache.hadoop" % "hadoop-client" % "2.7.2"
)

dependencyOverrides ++= Set(
  "com.fasterxml.jackson.core" % "jackson-databind" % "2.6.5"
  , "com.google.guava" % "guava" % "19.0"
)
Stacktrace,在
POST
请求到
http://localhost:9000/myEndpoint
和必需的
JSON
,是:

@752mgi3ib - Internal server error, for (POST) [/myEndpoint] ->

play.api.http.HttpErrorHandlerExceptions$$anon$1: Execution 
exception[[ScalaReflectionException: class models.Passenger in 
JavaMirror with 
DependencyClassLoader{file:/Users/XXXX/.ivy2/cache/org.scala-
lang/scala-library/jars/scala-library-2.11.11.jar, 
  ...
  ... // Many, many lines
  ...
  ... :/Library/Java/JavaVirtualMachines/jdk1.8.0_131.jdk/Contents/Home/jre/classes] not found.
    at play.api.http.HttpErrorHandlerExceptions$.throwableToUsefulException(HttpErrorHandler.scala:255)
    at play.api.http.DefaultHttpErrorHandler.onServerError(HttpErrorHandler.scala:180)
    at play.core.server.AkkaHttpServer$$anonfun$13$$anonfun$apply$1.applyOrElse(AkkaHttpServer.scala:252)
    at play.core.server.AkkaHttpServer$$anonfun$13$$anonfun$apply$1.applyOrElse(AkkaHttpServer.scala:251)
    at scala.concurrent.Future$$anonfun$recoverWith$1.apply(Future.scala:346)
    at scala.concurrent.Future$$anonfun$recoverWith$1.apply(Future.scala:345)
    at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:36)
    at play.api.libs.streams.Execution$trampoline$.execute(Execution.scala:70)
    at scala.concurrent.impl.CallbackRunnable.executeWithValue(Promise.scala:44)
    at scala.concurrent.impl.Promise$DefaultPromise.scala$concurrent$impl$Promise$
DefaultPromise$$dispatchOrAddCallback(Promise.scala:284) Caused by: scala.ScalaReflectionException: class models.Passenger in JavaMirror with DependencyClassLoader{file:
  ...
  ... // Many, many lines
  ...
  ... :/Library/Java/JavaVirtualMachines/jdk1.8.0_131.jdk/Contents/Home/jre/classes] not found.
    at scala.reflect.internal.Mirrors$RootsBase.staticClass(Mirrors.scala:123)
    at scala.reflect.internal.Mirrors$RootsBase.staticClass(Mirrors.scala:22)
    at controllers.ModelScorer$$typecreator3$1.apply(ModelScorer.scala:34)
    at scala.reflect.api.TypeTags$WeakTypeTagImpl.tpe$lzycompute(TypeTags.scala:232)
    at scala.reflect.api.TypeTags$WeakTypeTagImpl.tpe(TypeTags.scala:232)
    at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$.apply(ExpressionEncoder.scala:49)
    at org.apache.spark.sql.Encoders$.product(Encoders.scala:275)
    at org.apache.spark.sql.LowPrioritySQLImplicits$class.newProductEncoder(SQLImplicits.scala:233)
    at org.apache.spark.sql.SQLImplicits.newProductEncoder(SQLImplicits.scala:33)
    at controllers.ModelScorer$.predict(ModelScorer.scala:34)

我最好将问题追溯到在
ModelScorer.predict(passengers)
中创建数据集,更具体地说是
val-ds:Dataset[Passenger]=passengers.toDS
行,尽管我可以使用sbt控制台在REPL中运行该行,这使我认为将Spark集成到游戏中是一个问题。有点不知所措如何进行,任何和所有的指导感谢

您必须在play框架内初始化spark上下文。或者,如果您正在提交fat jar以spark submit,则不需要初始化上下文。