在之前的文章中,我们介绍了如何使用 Codegen
实现自定义函数,但是自定义函数参数类型及返回值类型均为 Spark
原生的数据类型。
从本篇文章开始,我们介绍如何在 Spark
中自定义数据类型 (UDT) ,以及针对该 UDT
的自定义函数,最后,我们希望这些自定义函数也是 Codegen 实现的
UDT
自定义数据类型的要求如下:
UDT
的名字为my_point
my_point
包含两个double
类型的成员变量x
和y
UDT
的核心代码如下,完整代码见 https://github.com/adream307/SparkSQLWithCodegen/tree/master/code/udt_example
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
package org.apache.spark.sql.myfunctions {
@SQLUserDefinedType(udt = classOf[my_point_udt])
class my_point(val x: Double, val y: Double) extends Serializable {
override def hashCode(): Int = 31 * (31 * x.hashCode()) + y.hashCode()
override def equals(other: Any): Boolean = other match {
case that: my_point => this.x == that.x && this.y == that.y
case _ => false
}
override def toString(): String = s"($x, $y)"
}
class my_point_udt extends UserDefinedType[my_point] {
override def sqlType: DataType = ArrayType(DoubleType, false)
override def serialize(obj: my_point): GenericArrayData = {
val output = new Array[Double](2)
output(0) = obj.x
output(1) = obj.y
new GenericArrayData(output)
}
override def deserialize(datum: Any): my_point = {
datum match {
case values: ArrayData => new my_point(values.getDouble(0), values.getDouble(1))
}
}
override def userClass: Class[my_point] = classOf[my_point]
}
}
原始数据定义如下,数据类型即为 my_point
1
2
3
4
5
6
7
val raw_data = Seq(
Row(1, new my_point(1, 1), new my_point(10, 10)),
Row(2, new my_point(2, 2), new my_point(20, 20)),
Row(3, new my_point(3, 3), new my_point(30, 30)),
Row(4, new my_point(4, 4), new my_point(40, 40)),
Row(5, new my_point(5, 5), new my_point(50, 50))
)
表结构定义如下,定义了数据类型为 my_point_udt
1
2
3
4
val sch = StructType(Array(
StructField("idx", IntegerType, false),
StructField("point1", new my_point_udt, false),
StructField("point2", new my_point_udt, false)))
查询语句定义如下:
1
val test_sql = spark.sql("select idx, point1, point2 from data_test")
程序输出结果如下:
1
2
3
4
5
6
7
8
9
+---+----------+------------+
|idx| point1| point2|
+---+----------+------------+
| 1|(1.0, 1.0)|(10.0, 10.0)|
| 2|(2.0, 2.0)|(20.0, 20.0)|
| 3|(3.0, 3.0)|(30.0, 30.0)|
| 4|(4.0, 4.0)|(40.0, 40.0)|
| 5|(5.0, 5.0)|(50.0, 50.0)|
+---+----------+------------+