Apache spark 将kmeans模型注册为UDF

Apache spark 将kmeans模型注册为UDF,apache-spark,apache-spark-sql,spark-streaming,apache-spark-mllib,Apache Spark,Apache Spark Sql,Spark Streaming,Apache Spark Mllib,嗨,我正在尝试使用Spark kmeans模型来预测集群数量。但是当我注册它并在SQL中使用它时,它给了我一个 java.lang.reflect.InvocationTargetException def findCluster(s:String):Int={ model.predict(feautarize(s)) } 我正在使用下面的 %sql select findCluster((text)) from tweets 如果我直接使用它,同样有效 findCluster("h

嗨,我正在尝试使用Spark kmeans模型来预测集群数量。但是当我注册它并在SQL中使用它时,它给了我一个

java.lang.reflect.InvocationTargetException

def findCluster(s:String):Int={
    model.predict(feautarize(s))
}
我正在使用下面的

%sql select findCluster((text)) from tweets
如果我直接使用它,同样有效

findCluster("hello am vishnu")

输出1

不可能用您提供的代码重现问题。假设
model
org.apache.spark.mllib.clustering.KMeansModel
,下面是逐步解决方案

首先,允许导入所需的库并设置RNG种子:

import scala.util.Random
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.Vectors

Random.setSeed(0L)
生成随机列车集:

// Generate random training set
val trainData = sc.parallelize((1 to 1000).map { _ =>
    val off = if(Random.nextFloat > 0.5) 0.5 else -0.5
    Vectors.dense(Random.nextFloat + off, Random.nextFloat + off)
})
运行KMeans

// Train KMeans with 2 clusters

val numClusters = 2
val numIterations = 20

val clusters = KMeans.train(trainData, numClusters, numIterations)
创建自定义项

// Create broadcast variable with model and prediction function 
val model = sc.broadcast(clusters)
def findCluster(v: org.apache.spark.mllib.linalg.Vector):Int={
    model.value.predict(v)
}

// Register UDF
sqlContext.udf.register("findCluster", findCluster _)
准备测试集

// Create test set
case class Coord(v: org.apache.spark.mllib.linalg.Vector)
val testData = sqlContext.createDataFrame(sc.parallelize((1 to 100).map { _ =>
    val off = if(Random.nextFloat > 0.5) 0.5 else -0.5
    Coord(Vectors.dense(Random.nextFloat + off, Random.nextFloat + off))
}))

// Register test set df
testData.registerTempTable("testData")

// Check if it works
sqlContext.sql("SELECT findCluster(v) FROM testData").take(1)
结果:

res3: Array[org.apache.spark.sql.Row] = Array([1])

嘿,谢谢,现在开始工作了。错误在于我使用齐柏林飞艇的方式。很高兴听到这个。如果您提供一些解释作为单独的答案,以防将来有人遇到类似的问题,这可能会很有用。