2016-04-14 6 views
5

Я пытаюсь определить функцию UserDefinedAggregateFunction (UDAF) в Spark, которая подсчитывает количество вхождений для каждого уникального значения в столбце группы.Почему Mutable map автоматически становится неизменным в UserDefinedAggregateFunction (UDAF) в Spark

Это пример: Предположим, у меня есть dataframe df, как это,

+----+----+ 
|col1|col2| 
+----+----+ 
| a| a1| 
| a| a1| 
| a| a2| 
| b| b1| 
| b| b2| 
| b| b3| 
| b| b1| 
| b| b1| 
+----+----+ 

у меня будет UDAF DistinctValues ​​

val func = new DistinctValues 

Тогда я применить его к dataframe ДФ

val agg_value = df.groupBy("col1").agg(func(col("col2")).as("DV")) 

Я ожидаю, что у меня будет что-то похожее е это:

+----+--------------------------+ 
|col1|DV      | 
+----+--------------------------+ 
| a| Map(a1->2, a2->1)  | 
| b| Map(b1->3, b2->1, b3->1)| 
+----+--------------------------+ 

Так что я вышел с UDAF, как это,

import org.apache.spark.sql.expressions.MutableAggregationBuffer 
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction 
import org.apache.spark.sql.Row 
import org.apache.spark.sql.types.StructType 
import org.apache.spark.sql.types.StructField 
import org.apache.spark.sql.types.DataType 
import org.apache.spark.sql.types.ArrayType 
import org.apache.spark.sql.types.StringType 
import org.apache.spark.sql.types.MapType 
import org.apache.spark.sql.types.LongType 
import Array._ 

class DistinctValues extends UserDefinedAggregateFunction { 
    def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("value", StringType) :: Nil) 

    def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil) 

    def dataType: DataType = MapType(StringType, LongType) 
    def deterministic: Boolean = true 

    def initialize(buffer: MutableAggregationBuffer): Unit = { 
    buffer(0) = scala.collection.mutable.Map() 
    } 

    def update(buffer: MutableAggregationBuffer, input: Row) : Unit = { 
    val str = input.getAs[String](0) 
    var mp = buffer.getAs[scala.collection.mutable.Map[String, Long]](0) 
    var c:Long = mp.getOrElse(str, 0) 
    c = c + 1 
    mp.put(str, c) 
    buffer(0) = mp 
    } 

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = { 
    var mp1 = buffer1.getAs[scala.collection.mutable.Map[String, Long]](0) 
    var mp2 = buffer2.getAs[scala.collection.mutable.Map[String, Long]](0) 
    mp2 foreach { 
     case (k ,v) => { 
      var c:Long = mp1.getOrElse(k, 0) 
      c = c + v 
      mp1.put(k ,c) 
     } 
    } 
    buffer1(0) = mp1 
    } 

    def evaluate(buffer: Row): Any = { 
     buffer.getAs[scala.collection.mutable.Map[String, LongType]](0) 
    } 
} 

Тогда у меня есть эта функция на моем dataframe,

val func = new DistinctValues 
val agg_values = df.groupBy("col1").agg(func(col("col2")).as("DV")) 

Он дал такую ​​ошибку,

func: DistinctValues = [email protected] 
org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 32.0 failed 4 times, most recent failure: Lost task 1.3 in stage 32.0 (TID 884, ip-172-31-22-166.ec2.internal): java.lang.ClassCastException: scala.collection.immutable.Map$EmptyMap$ cannot be cast to scala.collection.mutable.Map 
at $iwC$$iwC$DistinctValues.update(<console>:39) 
at org.apache.spark.sql.execution.aggregate.ScalaUDAF.update(udaf.scala:431) 
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:187) 
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:180) 
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.processCurrentSortedGroup(SortBasedAggregationIterator.scala:116) 
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:152) 
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29) 
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) 
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) 
at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:149) 
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73) 
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41) 
at org.apache.spark.scheduler.Task.run(Task.scala:89) 
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:213) 
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) 
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) 
at java.lang.Thread.run(Thread.java:745) 

Похоже, что в update(buffer: MutableAggregationBuffer, input: Row) метод, переменная buffer является immutable.Map, программа усталой, чтобы бросить его mutable.Map,

Но я mutable.Map инициализировать переменную buffer в initialize(buffer: MutableAggregationBuffer, input:Row) методы. Это одна и та же переменная, переданная методу update? А также buffer - mutableAggregationBuffer, поэтому он должен быть изменчивым, не так ли?

Почему моя mutable.Map стала неизменной? Кто-нибудь знает, что произошло?

Мне действительно нужна измененная карта в этой функции для выполнения задачи. Я знаю, что есть временное решение для создания изменчивой карты с неизменяемой карты, а затем ее обновления. Но я действительно хочу знать, почему изменчивый превращается в неизменяемый в программе автоматически, для меня это не имеет смысла.

ответ

4

Верьте, что это MapType в вашем StructType. buffer поэтому содержит Map, что было бы неизменным.

Вы можете преобразовать его, но почему бы вам просто не оставить его неизменным и сделать это:

mp = mp + (k -> c) 

, чтобы добавить запись в незыблемый Map?

Рабочий пример ниже:

class DistinctValues extends UserDefinedAggregateFunction { 
    def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("_2", IntegerType) :: Nil) 

    def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil) 

    def dataType: DataType = MapType(StringType, LongType) 
    def deterministic: Boolean = true 

    def initialize(buffer: MutableAggregationBuffer): Unit = { 
    buffer(0) = Map() 
    } 

    def update(buffer: MutableAggregationBuffer, input: Row) : Unit = { 
    val str = input.getAs[String](0) 
    var mp = buffer.getAs[Map[String, Long]](0) 
    var c:Long = mp.getOrElse(str, 0) 
    c = c + 1 
    mp = mp + (str -> c) 
    buffer(0) = mp 
    } 

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = { 
    var mp1 = buffer1.getAs[Map[String, Long]](0) 
    var mp2 = buffer2.getAs[Map[String, Long]](0) 
    mp2 foreach { 
     case (k ,v) => { 
      var c:Long = mp1.getOrElse(k, 0) 
      c = c + v 
      mp1 = mp1 + (k -> c) 
     } 
    } 
    buffer1(0) = mp1 
    } 

    def evaluate(buffer: Row): Any = { 
     buffer.getAs[Map[String, LongType]](0) 
    } 
} 
+0

Хороший улов! Хм, «MapyType» в «StructType» может быть так. Но в 'spark.sql.types' нет другого измененного типа карты, если я не определяю свой собственный. –

+0

Как я уже сказал, не используйте - просто используйте неизменяемую «Карту». 'mp = mp + (k -> c)' на неизменяемой «карте» предоставляет вам те же функции, что и «mp.put (k, c) 'на изменчивой' Map' –

+0

'mp = mp + (k -> c)' работает! Я новичок в scala, не знал, что вы можете манипулировать неизменным типом данных, подобным этому. Большое спасибо! –

 Смежные вопросы

  • Нет связанных вопросов^_^