Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/287.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
使用python序列化要在Pyspark ML管道中使用的自定义转换器_Pyspark_Apache Spark Ml - Fatal编程技术网

使用python序列化要在Pyspark ML管道中使用的自定义转换器

使用python序列化要在Pyspark ML管道中使用的自定义转换器,pyspark,apache-spark-ml,Pyspark,Apache Spark Ml,我在评论部分发现了相同的讨论,但没有明确的答案。还有一个未解决的JIRA对应于: 考虑到pysparkml管道没有提供保存用python编写的自定义转换器的选项,那么还有什么其他选项可以完成它呢?如何在python类中实现返回兼容java对象的_to_java方法 我不确定这是否是最好的方法,但我也需要能够保存我在Pyspark中创建的自定义估计器、转换器和模型,并支持它们在管道API中的持久性使用。自定义Pyspark估计器、转换器和模型可以在管道API中创建和使用,但无法保存。当模型训练所需

我在评论部分发现了相同的讨论,但没有明确的答案。还有一个未解决的JIRA对应于:


考虑到pysparkml管道没有提供保存用python编写的自定义转换器的选项,那么还有什么其他选项可以完成它呢?如何在python类中实现返回兼容java对象的_to_java方法

我不确定这是否是最好的方法,但我也需要能够保存我在Pyspark中创建的自定义估计器、转换器和模型,并支持它们在管道API中的持久性使用。自定义Pyspark估计器、转换器和模型可以在管道API中创建和使用,但无法保存。当模型训练所需时间超过事件预测周期时,这会在生产中造成问题

一般来说,Pyspark估计器、转换器和模型只是Java或Scala等价物的包装器,Pyspark包装器只是通过py4j将参数封送到Java或从Java封送出去。然后在Java端执行模型的任何持久化。由于这种当前的结构,这就限制了自定义Pyspark估计器、转换器和模型只存在于python世界中

在之前的一次尝试中,我通过使用Pickle/dill序列化保存了一个Pyspark模型。这很有效,但仍然不允许从管道API中保存或加载此类内容。但是,另一篇SO帖子指出,我被引导到OneVsRest分类器,并检查了_to_java和_from_java方法。他们在Pypark方面做所有的繁重工作。看了之后,我想,如果有办法将pickle转储保存到一个已经制作并支持的可保存java对象,那么应该可以使用管道API保存一个定制的Pyspark估计器、转换器和模型

import dill
from pyspark.ml import Transformer, Pipeline, PipelineModel
from pyspark.ml.param import Param, Params
from pyspark.ml.util import Identifiable, MLReadable, MLWritable, JavaMLReader, JavaMLWriter
from pyspark.ml.feature import StopWordsRemover
from pyspark.ml.wrapper import JavaParams
from pyspark.context import SparkContext
from pyspark.sql import Row

class PysparkObjId(object):
    """
    A class to specify constants used to idenify and setup python 
    Estimators, Transformers and Models so they can be serialized on there
    own and from within a Pipline or PipelineModel.
    """
    def __init__(self):
        super(PysparkObjId, self).__init__()

    @staticmethod
    def _getPyObjId():
        return '4c1740b00d3c4ff6806a1402321572cb'

    @staticmethod
    def _getCarrierClass(javaName=False):
        return 'org.apache.spark.ml.feature.StopWordsRemover' if javaName else StopWordsRemover

class PysparkPipelineWrapper(object):
    """
    A class to facilitate converting the stages of a Pipeline or PipelineModel
    that were saved from PysparkReaderWriter.
    """
    def __init__(self):
        super(PysparkPipelineWrapper, self).__init__()

    @staticmethod
    def unwrap(pipeline):
        if not (isinstance(pipeline, Pipeline) or isinstance(pipeline, PipelineModel)):
            raise TypeError("Cannot recognize a pipeline of type %s." % type(pipeline))

        stages = pipeline.getStages() if isinstance(pipeline, Pipeline) else pipeline.stages
        for i, stage in enumerate(stages):
            if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):
                stages[i] = PysparkPipelineWrapper.unwrap(stage)
            if isinstance(stage, PysparkObjId._getCarrierClass()) and stage.getStopWords()[-1] == PysparkObjId._getPyObjId():
                swords = stage.getStopWords()[:-1] # strip the id
                lst = [chr(int(d)) for d in swords]
                dmp = ''.join(lst)
                py_obj = dill.loads(dmp)
                stages[i] = py_obj

        if isinstance(pipeline, Pipeline):
            pipeline.setStages(stages)
        else:
            pipeline.stages = stages
        return pipeline

