This article contains an example of a UDAF and how to register it for use in Apache Spark SQL. See User-defined aggregate functions (UDAFs) for more details.
RequirementsâUserDefinedAggregateFunction
â
Scala
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
class GeometricMean extends UserDefinedAggregateFunction {
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("value", DoubleType) :: Nil)
override def bufferSchema: StructType = StructType(
StructField("count", LongType) ::
StructField("product", DoubleType) :: Nil
)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 1.0
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Long](0) + 1
buffer(1) = buffer.getAs[Double](1) * input.getAs[Double](0)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
buffer1(1) = buffer1.getAs[Double](1) * buffer2.getAs[Double](1)
}
override def evaluate(buffer: Row): Any = {
math.pow(buffer.getDouble(1), 1.toDouble / buffer.getLong(0))
}
}
Register the UDAF with Spark SQLâ
Scala
spark.udf.register("gm", new GeometricMean)
Use your UDAFâ
Scala
import org.apache.spark.sql.functions._
val ids = spark.range(1, 20)
ids.createOrReplaceTempView("ids")
val df = spark.sql("select id, id % 3 as group_id from ids")
df.createOrReplaceTempView("simple")
SQL
select group_id, gm(id) from simple group by group_id
Scala
val gm = new GeometricMean
df.groupBy("group_id").agg(gm(col("id")).as("GeometricMean")).show()
df.groupBy("group_id").agg(expr("gm(id) as GeometricMean")).show()
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4