热门IT资讯网

spark-sql 自定义函数

发表于:2024-11-24 作者:热门IT资讯网编辑
编辑最后更新 2024年11月24日,(1)自定义UDFobject SparkSqlTest { def main(args: Array[String]): Unit = { //屏蔽多余的日志 Lo

(1)自定义UDF

object SparkSqlTest {    def main(args: Array[String]): Unit = {        //屏蔽多余的日志        Logger.getLogger("org.apache.hadoop").setLevel(Level.WARN)        Logger.getLogger("org.apache.spark").setLevel(Level.WARN)        Logger.getLogger("org.project-spark").setLevel(Level.WARN)        //构建编程入口        val conf: SparkConf = new SparkConf()        conf.setAppName("SparkSqlTest")            .setMaster("local[2]")        val spark: SparkSession = SparkSession.builder().config(conf)            .getOrCreate()        //创建sqlcontext对象        val sqlContext: SQLContext = spark.sqlContext        /**          * 注册定义的UDF:          * 这里的泛型[Int,String]          * 第一个是返回值类型,后面可以是一个或者多个,是方法参数类型          */        sqlContext.udf.register[Int,String]("strLen",strLen)        val sql=            """              |select strLen("zhangsan")            """.stripMargin        spark.sql(sql).show()    }    //自定义UDF方法    def strLen(str:String):Integer={        str.length    }}

(2) 自定义UDAF

这里举的例子是实现一个count:
自定义UDAF类:

    class MyCountUDAF extends UserDefinedAggregateFunction{    //该UDAF输入的数据类型    override def inputSchema: StructType = {        StructType(List(            StructField("age",DataTypes.IntegerType)        ))    }    //在该UDAF中聚合的数据类型    override def bufferSchema: StructType = {        StructType(List(            StructField("age",DataTypes.IntegerType)        ))    }    //该UDAF输出的数据类型    override def dataType: DataType = DataTypes.IntegerType    //确定性判断,通常特定输入和输出的类型一致    override def deterministic: Boolean = true    //buffer:计算过程中临时的存储了聚合结果的Buffer    override def initialize(buffer: MutableAggregationBuffer): Unit = {        buffer.update(0,0)    }    /**      * 分区内的数据聚合合并      * @param buffer:就是我们在initialize方法中声明初始化的临时缓冲区      * @param input:聚合操作新传入的值      */    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {        val oldValue=buffer.getInt(0)        buffer.update(0,oldValue+1)    }    /**      * 分区间的聚合      * @param buffer1:分区一聚合的临时结果      * @param buffer2;分区二聚合的临时结果      */    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {        val p1=buffer1.getInt(0)        val p2=buffer2.getInt(0)        buffer1.update(0,p1+p2)    }    //该聚合函数最终输出的值    override def evaluate(buffer: Row): Any = {        buffer.get(0)    }}

调用:

object SparkSqlTest {    def main(args: Array[String]): Unit = {        //屏蔽多余的日志        Logger.getLogger("org.apache.hadoop").setLevel(Level.WARN)        Logger.getLogger("org.apache.spark").setLevel(Level.WARN)        Logger.getLogger("org.project-spark").setLevel(Level.WARN)        //构建编程入口        val conf: SparkConf = new SparkConf()        conf.setAppName("SparkSqlTest")            .setMaster("local[2]")            .set("spark.serializer","org.apache.spark.serializer.KryoSerializer")            .registerKryoClasses(Array(classOf[Student]))        val spark: SparkSession = SparkSession.builder().config(conf)            .getOrCreate()        //创建sqlcontext对象        val sqlContext: SQLContext = spark.sqlContext        //注册UDAF        sqlContext.udf.register("myCount",new MyCountUDAF())        val stuList = List(            new Student("委xx", 18),            new Student("吴xx", 18),            new Student("戚xx", 18),            new Student("王xx", 19),            new Student("薛xx", 19)        )        import spark.implicits._        val stuDS: Dataset[Student] = sqlContext.createDataset(stuList)        stuDS.createTempView("student")        val sql=            """              |select myCount(1) counts              |from student              |group by age              |order by counts            """.stripMargin        spark.sql(sql).show()    }}case class Student(name:String,age:Int)
0