class PysparkReaderWriter(object):
    """
    A mixin class so custom pyspark Estimators, Transformers and Models may
    support saving and loading directly or be saved within a Pipline or PipelineModel.
    """
    def __init__(self):
        super(PysparkReaderWriter, self).__init__()

    def write(self):
        """Returns an MLWriter instance for this ML instance."""
        return JavaMLWriter(self)

    @classmethod
    def read(cls):
        """Returns an MLReader instance for our clarrier class."""
        return JavaMLReader(PysparkObjId._getCarrierClass())

    @classmethod
    def load(cls, path):
        """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
        swr_java_obj = cls.read().load(path)
        return cls._from_java(swr_java_obj)

    @classmethod
    def _from_java(cls, java_obj):
        """
        Get the dumby the stopwords that are the characters of the dills dump plus our guid
        and convert, via dill, back to our python instance.
        """
        swords = java_obj.getStopWords()[:-1] # strip the id
        lst = [chr(int(d)) for d in swords] # convert from string integer list to bytes
        dmp = ''.join(lst)
        py_obj = dill.loads(dmp)
        return py_obj

    def _to_java(self):
        """
        Convert this instance to a dill dump, then to a list of strings with the unicode integer values of each character.
        Use this list as a set of dumby stopwords and store in a StopWordsRemover instance
        :return: Java object equivalent to this instance.
        """
        dmp = dill.dumps(self)
        pylist = [str(ord(d)) for d in dmp] # convert byes to string integer list
        pylist.append(PysparkObjId._getPyObjId()) # add our id so PysparkPipelineWrapper can id us.
        sc = SparkContext._active_spark_context
        java_class = sc._gateway.jvm.java.lang.String
        java_array = sc._gateway.new_array(java_class, len(pylist))
        for i in xrange(len(pylist)):
            java_array[i] = pylist[i]
        _java_obj = JavaParams._new_java_obj(PysparkObjId._getCarrierClass(javaName=True), self.uid)
        _java_obj.setStopWords(java_array)
        return _java_obj

class HasFake(Params):
    def __init__(self):
        super(HasFake, self).__init__()
        self.fake = Param(self, "fake", "fake param")

    def getFake(self):
        return self.getOrDefault(self.fake)

class MockTransformer(Transformer, HasFake, Identifiable):
    def __init__(self):
        super(MockTransformer, self).__init__()
        self.dataset_count = 0

    def _transform(self, dataset):
        self.dataset_count = dataset.count()
        return dataset

class MyTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):
    def __init__(self):
        super(MyTransformer, self).__init__()

def make_a_dataframe(sc):
    df = sc.parallelize([Row(name='Alice', age=5, height=80), Row(name='Alice', age=5, height=80), Row(name='Alice', age=10, height=80)]).toDF()
    return df

def test1():
    trA = MyTransformer()
    trA.dataset_count = 999
    print trA.dataset_count
    trA.save('test.trans')
    trB = MyTransformer.load('test.trans')
    print trB.dataset_count

def test2():
    trA = MyTransformer()
    pipeA = Pipeline(stages=[trA])
    print type(pipeA)
    pipeA.save('testA.pipe')
    pipeAA = PysparkPipelineWrapper.unwrap(Pipeline.load('testA.pipe'))
    stagesAA = pipeAA.getStages()
    trAA = stagesAA[0]
    print trAA.dataset_count

def test3():
    dfA = make_a_dataframe(sc)
    trA = MyTransformer()
    pipeA = Pipeline(stages=[trA]).fit(dfA)
    print type(pipeA)
    pipeA.save('testB.pipe')
    pipeAA = PysparkPipelineWrapper.unwrap(PipelineModel.load('testB.pipe'))
    stagesAA = pipeAA.stages
    trAA = stagesAA[0]
    print trAA.dataset_count
    dfB = pipeAA.transform(dfA)
    dfB.show()
为此,我发现StopWordsRever是理想的劫持对象,因为它有一个属性stopwords,即字符串列表。dill.dumps方法将对象的pickle表示形式返回为字符串。计划是将字符串转换为列表,然后将StopWordsRemover的stopwords参数设置为此列表。虽然列出了字符串,但我发现有些字符无法封送到java对象。所以字符被转换成整数,然后整数被转换成字符串。这对于保存单个实例非常有效,在管道中保存时也是如此,因为管道会尽职尽责地调用python类的_to_java方法(我们仍然在Pyspark方面,所以这是可行的)。但是,从java回到Pyspark并没有出现在管道API中

import dill
from pyspark.ml import Transformer, Pipeline, PipelineModel
from pyspark.ml.param import Param, Params
from pyspark.ml.util import Identifiable, MLReadable, MLWritable, JavaMLReader, JavaMLWriter
from pyspark.ml.feature import StopWordsRemover
from pyspark.ml.wrapper import JavaParams
from pyspark.context import SparkContext
from pyspark.sql import Row

class PysparkObjId(object):
    """
    A class to specify constants used to idenify and setup python 
    Estimators, Transformers and Models so they can be serialized on there
    own and from within a Pipline or PipelineModel.
    """
    def __init__(self):
        super(PysparkObjId, self).__init__()

    @staticmethod
    def _getPyObjId():
        return '4c1740b00d3c4ff6806a1402321572cb'

    @staticmethod
    def _getCarrierClass(javaName=False):
        return 'org.apache.spark.ml.feature.StopWordsRemover' if javaName else StopWordsRemover

class PysparkPipelineWrapper(object):
    """
    A class to facilitate converting the stages of a Pipeline or PipelineModel
    that were saved from PysparkReaderWriter.
    """
    def __init__(self):
        super(PysparkPipelineWrapper, self).__init__()

    @staticmethod
    def unwrap(pipeline):
        if not (isinstance(pipeline, Pipeline) or isinstance(pipeline, PipelineModel)):
            raise TypeError("Cannot recognize a pipeline of type %s." % type(pipeline))

        stages = pipeline.getStages() if isinstance(pipeline, Pipeline) else pipeline.stages
        for i, stage in enumerate(stages):
            if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):
                stages[i] = PysparkPipelineWrapper.unwrap(stage)
            if isinstance(stage, PysparkObjId._getCarrierClass()) and stage.getStopWords()[-1] == PysparkObjId._getPyObjId():
                swords = stage.getStopWords()[:-1] # strip the id
                lst = [chr(int(d)) for d in swords]
                dmp = ''.join(lst)
                py_obj = dill.loads(dmp)
                stages[i] = py_obj

        if isinstance(pipeline, Pipeline):
            pipeline.setStages(stages)
        else:
            pipeline.stages = stages
        return pipeline

class PysparkReaderWriter(object):
    """
    A mixin class so custom pyspark Estimators, Transformers and Models may
    support saving and loading directly or be saved within a Pipline or PipelineModel.
    """
    def __init__(self):
        super(PysparkReaderWriter, self).__init__()

    def write(self):
        """Returns an MLWriter instance for this ML instance."""
        return JavaMLWriter(self)

    @classmethod
    def read(cls):
        """Returns an MLReader instance for our clarrier class."""
        return JavaMLReader(PysparkObjId._getCarrierClass())

    @classmethod
    def load(cls, path):
        """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
        swr_java_obj = cls.read().load(path)
        return cls._from_java(swr_java_obj)

    @classmethod
    def _from_java(cls, java_obj):
        """
        Get the dumby the stopwords that are the characters of the dills dump plus our guid
        and convert, via dill, back to our python instance.
        """
        swords = java_obj.getStopWords()[:-1] # strip the id
        lst = [chr(int(d)) for d in swords] # convert from string integer list to bytes
        dmp = ''.join(lst)
        py_obj = dill.loads(dmp)
        return py_obj

    def _to_java(self):
        """
        Convert this instance to a dill dump, then to a list of strings with the unicode integer values of each character.
        Use this list as a set of dumby stopwords and store in a StopWordsRemover instance
        :return: Java object equivalent to this instance.
        """
        dmp = dill.dumps(self)
        pylist = [str(ord(d)) for d in dmp] # convert byes to string integer list
        pylist.append(PysparkObjId._getPyObjId()) # add our id so PysparkPipelineWrapper can id us.
        sc = SparkContext._active_spark_context
        java_class = sc._gateway.jvm.java.lang.String
        java_array = sc._gateway.new_array(java_class, len(pylist))
        for i in xrange(len(pylist)):
            java_array[i] = pylist[i]
        _java_obj = JavaParams._new_java_obj(PysparkObjId._getCarrierClass(javaName=True), self.uid)
        _java_obj.setStopWords(java_array)
        return _java_obj

