Я пытаюсь запустить произвольную классификацию леса с помощью Spark ML api, но у меня возникают проблемы с созданием ввода правильного ввода данных в конвейер.Как создать правильный фрейм данных для классификации в Spark ML
Вот выборка данных:
age,hours_per_week,education,sex,salaryRange
38,40,"hs-grad","male","A"
28,40,"bachelors","female","A"
52,45,"hs-grad","male","B"
31,50,"masters","female","B"
42,40,"bachelors","male","B"
возраст и hours_per_week являются целыми числами, а другие функции, включая этикетки salaryRange категоричны (String)
Загрузка этого CSV-файл (назовем его sample.csv) может быть сделано Spark csv library следующим образом:
val data = sqlContext.csvFile("/home/dusan/sample.csv")
По умолчанию все столбцы импортируются в виде строки, поэтому мы должны изменить «возраст» и «hours_per_week» в Int:
val toInt = udf[Int, String](_.toInt)
val dataFixed = data.withColumn("age", toInt(data("age"))).withColumn("hours_per_week",toInt(data("hours_per_week")))
Просто чтобы проверить, как схема выглядит сейчас:
scala> dataFixed.printSchema
root
|-- age: integer (nullable = true)
|-- hours_per_week: integer (nullable = true)
|-- education: string (nullable = true)
|-- sex: string (nullable = true)
|-- salaryRange: string (nullable = true)
Тогда позволяет настроить кросс валидатора и трубопровод:
val rf = new RandomForestClassifier()
val pipeline = new Pipeline().setStages(Array(rf))
val cv = new CrossValidator().setNumFolds(10).setEstimator(pipeline).setEvaluator(new BinaryClassificationEvaluator)
ошибки проявляющейся при выполнении этой строки:
val cmModel = cv.fit(dataFixed)
java.lang.IllegalArgumentException: Поле "функции" не существует.
Можно установить столбец меток и колонку функций в RandomForestClassifier, однако у меня есть 4 столбца в качестве предикторов (функций) не только один.
Как я должен организовать свой фрейм данных, чтобы он правильно маркировал столбцы и столбцы?
Для вашего удобства здесь полный код:
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.tuning.CrossValidator
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
object SampleClassification {
def main(args: Array[String]): Unit = {
//set spark context
val conf = new SparkConf().setAppName("Simple Application").setMaster("local");
val sc = new SparkContext(conf)
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.implicits._
import com.databricks.spark.csv._
//load data by using databricks "Spark CSV Library"
val data = sqlContext.csvFile("/home/dusan/sample.csv")
//by default all columns are imported as string so we need to change "age" and "hours_per_week" to Int
val toInt = udf[Int, String](_.toInt)
val dataFixed = data.withColumn("age", toInt(data("age"))).withColumn("hours_per_week",toInt(data("hours_per_week")))
val rf = new RandomForestClassifier()
val pipeline = new Pipeline().setStages(Array(rf))
val cv = new CrossValidator().setNumFolds(10).setEstimator(pipeline).setEvaluator(new BinaryClassificationEvaluator)
// this fails with error
//java.lang.IllegalArgumentException: Field "features" does not exist.
val cmModel = cv.fit(dataFixed)
}
}
Спасибо за помощь!
Не известно о языке scala, но где вы устанавливаете метки и функции из набора данных, например LabeledPoint (метки, список (функции)), посмотрите пример в https://spark.apache.org/docs/latest/mllib -linear-methods.html –
@ABC, пожалуйста, проверьте мой комментарий в вопросе ниже. –
проверьте этот пример https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala где val model = pipeline.fit (обучение .toDF()) использует dataframe в конвейере –