Python 执行Spark TrainValidationSplit拟合时出现java.lang.StackOverflower错误

Python 执行Spark TrainValidationSplit拟合时出现java.lang.StackOverflower错误,python,apache-spark,pyspark,apache-spark-ml,Python,Apache Spark,Pyspark,Apache Spark Ml,我试着在其他很多帖子上寻找这个问题,但没有找到解决方案,因此打开了这个帖子 我正在尝试使用Spark对我的模型进行超参数调整 我在16Gb、8核PC上以独立本地模式启动代码,并在spark defaults.conf中进行以下配置:spark.driver.memory 14g。这是我做的唯一配置 我使用渐变增强的树分类器启动拟合,如下所示: tvs = TrainValidationSplit(estimator=classifier, e

我试着在其他很多帖子上寻找这个问题,但没有找到解决方案,因此打开了这个帖子

我正在尝试使用Spark对我的模型进行超参数调整

我在16Gb、8核PC上以独立本地模式启动代码,并在
spark defaults.conf
中进行以下配置:
spark.driver.memory 14g
。这是我做的唯一配置

我使用渐变增强的树分类器启动拟合,如下所示:

tvs = TrainValidationSplit(estimator=classifier,
                           estimatorParamMaps=param_grid,
                           evaluator=evaluator,
                           trainRatio=0.8,
                           parallelism=4)
model = tvs.fit(training)
我得到了一个很长的堆栈跟踪错误(25000行!),没有任何有用的信息,看起来像这样,并且似乎在重复:

16:55:57 ERROR SparkUncaughtExceptionHandler: Uncaught exception in thread Thread[Executor task launch worker for task 101611,5,main]
java.lang.StackOverflowError
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2331)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:465)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:423)
    at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:490)
    at sun.reflect.GeneratedMethodAccessor14.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1170)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2233)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2342)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2342)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:465)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:423)
    at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:490)
    at sun.reflect.GeneratedMethodAccessor14.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1170)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2233)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2342)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2342)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2266)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2124)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:465)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:423)
    at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:490)
    at sun.reflect.GeneratedMethodAccessor14.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    ...

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/.../venv/lib/python3.6/site-packages/py4j/java_gateway.py", line 1067, in start
    self.socket.connect((self.address, self.port))
ConnectionRefusedError: [Errno 111] Connection refused
...
当我训练其他模型,如随机森林,它的工作。因此,我猜测这是因为我在训练验证中测试GBT(36)比RF(4)更多的超参数组合,正如用于构建paramgrid的列表所示:

RF_MAX_DEPTH = [5, 7]
RF_MAX_BINS = [32]
RF_NUM_TREES = [10, 5]
RF_IMPURITY = ['entropy']
RF_FEATURE_SUBSET_STRATEGY = ['auto']
RF_MIN_INSTANCES_PER_NODE = [1]
RF_MIN_INFO_GAIN = [0.0]
RF_SUBSAMPLING_RATE = [0.8]
RF_MAX_MEMORY_IN_MB = [8192]
RF_CACHE_NODE_IDS = [False]

GBT_MAX_DEPTH = [5, 7, 15, 20]
GBT_MAX_BINS = [32]
GBT_MAX_ITER = [200, 300, 1000]
GBT_STEP_SIZE = [0.05, 0.03, 0.2]
GBT_LOSS_TYPE = ['logistic']
GBT_MIN_INSTANCES_PER_NODE = [1]
GBT_MIN_INFO_GAIN = [0.0]
GBT_SUBSAMPLING_RATE = [0.8]
GBT_MAX_MEMORY_IN_MB = [8192]
GBT_CACHE_NODE_IDS = [True]
为什么TrainValidationSplit抛出这个错误,我如何解决它(除了删除一些超参数可能性)

提前感谢您的帮助:)

编辑:我删除了一些超参数的可能性(因此只有一种可能的组合),它可以工作。问题就在这里,但我真的想让它发挥作用

编辑2:我曾尝试添加
spark.driver.extraJavaOptions-XX:ThreadStackSize=81920
,因为它可能与jvm设置有关(因为它是一个Java堆栈溢出错误),但只是后来导致崩溃