Python Pyspark:SparseVector的求和错误

Python Pyspark:SparseVector的求和错误,python,numpy,apache-spark,pyspark,user-defined-functions,Python,Numpy,Apache Spark,Pyspark,User Defined Functions,假设我有一个SparseVector,我想对它的值求和,例如 v = SparseVector(15557, [3, 40, 45, 103, 14356], np.ones(5)) v.values.sum() 5.0 这很有效。现在我想通过一个udf来做同样的事情,因为我有一个DataFrame,其中有一列SparseVector。这里有一个我不明白的错误: from pyspark.sql import functions as f def sum_vector(vector):

假设我有一个
SparseVector
,我想对它的值求和,例如

v = SparseVector(15557, [3, 40, 45, 103, 14356], np.ones(5))
v.values.sum()

5.0
这很有效。现在我想通过一个
udf
来做同样的事情,因为我有一个
DataFrame
,其中有一列
SparseVector
。这里有一个我不明白的错误:

from pyspark.sql import functions as f

def sum_vector(vector):
    return vector.values.sum()

sum_vector_udf = f.udf(lambda x: sum_vector(x))

sum_vector_udf(v)

----

AttributeError                            Traceback (most recent call last)
<ipython-input-38-b4d44c2ef561> in <module>()
      1 v = SparseVector(15557, [3, 40, 45, 103, 14356], np.ones(5))
      2 
----> 3 sum_vector_udf(v)
      4 #v.values.sum()

~/anaconda3/lib/python3.6/site-packages/pyspark/sql/functions.py in wrapper(*args)
   1955         @functools.wraps(f)
   1956         def wrapper(*args):
-> 1957             return udf_obj(*args)
   1958 
   1959         wrapper.func = udf_obj.func

~/anaconda3/lib/python3.6/site-packages/pyspark/sql/functions.py in __call__(self, *cols)
   1916         judf = self._judf
   1917         sc = SparkContext._active_spark_context
-> 1918         return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
   1919 
   1920 

~/anaconda3/lib/python3.6/site-packages/pyspark/sql/column.py in _to_seq(sc, cols, converter)
     58     """
     59     if converter:
---> 60         cols = [converter(c) for c in cols]
     61     return sc._jvm.PythonUtils.toSeq(cols)
     62 

~/anaconda3/lib/python3.6/site-packages/pyspark/sql/column.py in <listcomp>(.0)
     58     """
     59     if converter:
---> 60         cols = [converter(c) for c in cols]
     61     return sc._jvm.PythonUtils.toSeq(cols)
     62 

~/anaconda3/lib/python3.6/site-packages/pyspark/sql/column.py in _to_java_column(col)
     46         jcol = col._jc
     47     else:
---> 48         jcol = _create_column_from_name(col)
     49     return jcol
     50 

~/anaconda3/lib/python3.6/site-packages/pyspark/sql/column.py in _create_column_from_name(name)
     39 def _create_column_from_name(name):
     40     sc = SparkContext._active_spark_context
---> 41     return sc._jvm.functions.col(name)
     42 
     43 

~/anaconda3/lib/python3.6/site-packages/py4j/java_gateway.py in __call__(self, *args)
   1122 
   1123     def __call__(self, *args):
-> 1124         args_command, temp_args = self._build_args(*args)
   1125 
   1126         command = proto.CALL_COMMAND_NAME +\

~/anaconda3/lib/python3.6/site-packages/py4j/java_gateway.py in _build_args(self, *args)
   1092 
   1093         args_command = "".join(
-> 1094             [get_command_part(arg, self.pool) for arg in new_args])
   1095 
   1096         return args_command, temp_args

~/anaconda3/lib/python3.6/site-packages/py4j/java_gateway.py in <listcomp>(.0)
   1092 
   1093         args_command = "".join(
-> 1094             [get_command_part(arg, self.pool) for arg in new_args])
   1095 
   1096         return args_command, temp_args

~/anaconda3/lib/python3.6/site-packages/py4j/protocol.py in get_command_part(parameter, python_proxy_pool)
    287             command_part += ";" + interface
    288     else:
