Posts 6-如何处理 Null
Post
Cancel

6-如何处理 Null

在上篇文章我,我们介绍了 Codegen 实现的自定义函数,当时为了简化叙述,假设所有输入数据均为非空值。在本篇文章中,我们将介绍如何在 Codegen 中处理空值的问题。

Codegen 的自定义函数

自定义函数的需求与前文一样:

  • 需要一个名字为 my_foo 的函数
  • 该函数接受两个 double 类型的参数作为输入
  • 参数名记为 x,y
  • 函数输出 x*y+3

与不处理 Null 的自定义函数有两处区别: nullabledoGenCode, 程序核心代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
package org.apache.spark.sql.myfunctions {

  import org.apache.spark.sql.catalyst.InternalRow
  import org.apache.spark.sql.catalyst.expressions._
  import org.apache.spark.sql.catalyst.expressions.codegen._
  import org.apache.spark.sql.catalyst.expressions.codegen.Block._

  case class my_foo(inputExpressions: Seq[Expression]) extends Expression with ExpectsInputTypes {
    override def nullable: Boolean = inputExpressions(0).nullable || inputExpressions(1).nullable

    override def eval(input: InternalRow): Any = ???

    override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
      val left_expr = inputExpressions(0)
      val right_expr = inputExpressions(1)
      val left_code = left_expr.genCode(ctx)
      val right_code = right_expr.genCode(ctx)

      if (nullable) {
        val nullSafeEval =
          left_code.code + ctx.nullSafeExec(left_expr.nullable, left_code.isNull) {
            right_code.code + ctx.nullSafeExec(right_expr.nullable, right_code.isNull) {
              s"""
                 | ${ev.isNull} = false; // resultCode could change nullability.
                 | ${ev.value} = ${left_code.value} * ${right_code.value};
                 | ${ev.value} = ${ev.value} + 3;
                 |""".stripMargin
            }
          }
        ev.copy(code =
          code"""
             boolean ${ev.isNull} = true;
             ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
             $nullSafeEval
              """)
      } else {
        ev.copy(code =
          code"""
             ${left_code.code}
             ${right_code.code}
             ${CodeGenerator.javaType(dataType)} ${ev.value} = ${left_code.value} * ${right_code.value};
             ${ev.value} = ${ev.value} + 3;
              """, FalseLiteral)
      }
    }

    override def dataType: DataType = DoubleType

    override def children: Seq[Expression] = inputExpressions

    override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
  }

}

程序说明

  1. nullable 不再是简单的 truefalse ,而是根据输入参数确定,只要输入参数有一个允许为空,那么输出就允许为空
    1
    
    override def nullable: Boolean = inputExpressions(0).nullable || inputExpressions(1).nullable
    
  2. 根据 nullable 的值,doGenCode 分为两部分
  3. nullablefalse 时,doGenCode 的代码与上篇文章 完全一致
  4. nullabletrue 时,程序在 codegn 代码中首先判断输入均不为 null 之后再做 my_foo 的函数运算
    1
    2
    3
    4
    5
    6
    7
    8
    9
    
    left_code.code + ctx.nullSafeExec(left_expr.nullable, left_code.isNull) {
      right_code.code + ctx.nullSafeExec(right_expr.nullable, right_code.isNull) {
     s"""
      | ${ev.isNull} = false; // resultCode could change nullability.
      | ${ev.value} = ${left_code.value} * ${right_code.value};
      | ${ev.value} = ${ev.value} + 3;
      |""".stripMargin
      }
    }
    
  5. nullabletrue 时,codegen 代码需要定义 bool 变量 ${ev.isNull} ,并对其赋值,表示本次计算输出结果是否为 null

运行程序

完整代码位于 https://github.com/adream307/SparkSQLWithCodegen/tree/master/code/codegen_with_null 目录内,示例代码中包含两处查询,分别对应 nullable 分别为 truefalse 的查询代码。

第一次查询

第一次查询定义的表结构如下,即输入数据 xy 均允许为空,那么 my_foo 的输出亦允许为空

1
val sch = StructType(Array(StructField("x", DoubleType, true), StructField("y", DoubleType, true)))

程序输出效果如下:

1
2
3
4
5
6
7
8
9
+----+----+------------+
|   x|   y|my_foo(x, y)|
+----+----+------------+
| 1.0| 2.0|         5.0|
| 3.0| 4.0|        15.0|
| 5.0|null|        null|
|null| 6.0|        null|
|null|null|        null|
+----+----+------------+

