A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://docs.databricks.com/aws/en/udf/aggregate-scala below:

User-defined aggregate functions - Scala

User-defined aggregate functions - Scala

This article contains an example of a UDAF and how to register it for use in Apache Spark SQL. See User-defined aggregate functions (UDAFs) for more details.

Requirements​ Implement a UserDefinedAggregateFunction​

Scala

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class GeometricMean extends UserDefinedAggregateFunction {

override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("value", DoubleType) :: Nil)


override def bufferSchema: StructType = StructType(
StructField("count", LongType) ::
StructField("product", DoubleType) :: Nil
)


override def dataType: DataType = DoubleType

override def deterministic: Boolean = true


override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 1.0
}


override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Long](0) + 1
buffer(1) = buffer.getAs[Double](1) * input.getAs[Double](0)
}


override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
buffer1(1) = buffer1.getAs[Double](1) * buffer2.getAs[Double](1)
}


override def evaluate(buffer: Row): Any = {
math.pow(buffer.getDouble(1), 1.toDouble / buffer.getLong(0))
}
}
Register the UDAF with Spark SQL​

Scala

spark.udf.register("gm", new GeometricMean)
Use your UDAF​

Scala


import org.apache.spark.sql.functions._

val ids = spark.range(1, 20)
ids.createOrReplaceTempView("ids")
val df = spark.sql("select id, id % 3 as group_id from ids")
df.createOrReplaceTempView("simple")

SQL


select group_id, gm(id) from simple group by group_id

Scala




val gm = new GeometricMean


df.groupBy("group_id").agg(gm(col("id")).as("GeometricMean")).show()


df.groupBy("group_id").agg(expr("gm(id) as GeometricMean")).show()

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4