class HasFake(Params):
    def __init__(self):
        super(HasFake, self).__init__()
        self.fake = Param(self, "fake", "fake param")

    def getFake(self):
        return self.getOrDefault(self.fake)

class MockTransformer(Transformer, HasFake, Identifiable):
    def __init__(self):
        super(MockTransformer, self).__init__()
        self.dataset_count = 0

    def _transform(self, dataset):
        self.dataset_count = dataset.count()
        return dataset

class MyTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):
    def __init__(self):
        super(MyTransformer, self).__init__()

def make_a_dataframe(sc):
    df = sc.parallelize([Row(name='Alice', age=5, height=80), Row(name='Alice', age=5, height=80), Row(name='Alice', age=10, height=80)]).toDF()
    return df

def test1():
    trA = MyTransformer()
    trA.dataset_count = 999
    print trA.dataset_count
    trA.save('test.trans')
    trB = MyTransformer.load('test.trans')
    print trB.dataset_count

def test2():
    trA = MyTransformer()
    pipeA = Pipeline(stages=[trA])
    print type(pipeA)
    pipeA.save('testA.pipe')
    pipeAA = PysparkPipelineWrapper.unwrap(Pipeline.load('testA.pipe'))
    stagesAA = pipeAA.getStages()
    trAA = stagesAA[0]
    print trAA.dataset_count

