5

Я пытаюсь найти Точность, используя 5-кратное перекрестное подтверждение с использованием модели Random Forest Classifier в SCALA. Но я получаю следующее сообщение об ошибке во время выполнения:RandomForestClassifier получил ввод с недопустимой ошибкой столбца столбца в Apache Spark

java.lang.IllegalArgumentException: RandomForestClassifier был дан вход с недопустимой подписью столбца метка, без числа классов, указанных. См. StringIndexer.

Получение выше ошибки в строке ---> Val cvModel = cv.fit (trainingData)

Код, который я использовал для перекрестной проверки набора данных с использованием случайных лес выглядит следующим образом:

import org.apache.spark.ml.Pipeline 
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} 
import org.apache.spark.ml.classification.RandomForestClassifier 
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator 
import org.apache.spark.mllib.linalg.Vectors 
import org.apache.spark.mllib.regression.LabeledPoint 

val data = sc.textFile("exprogram/dataset.txt") 
val parsedData = data.map { line => 
val parts = line.split(',') 
LabeledPoint(parts(41).toDouble, 
Vectors.dense(parts(0).split(',').map(_.toDouble))) 
} 


val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) 
val training = splits(0) 
val test = splits(1) 

val trainingData = training.toDF() 

val testData = test.toDF() 

val nFolds: Int = 5 
val NumTrees: Int = 5 

val rf = new  
RandomForestClassifier() 
     .setLabelCol("label") 
     .setFeaturesCol("features") 
     .setNumTrees(NumTrees) 

val pipeline = new Pipeline() 
     .setStages(Array(rf)) 

val paramGrid = new ParamGridBuilder() 
      .build() 

val evaluator = new MulticlassClassificationEvaluator() 
    .setLabelCol("label") 
    .setPredictionCol("prediction") 
    .setMetricName("precision") 

val cv = new CrossValidator() 
    .setEstimator(pipeline) 
    .setEvaluator(evaluator) 
    .setEstimatorParamMaps(paramGrid) 
    .setNumFolds(nFolds) 

val cvModel = cv.fit(trainingData) 

val results = cvModel.transform(testData) 
.select("label","prediction").collect 

val numCorrectPredictions = results.map(row => 
if (row.getDouble(0) == row.getDouble(1)) 1 else 0).foldLeft(0)(_ + _) 
val accuracy = 1.0D * numCorrectPredictions/results.size 

println("Test set accuracy: %.3f".format(accuracy)) 

Может кто-нибудь объяснить, что является ошибкой в ​​приведенном выше коде.

ответ

8

RandomForestClassifier, так же как и многие другие алгоритмы ML, требуют определенных метаданных, которые должны быть установлены в столбце меток, и значения меток, которые являются целыми значениями из [0, 1, 2 ..., #classes), представлены как двойные. Обычно это обрабатывается восходящим потоком Transformers, как StringIndexer. Поскольку вы конвертируете метки вручную, поля метаданных не заданы, и классификатор не может подтвердить, что эти требования выполнены.

val df = Seq(
    (0.0, Vectors.dense(1, 0, 0, 0)), 
    (1.0, Vectors.dense(0, 1, 0, 0)), 
    (2.0, Vectors.dense(0, 0, 1, 0)), 
    (2.0, Vectors.dense(0, 0, 0, 1)) 
).toDF("label", "features") 

val rf = new RandomForestClassifier() 
    .setFeaturesCol("features") 
    .setNumTrees(5) 

rf.setLabelCol("label").fit(df) 
// java.lang.IllegalArgumentException: RandomForestClassifier was given input ... 

Вы можете перекодировать столбец этикетки с помощью StringIndexer:

import org.apache.spark.ml.feature.StringIndexer 

val indexer = new StringIndexer() 
    .setInputCol("label") 
    .setOutputCol("label_idx") 
    .fit(df) 

rf.setLabelCol("label_idx").fit(indexer.transform(df)) 

или set required metadata manually:

val meta = NominalAttribute 
    .defaultAttr 
    .withName("label") 
    .withValues("0.0", "1.0", "2.0") 
    .toMetadata 

rf.setLabelCol("label_meta").fit(
    df.withColumn("label_meta", $"label".as("", meta)) 
) 

Примечание:

этикетки, созданные с помощью StringIndexer зависит от частоты не значение:

indexer.labels 
// Array[String] = Array(2.0, 0.0, 1.0) 

PySpark:

В полей метаданных Python может быть установлен непосредственно на схеме:

from pyspark.sql.types import StructField, DoubleType 

StructField(
    "label", DoubleType(), False, 
    {"ml_attr": { 
     "name": "label", 
     "type": "nominal", 
     "vals": ["0.0", "1.0", "2.0"] 
    }} 
)