2017-01-21 7 views
2

Этот вопрос похож на this one. Я хотел бы напечатать лучшие параметры модели после выполнения TrainValidationSplit в pyspark. Я не могу найти кусок текста другой пользователь использует, чтобы ответить на этот вопрос, потому что я работаю над jupyter и бревенчатых пропадает из терминала ...Как распечатать лучшие параметры модели в трубопроводе pyspark

Часть кода:

pca = PCA(inputCol = 'features') 
dt = DecisionTreeRegressor(featuresCol=pca.getOutputCol(), 
          labelCol="energy") 
pipe = Pipeline(stages=[pca,dt]) 

paramgrid = ParamGridBuilder().addGrid(pca.k, range(1,50,2)).addGrid(dt.maxDepth, range(1,10,1)).build() 

tvs = TrainValidationSplit(estimator = pipe, evaluator = RegressionEvaluator(
labelCol="energy", predictionCol="prediction", metricName="mae"), estimatorParamMaps = paramgrid, trainRatio = 0.66) 

model = tvs.fit(wind_tr_va); 

Спасибо заранее.

ответ

4

Из этого следует то же рассуждение, описанное в ответе о How to get the maxDepth from a Spark RandomForestRegressionModel, данное @ user6910411.

Вам нужно пропатчить TrainValidationSplitModel, PCAModel и DecisionTreeRegressionModel как следует:

TrainValidationSplitModel.bestModel = (
    lambda self: self._java_obj.bestModel 
) 

PCAModel.getK = (
    lambda self: self._java_obj.getK() 
) 

DecisionTreeRegressionModel.getMaxDepth = (
    lambda self: self._java_obj.getMaxDepth() 
) 

Теперь вы можете использовать его, чтобы получить лучшую модель и извлечь k и maxDepth

bestModel = model.bestModel 

bestModelK = bestModel.stages[0].getK() 
bestModelMaxDepth = bestModel.stages[1].getMaxDepth() 

PS : Вы можете патч-модели для получения определенных параметров так же, как описано выше.

1

Еще проще (1 линия), просто обратитесь к JVM объект модели

cvModel.bestModel.stages[-1]._java_obj.getMaxDepth() 

Здесь вы берете свой bestModel после перекрестной проверки, вызовите объект JVM этой модели и извлечь параметр maxDepth с помощью getMaxDepth() - метод из объекта JVM.

Список всех оригинальной JVM GET-параметры можно найти здесь https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/RandomForestClassificationModel.html

Кроме того, вы можете просматривать другие GET-параметры для других моделей и извлечь их со ссылкой на оригинальный JVM объект любой модели

<yourModel>.stages[<yourModelStage>]._java_obj.<getParameter>() 

Надеюсь, это поможет.

 Смежные вопросы

  • Нет связанных вопросов^_^