def test3():
    dfA = make_a_dataframe(sc)
    trA = MyTransformer()
    pipeA = Pipeline(stages=[trA]).fit(dfA)
    print type(pipeA)
    pipeA.save('testB.pipe')
    pipeAA = PysparkPipelineWrapper.unwrap(PipelineModel.load('testB.pipe'))
    stagesAA = pipeAA.stages
    trAA = stagesAA[0]
    print trAA.dataset_count
    dfB = pipeAA.transform(dfA)
    dfB.show()
因为我将python对象隐藏在StopWordsRemover实例中,所以当管道返回Pyspark时,它对我隐藏的类对象一无所知,它只知道它有一个StopWordsRemover实例。理想情况下,将Pipeline和PipelineModel子类化会很好,但遗憾的是,这让我们回到了尝试序列化Python对象的问题上。为了解决这个问题,我创建了一个PysparkPipelineWrapper,它接受一个管道或管道模型,只扫描阶段,在stopwords列表中寻找一个编码的ID(记住,这只是python对象的pickled字节),告诉它将列表打开到我的实例中,并将其存储回它来自的阶段。下面的代码显示了这一切是如何工作的

对于任何自定义Pyspark估计器、转换器和模型,只需从可识别、PysparkReaderWriter、MLReadable和MLWritable继承即可。然后,在加载管道和管道模型时,通过PysparkPipelineWrapper.unwrap(管道)传递这些数据

该方法不解决在Java或Scala中使用Pyspark代码的问题,但至少我们可以保存和加载自定义Pyspark估计器、转换器和模型,并使用管道API

import dill
from pyspark.ml import Transformer, Pipeline, PipelineModel
from pyspark.ml.param import Param, Params
from pyspark.ml.util import Identifiable, MLReadable, MLWritable, JavaMLReader, JavaMLWriter
from pyspark.ml.feature import StopWordsRemover
from pyspark.ml.wrapper import JavaParams
from pyspark.context import SparkContext
from pyspark.sql import Row

class PysparkObjId(object):
    """
    A class to specify constants used to idenify and setup python 
    Estimators, Transformers and Models so they can be serialized on there
    own and from within a Pipline or PipelineModel.
    """
    def __init__(self):
        super(PysparkObjId, self).__init__()

    @staticmethod
    def _getPyObjId():
        return '4c1740b00d3c4ff6806a1402321572cb'

    @staticmethod
    def _getCarrierClass(javaName=False):
        return 'org.apache.spark.ml.feature.StopWordsRemover' if javaName else StopWordsRemover

class PysparkPipelineWrapper(object):
    """
    A class to facilitate converting the stages of a Pipeline or PipelineModel
    that were saved from PysparkReaderWriter.
    """
    def __init__(self):
        super(PysparkPipelineWrapper, self).__init__()

    @staticmethod
    def unwrap(pipeline):
        if not (isinstance(pipeline, Pipeline) or isinstance(pipeline, PipelineModel)):
            raise TypeError("Cannot recognize a pipeline of type %s." % type(pipeline))

        stages = pipeline.getStages() if isinstance(pipeline, Pipeline) else pipeline.stages
        for i, stage in enumerate(stages):
            if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):
                stages[i] = PysparkPipelineWrapper.unwrap(stage)
            if isinstance(stage, PysparkObjId._getCarrierClass()) and stage.getStopWords()[-1] == PysparkObjId._getPyObjId():
                swords = stage.getStopWords()[:-1] # strip the id
                lst = [chr(int(d)) for d in swords]
                dmp = ''.join(lst)
                py_obj = dill.loads(dmp)
                stages[i] = py_obj

        if isinstance(pipeline, Pipeline):
            pipeline.setStages(stages)
        else:
            pipeline.stages = stages
        return pipeline