--> 289         command_part = REFERENCE_TYPE + parameter._get_object_id()
    290 
    291     command_part += "\n"

AttributeError: 'SparseVector' object has no attribute '_get_object_id'
从pyspark.sql导入函数为f
定义和向量(向量):
返回vector.values.sum()
sum_vector_udf=f.udf(λx:sum_vector(x))
和向量udf(v)
----
AttributeError回溯(最近一次呼叫上次)
在()
1伏=SparseVector(15557[3,40,45,103,14356],np.ones(5))
2.
---->3和向量udf(v)
4#v.values.sum()
包装器中的~/anaconda3/lib/python3.6/site-packages/pyspark/sql/functions.py(*args)
1955@functools.wrapps(f)
1956 def包装(*参数):
->1957返回自定义项对象(*args)
1958
1959 wrapper.func=udf_obj.func
~/anaconda3/lib/python3.6/site-packages/pyspark/sql/functions.py在调用中(self,*cols)
1916年judf=self.\u judf
1917 sc=SparkContext.\u活动\u火花\u上下文
->1918返回列(judf.apply(_to_seq(sc,cols,_to_java_列)))
1919
1920
~/anaconda3/lib/python3.6/site-packages/pyspark/sql/column.py in_to_seq(sc、cols、converter)
58     """
59如果转换器:
--->60列=[以列表示的c的转换器(c)]
61返回sc.\u jvm.PythonUtils.toSeq(cols)
62
~/anaconda3/lib/python3.6/site-packages/pyspark/sql/column.py in(.0)
58     """
59如果转换器:
--->60列=[以列表示的c的转换器(c)]
61返回sc.\u jvm.PythonUtils.toSeq(cols)
62
~/anaconda3/lib/python3.6/site-packages/pyspark/sql/column.py in_to_java_column(col)
46 jcol=col.\u jc
47.其他:
--->48 jcol=\u从\u名称(col)创建\u列\u
49返回jcol
50
~/anaconda3/lib/python3.6/site-packages/pyspark/sql/column.py在\u create\u column\u from\u name(name)中
39定义从名称(名称)创建列:
40 sc=SparkContext.\u活动\u火花\u上下文
--->41返回sc.\u jvm.functions.col(名称)
42
43
~/anaconda3/lib/python3.6/site-packages/py4j/java\u gateway.py在调用中(self,*args)
1122
1123定义调用(self,*args):
->1124 args\u命令,temp\u args=self.\u build\u args(*args)
1125
1126 command=proto.CALL\u command\u NAME+\
~/anaconda3/lib/python3.6/site-packages/py4j/java_gateway.py in_build_args(self,*args)
1092
1093 args_command=“”.加入(
->1094[get_command_part(arg,self.pool)用于新参数中的参数])
1095
1096返回参数命令,临时参数
~/anaconda3/lib/python3.6/site-packages/py4j/java_gateway.py in(.0)
1092
1093 args_command=“”.加入(
->1094[get_command_part(arg,self.pool)用于新参数中的参数])
1095
1096返回参数命令,临时参数
get_命令_部分(参数,python_代理_池)中的~/anaconda3/lib/python3.6/site-packages/py4j/protocol.py
287命令_部分+=“;”+接口
288其他:
-->289命令\u部分=引用\u类型+参数。\u获取\u对象\u id()
290
291命令\u部分+=“\n”
AttributeError:“SparseVector”对象没有属性“\u get\u object\u id”

我真的不明白,我用两种不同的方式写了完全相同的东西。有什么提示吗?

发生这种情况是因为
udf
不支持NumPy类型作为返回类型

>>> type(v.values.sum())
<class 'numpy.float64'>

在这两种情况下,您都会得到预期的结果:

df.select(sum_vector("v")).show()
+-------------+
|和向量(v)|
+-------------+
|          5.0|
+-------------+

发生这种情况是因为
udf
不支持NumPy类型作为返回类型

>>> type(v.values.sum())
<class 'numpy.float64'>

在这两种情况下,您都会得到预期的结果:

df.select(sum_vector("v")).show()
+-------------+
|和向量(v)|
+-------------+
|          5.0|
+-------------+