Codegen 生成的完整 java 代码如下,其中第 30~38 行即为 my_foo 函数对 null 值的处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private scala.collection.Iterator rdd_input_0;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] rdd_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
/* 011 */
/* 012 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 013 */     this.references = references;
/* 014 */   }
/* 015 */
/* 016 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 017 */     partitionIndex = index;
/* 018 */     this.inputs = inputs;
/* 019 */     rdd_input_0 = inputs[0];
/* 020 */     rdd_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 021 */     rdd_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 022 */     rdd_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 0);
/* 023 */
/* 024 */   }
/* 025 */
/* 026 */   private void project_doConsume_0(double project_expr_0_0, boolean project_exprIsNull_0_0, double project_expr_1_0, boolean project_exprIsNull_1_0) throws java.io.IOException {
/* 027 */     boolean project_isNull_4 = true;
/* 028 */     double project_value_4 = -1.0;
/* 029 */
/* 030 */     if (!project_exprIsNull_0_0) {
/* 031 */       if (!project_exprIsNull_1_0) {
/* 032 */         project_isNull_4 = false; // resultCode could change nullability.
/* 033 */         project_value_4 = project_expr_0_0 * project_expr_1_0;
/* 034 */         project_value_4 = project_value_4 + 3;
/* 035 */
/* 036 */       }
/* 037 */
/* 038 */     }
/* 039 */     rdd_mutableStateArray_0[2].reset();
/* 040 */
/* 041 */     rdd_mutableStateArray_0[2].zeroOutNullBytes();
/* 042 */
/* 043 */     if (project_exprIsNull_0_0) {
/* 044 */       rdd_mutableStateArray_0[2].setNullAt(0);
/* 045 */     } else {
/* 046 */       rdd_mutableStateArray_0[2].write(0, project_expr_0_0);
/* 047 */     }
/* 048 */
/* 049 */     if (project_exprIsNull_1_0) {
/* 050 */       rdd_mutableStateArray_0[2].setNullAt(1);
/* 051 */     } else {
/* 052 */       rdd_mutableStateArray_0[2].write(1, project_expr_1_0);
/* 053 */     }
/* 054 */
/* 055 */     if (project_isNull_4) {
/* 056 */       rdd_mutableStateArray_0[2].setNullAt(2);
/* 057 */     } else {
/* 058 */       rdd_mutableStateArray_0[2].write(2, project_value_4);
/* 059 */     }
/* 060 */     append((rdd_mutableStateArray_0[2].getRow()));
/* 061 */
/* 062 */   }
/* 063 */
/* 064 */   protected void processNext() throws java.io.IOException {
/* 065 */     while ( rdd_input_0.hasNext()) {
/* 066 */       InternalRow rdd_row_0 = (InternalRow) rdd_input_0.next();
/* 067 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 068 */       boolean rdd_isNull_0 = rdd_row_0.isNullAt(0);
/* 069 */       double rdd_value_0 = rdd_isNull_0 ?
/* 070 */       -1.0 : (rdd_row_0.getDouble(0));
/* 071 */       boolean rdd_isNull_1 = rdd_row_0.isNullAt(1);
/* 072 */       double rdd_value_1 = rdd_isNull_1 ?
/* 073 */       -1.0 : (rdd_row_0.getDouble(1));
/* 074 */
/* 075 */       project_doConsume_0(rdd_value_0, rdd_isNull_0, rdd_value_1, rdd_isNull_1);
/* 076 */       if (shouldStop()) return;
/* 077 */     }
/* 078 */   }
/* 079 */
/* 080 */ }

第二次查询

第二次查询定义的表结构如下,即输入数据 xy 均不为空,那么 my_foo 的输出亦不为空

1
val sch2 = StructType(Array(StructField("x", DoubleType, false), StructField("y", DoubleType, false)))

程序输出效果如下:

1
2
3
4
5
6
7
8
9
+---+----+------------+
|  x|   y|my_foo(x, y)|
+---+----+------------+
|1.0| 2.0|         5.0|
|3.0| 4.0|        15.0|
|5.0| 6.0|        33.0|
|7.0| 8.0|        59.0|
|9.0|10.0|        93.0|
+---+----+------------+

Codegen 生成的完整 java 代码如下,与上篇文章的代码完全一致。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private scala.collection.Iterator rdd_input_0;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] rdd_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
/* 011 */
/* 012 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 013 */     this.references = references;
/* 014 */   }
/* 015 */
/* 016 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 017 */     partitionIndex = index;
/* 018 */     this.inputs = inputs;
/* 019 */     rdd_input_0 = inputs[0];
/* 020 */     rdd_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 021 */     rdd_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 022 */     rdd_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 0);
/* 023 */
/* 024 */   }
/* 025 */
/* 026 */   private void project_doConsume_0(double project_expr_0_0, double project_expr_1_0) throws java.io.IOException {
/* 027 */     double project_value_4 = project_expr_0_0 * project_expr_1_0;
/* 028 */     project_value_4 = project_value_4 + 3;
/* 029 */     rdd_mutableStateArray_0[2].reset();
/* 030 */
/* 031 */     rdd_mutableStateArray_0[2].write(0, project_expr_0_0);
/* 032 */
/* 033 */     rdd_mutableStateArray_0[2].write(1, project_expr_1_0);
/* 034 */
/* 035 */     rdd_mutableStateArray_0[2].write(2, project_value_4);
/* 036 */     append((rdd_mutableStateArray_0[2].getRow()));
/* 037 */
/* 038 */   }
/* 039 */
/* 040 */   protected void processNext() throws java.io.IOException {
/* 041 */     while ( rdd_input_0.hasNext()) {
/* 042 */       InternalRow rdd_row_0 = (InternalRow) rdd_input_0.next();
/* 043 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 044 */       double rdd_value_0 = rdd_row_0.getDouble(0);
/* 045 */       double rdd_value_1 = rdd_row_0.getDouble(1);
/* 046 */
/* 047 */       project_doConsume_0(rdd_value_0, rdd_value_1);
/* 048 */       if (shouldStop()) return;
/* 049 */     }
/* 050 */   }
/* 051 */
/* 052 */ }
This post is licensed under CC BY 4.0 by the author.
Contents