class PysparkReaderWriter(object):
    """
    A mixin class so custom pyspark Estimators, Transformers and Models may
    support saving and loading directly or be saved within a Pipline or PipelineModel.
    """
    def __init__(self):
        super(PysparkReaderWriter, self).__init__()

    def write(self):
        """Returns an MLWriter instance for this ML instance."""
        return JavaMLWriter(self)

    @classmethod
    def read(cls):
        """Returns an MLReader instance for our clarrier class."""
        return JavaMLReader(PysparkObjId._getCarrierClass())

    @classmethod
    def load(cls, path):
        """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
        swr_java_obj = cls.read().load(path)
        return cls._from_java(swr_java_obj)

    @classmethod
    def _from_java(cls, java_obj):
        """
        Get the dumby the stopwords that are the characters of the dills dump plus our guid
        and convert, via dill, back to our python instance.
        """
        swords = java_obj.getStopWords()[:-1] # strip the id
        lst = [chr(int(d)) for d in swords] # convert from string integer list to bytes
        dmp = ''.join(lst)
        py_obj = dill.loads(dmp)
        return py_obj

    def _to_java(self):
        """
        Convert this instance to a dill dump, then to a list of strings with the unicode integer values of each character.
        Use this list as a set of dumby stopwords and store in a StopWordsRemover instance
        :return: Java object equivalent to this instance.
        """
        dmp = dill.dumps(self)
        pylist = [str(ord(d)) for d in dmp] # convert byes to string integer list
        pylist.append(PysparkObjId._getPyObjId()) # add our id so PysparkPipelineWrapper can id us.
        sc = SparkContext._active_spark_context
        java_class = sc._gateway.jvm.java.lang.String
        java_array = sc._gateway.new_array(java_class, len(pylist))
        for i in xrange(len(pylist)):
            java_array[i] = pylist[i]
        _java_obj = JavaParams._new_java_obj(PysparkObjId._getCarrierClass(javaName=True), self.uid)
        _java_obj.setStopWords(java_array)
        return _java_obj

class HasFake(Params):
    def __init__(self):
        super(HasFake, self).__init__()
        self.fake = Param(self, "fake", "fake param")

    def getFake(self):
        return self.getOrDefault(self.fake)

class MockTransformer(Transformer, HasFake, Identifiable):
    def __init__(self):
        super(MockTransformer, self).__init__()
        self.dataset_count = 0

    def _transform(self, dataset):
        self.dataset_count = dataset.count()
        return dataset

class MyTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):
    def __init__(self):
        super(MyTransformer, self).__init__()

def make_a_dataframe(sc):
    df = sc.parallelize([Row(name='Alice', age=5, height=80), Row(name='Alice', age=5, height=80), Row(name='Alice', age=10, height=80)]).toDF()
    return df

def test1():
    trA = MyTransformer()
    trA.dataset_count = 999
    print trA.dataset_count
    trA.save('test.trans')
    trB = MyTransformer.load('test.trans')
    print trB.dataset_count

def test2():
    trA = MyTransformer()
    pipeA = Pipeline(stages=[trA])
    print type(pipeA)
    pipeA.save('testA.pipe')
    pipeAA = PysparkPipelineWrapper.unwrap(Pipeline.load('testA.pipe'))
    stagesAA = pipeAA.getStages()
    trAA = stagesAA[0]
    print trAA.dataset_count

def test3():
    dfA = make_a_dataframe(sc)
    trA = MyTransformer()
    pipeA = Pipeline(stages=[trA]).fit(dfA)
    print type(pipeA)
    pipeA.save('testB.pipe')
    pipeAA = PysparkPipelineWrapper.unwrap(PipelineModel.load('testB.pipe'))
    stagesAA = pipeAA.stages
    trAA = stagesAA[0]
    print trAA.dataset_count
    dfB = pipeAA.transform(dfA)
    dfB.show()
与@dmbaker的类似,我将名为
聚合器
的自定义转换器包装在内置Spark transformer中,在本例中为
二进制转换器
,不过我相信您也可以从其他转换器继承。这允许我的自定义转换器继承序列化所需的方法

from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler, Binarizer
from pyspark.ml.regression import LinearRegression    

class Aggregator(Binarizer):
    """A huge hack to allow serialization of custom transformer."""

    def transform(self, input_df):
        agg_df = input_df\
            .groupBy('channel_id')\
            .agg({
                'foo': 'avg',
                'bar': 'avg',
            })\
            .withColumnRenamed('avg(foo)', 'avg_foo')\
            .withColumnRenamed('avg(bar)', 'avg_bar') 
        return agg_df

# Create pipeline stages.
aggregator = Aggregator()
vector_assembler = VectorAssembler(...)
linear_regression = LinearRegression()

# Create pipeline.
pipeline = Pipeline(stages=[aggregator, vector_assembler, linear_regression])

# Train.
pipeline_model = pipeline.fit(input_df)

# Save model file to S3.
pipeline_model.save('s3n://example')

@dmbaker解决方案对我不起作用。我相信这是因为python版本(2.x与3.x)的不同。我对他的解决方案做了一些更新,现在它可以在Python3上运行。我的设置如下所示:

  • python:3.6.3
  • 火花:2.2.1
  • 莳萝:0.2.7.1

我无法在Spark 2.2.0上使用Python 2获得@dmbaker的巧妙解决方案;我不断地犯酸洗错误。经过几个死路一条,我得到了一个有效的解决方案,修改了他(她?)的想法,将参数值作为字符串写入并读取到
StopWordsRemover的
stopWords中

如果要保存并加载自己的估计器或转换器,请使用基类:

from pyspark import SparkContext
from pyspark.ml.feature import StopWordsRemover
from pyspark.ml.util import Identifiable, MLWritable, JavaMLWriter, MLReadable, JavaMLReader
from pyspark.ml.wrapper import JavaWrapper, JavaParams

class PysparkReaderWriter(Identifiable, MLReadable, MLWritable):
    """
    A base class for custom pyspark Estimators and Models to support saving and loading directly
    or within a Pipeline or PipelineModel.
    """
    def __init__(self):
        super(PysparkReaderWriter, self).__init__()

    @staticmethod
    def _getPyObjIdPrefix():
        return "_ThisIsReallyA_"

    @classmethod
    def _getPyObjId(cls):
        return PysparkReaderWriter._getPyObjIdPrefix() + cls.__name__

    def getParamsAsListOfStrings(self):
        raise NotImplementedError("PysparkReaderWriter.getParamsAsListOfStrings() not implemented for instance: %r" % self)

    def write(self):
        """Returns an MLWriter instance for this ML instance."""
        return JavaMLWriter(self)

    def _to_java(self):
        # Convert all our parameters to strings:
        paramValuesAsStrings = self.getParamsAsListOfStrings()

        # Append our own type-specific id so PysparkPipelineLoader can detect this algorithm when unwrapping us.
        paramValuesAsStrings.append(self._getPyObjId())

        # Convert the parameter values to a Java array:
        sc = SparkContext._active_spark_context
        java_array = JavaWrapper._new_java_array(paramValuesAsStrings, sc._gateway.jvm.java.lang.String)

        # Create a Java (Scala) StopWordsRemover and give it the parameters as its stop words.
        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid)
        _java_obj.setStopWords(java_array)
        return _java_obj

    @classmethod
    def _from_java(cls, java_obj):
        # Get the stop words, ignoring the id at the end:
        stopWords = java_obj.getStopWords()[:-1]
        return cls.createAndInitialisePyObj(stopWords)

    @classmethod
    def createAndInitialisePyObj(cls, paramsAsListOfStrings):
        raise NotImplementedError("PysparkReaderWriter.createAndInitialisePyObj() not implemented for type: %r" % cls)

    @classmethod
    def read(cls):
        """Returns an MLReader instance for our clarrier class."""
        return JavaMLReader(StopWordsRemover)

    @classmethod
    def load(cls, path):
        """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
        swr_java_obj = cls.read().load(path)
        return cls._from_java(swr_java_obj)
然后,您自己的pyspark算法必须继承自
PysparkReaderWriter
,并重写将参数保存到字符串列表中的
getParamsAsListOfStrings()
方法。算法还必须重写
createAndInitializePyobj()
方法,以便将字符串列表转换回参数。在幕后,参数被转换为
StopWordsRemover
所使用的停止字

示例估计器具有3个不同类型的参数:

from pyspark.ml.param.shared import Param, Params, TypeConverters
from pyspark.ml.base import Estimator

class MyEstimator(Estimator, PysparkReaderWriter):

def __init__(self):
    super(MyEstimator, self).__init__()

# 3 sample parameters, deliberately of different types:
stringParam = Param(Params._dummy(), "stringParam", "A dummy string parameter", typeConverter=TypeConverters.toString)

def setStringParam(self, value):
    return self._set(stringParam=value)

def getStringParam(self):
    return self.getOrDefault(self.stringParam)

listOfStringsParam = Param(Params._dummy(), "listOfStringsParam", "A dummy list of strings.", typeConverter=TypeConverters.toListString)

def setListOfStringsParam(self, value):
    return self._set(listOfStringsParam=value)

def getListOfStringsParam(self):
    return self.getOrDefault(self.listOfStringsParam)

intParam = Param(Params._dummy(), "intParam", "A dummy int parameter.", typeConverter=TypeConverters.toInt)

def setIntParam(self, value):
    return self._set(intParam=value)

def getIntParam(self):
    return self.getOrDefault(self.intParam)

def _fit(self, dataset):
    model = MyModel()
    # Just some changes to verify we can modify the model (and also it's something we can expect to see when restoring it later):
    model.setAnotherStringParam(self.getStringParam() + " World!")
    model.setAnotherListOfStringsParam(self.getListOfStringsParam() + ["E", "F"])
    model.setAnotherIntParam(self.getIntParam() + 10)
    return model

def getParamsAsListOfStrings(self):
    paramValuesAsStrings = []
    paramValuesAsStrings.append(self.getStringParam()) # Parameter is already a string
    paramValuesAsStrings.append(','.join(self.getListOfStringsParam())) # ...convert from a list of strings
    paramValuesAsStrings.append(str(self.getIntParam())) # ...convert from an int
    return paramValuesAsStrings

@classmethod
def createAndInitialisePyObj(cls, paramsAsListOfStrings):
    # Convert back into our parameters. Make sure you do this in the same order you saved them!
    py_obj = cls()
    py_obj.setStringParam(paramsAsListOfStrings[0])
    py_obj.setListOfStringsParam(paramsAsListOfStrings[1].split(","))
    py_obj.setIntParam(int(paramsAsListOfStrings[2]))
    return py_obj
示例模型(也是变压器),具有3个不同参数:

from pyspark.ml.base import Model

class MyModel(Model, PysparkReaderWriter):

    def __init__(self):
        super(MyModel, self).__init__()

    # 3 sample parameters, deliberately of different types:
    anotherStringParam = Param(Params._dummy(), "anotherStringParam", "A dummy string parameter", typeConverter=TypeConverters.toString)

    def setAnotherStringParam(self, value):
        return self._set(anotherStringParam=value)

    def getAnotherStringParam(self):
        return self.getOrDefault(self.anotherStringParam)

    anotherListOfStringsParam = Param(Params._dummy(), "anotherListOfStringsParam", "A dummy list of strings.", typeConverter=TypeConverters.toListString)

    def setAnotherListOfStringsParam(self, value):
        return self._set(anotherListOfStringsParam=value)

    def getAnotherListOfStringsParam(self):
        return self.getOrDefault(self.anotherListOfStringsParam)

    anotherIntParam = Param(Params._dummy(), "anotherIntParam", "A dummy int parameter.", typeConverter=TypeConverters.toInt)

    def setAnotherIntParam(self, value):
        return self._set(anotherIntParam=value)

    def getAnotherIntParam(self):
        return self.getOrDefault(self.anotherIntParam)

    def _transform(self, dataset):
        # Dummy transform code:
        return dataset.withColumn('age2', dataset.age + self.getAnotherIntParam())

    def getParamsAsListOfStrings(self):
        paramValuesAsStrings = []
        paramValuesAsStrings.append(self.getAnotherStringParam()) # Parameter is already a string
        paramValuesAsStrings.append(','.join(self.getAnotherListOfStringsParam())) # ...convert from a list of strings
        paramValuesAsStrings.append(str(self.getAnotherIntParam())) # ...convert from an int
        return paramValuesAsStrings

    @classmethod
    def createAndInitialisePyObj(cls, paramsAsListOfStrings):
        # Convert back into our parameters. Make sure you do this in the same order you saved them!
        py_obj = cls()
        py_obj.setAnotherStringParam(paramsAsListOfStrings[0])
        py_obj.setAnotherListOfStringsParam(paramsAsListOfStrings[1].split(","))
        py_obj.setAnotherIntParam(int(paramsAsListOfStrings[2]))
        return py_obj
下面是一个示例测试用例,展示了如何保存和加载模型。它对于估计量是相似的,所以为了简洁起见,我省略了它

def createAModel():
    m = MyModel()
    m.setAnotherStringParam("Boo!")
    m.setAnotherListOfStringsParam(["P", "Q", "R"])
    m.setAnotherIntParam(77)
    return m

def testSaveLoadModel():
    modA = createAModel()
    print(modA.explainParams())

    savePath = "/whatever/path/you/want"
    #modA.save(savePath) # Can't overwrite, so...
    modA.write().overwrite().save(savePath)

    modB = MyModel.load(savePath)
    print(modB.explainParams())

testSaveLoadModel()
输出:

anotherIntParam: A dummy int parameter. (current: 77)
anotherListOfStringsParam: A dummy list of strings. (current: ['P', 'Q', 'R'])
anotherStringParam: A dummy string parameter (current: Boo!)
anotherIntParam: A dummy int parameter. (current: 77)
anotherListOfStringsParam: A dummy list of strings. (current: [u'P', u'Q', u'R'])
anotherStringParam: A dummy string parameter (current: Boo!)
intParam: A dummy int parameter. (current: 42)
listOfStringsParam: A dummy list of strings. (current: [u'A', u'B', u'C', u'D'])
stringParam: A dummy string parameter (current: Hello)
anotherIntParam: A dummy int parameter. (current: 52)
anotherListOfStringsParam: A dummy list of strings. (current: [u'A', u'B', u'C', u'D', u'E', u'F'])
anotherStringParam: A dummy string parameter (current: Hello World!)
+---+------+-----+----+
|age|height| name|age2|
+---+------+-----+----+
|  5|    80|Alice|  57|
|  7|    85|  Bob|  59|
| 10|    90|Chris|  62|
+---+------+-----+----+
注意参数是如何以unicode字符串的形式返回的。这可能会也可能不会对您在估计器的
\u transform()
(或
\u fit()
中实现的基础算法产生影响
from pyspark.sql import Row

def make_a_dataframe(sc):
    df = sc.parallelize([Row(name='Alice', age=5, height=80), Row(name='Bob', age=7, height=85), Row(name='Chris', age=10, height=90)]).toDF()
    return df

def testSaveAndLoadPipelineModel():
    dfA = make_a_dataframe(sc)
    estA = createAnEstimator()
    #print(estA.explainParams())
    pipelineModelA = Pipeline(stages=[estA]).fit(dfA)
    savePath = "/whatever/path/you/want"
    #pipelineModelA.save(savePath) # Can't overwrite, so...
    pipelineModelA.write().overwrite().save(savePath)

    pipelineModelReloaded = PysparkPipelineLoader.unwrap(PipelineModel.load(savePath), [MyModel])
    modB = pipelineModelReloaded.stages[0]
    print(modB.explainParams())

    dfB = pipelineModelReloaded.transform(dfA)
    dfB.show()

testSaveAndLoadPipelineModel()
anotherIntParam: A dummy int parameter. (current: 52)
anotherListOfStringsParam: A dummy list of strings. (current: [u'A', u'B', u'C', u'D', u'E', u'F'])
anotherStringParam: A dummy string parameter (current: Hello World!)
+---+------+-----+----+
|age|height| name|age2|
+---+------+-----+----+
|  5|    80|Alice|  57|
|  7|    85|  Bob|  59|
| 10|    90|Chris|  62|
+---+------+-----+----+
from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasOutputCols, Param, Params
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql.functions import lit # for the dummy _transform

class SetValueTransformer(
    Transformer, HasOutputCols, DefaultParamsReadable, DefaultParamsWritable,
):
    value = Param(
        Params._dummy(),
        "value",
        "value to fill",
    )

    @keyword_only
    def __init__(self, outputCols=None, value=0.0):
        super(SetValueTransformer, self).__init__()
        self._setDefault(value=0.0)
        kwargs = self._input_kwargs
        self._set(**kwargs)

    @keyword_only
    def setParams(self, outputCols=None, value=0.0):
        """
        setParams(self, outputCols=None, value=0.0)
        Sets params for this SetValueTransformer.
        """
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def setValue(self, value):
        """
        Sets the value of :py:attr:`value`.
        """
        return self._set(value=value)

    def getValue(self):
        """
        Gets the value of :py:attr:`value` or its default value.
        """
        return self.getOrDefault(self.value)

    def _transform(self, dataset):
        for col in self.getOutputCols():
            dataset = dataset.withColumn(col, lit(self.getValue()))
        return dataset
from pyspark.ml import Pipeline, PipelineModel

svt = SetValueTransformer(outputCols=["a", "b"], value=123.0)

p = Pipeline(stages=[svt])
df = sc.parallelize([(1, None), (2, 1.0), (3, 0.5)]).toDF(["key", "value"])
pm = p.fit(df)
pm.transform(df).show()
pm.write().overwrite().save('/tmp/example_pyspark_pipeline')
pm2 = PipelineModel.load('/tmp/example_pyspark_pipeline')
print('matches?', pm2.stages[0].extractParamMap() == pm.stages[0].extractParamMap())
pm2.transform(df).show()
+---+-----+-----+-----+
|key|value|    a|    b|
+---+-----+-----+-----+
|  1| null|123.0|123.0|
|  2|  1.0|123.0|123.0|
|  3|  0.5|123.0|123.0|
+---+-----+-----+-----+

matches? True
+---+-----+-----+-----+
|key|value|    a|    b|
+---+-----+-----+-----+
|  1| null|123.0|123.0|
|  2|  1.0|123.0|123.0|
|  3|  0.5|123.0|123.0|
+---+-----+-----+-----+