spark-sql 自定义函数
发表于:2024-11-24 作者:热门IT资讯网编辑
编辑最后更新 2024年11月24日,(1)自定义UDFobject SparkSqlTest { def main(args: Array[String]): Unit = { //屏蔽多余的日志 Lo
(1)自定义UDF
object SparkSqlTest { def main(args: Array[String]): Unit = { //屏蔽多余的日志 Logger.getLogger("org.apache.hadoop").setLevel(Level.WARN) Logger.getLogger("org.apache.spark").setLevel(Level.WARN) Logger.getLogger("org.project-spark").setLevel(Level.WARN) //构建编程入口 val conf: SparkConf = new SparkConf() conf.setAppName("SparkSqlTest") .setMaster("local[2]") val spark: SparkSession = SparkSession.builder().config(conf) .getOrCreate() //创建sqlcontext对象 val sqlContext: SQLContext = spark.sqlContext /** * 注册定义的UDF: * 这里的泛型[Int,String] * 第一个是返回值类型,后面可以是一个或者多个,是方法参数类型 */ sqlContext.udf.register[Int,String]("strLen",strLen) val sql= """ |select strLen("zhangsan") """.stripMargin spark.sql(sql).show() } //自定义UDF方法 def strLen(str:String):Integer={ str.length }}
(2) 自定义UDAF
这里举的例子是实现一个count:
自定义UDAF类:
class MyCountUDAF extends UserDefinedAggregateFunction{ //该UDAF输入的数据类型 override def inputSchema: StructType = { StructType(List( StructField("age",DataTypes.IntegerType) )) } //在该UDAF中聚合的数据类型 override def bufferSchema: StructType = { StructType(List( StructField("age",DataTypes.IntegerType) )) } //该UDAF输出的数据类型 override def dataType: DataType = DataTypes.IntegerType //确定性判断,通常特定输入和输出的类型一致 override def deterministic: Boolean = true //buffer:计算过程中临时的存储了聚合结果的Buffer override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer.update(0,0) } /** * 分区内的数据聚合合并 * @param buffer:就是我们在initialize方法中声明初始化的临时缓冲区 * @param input:聚合操作新传入的值 */ override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { val oldValue=buffer.getInt(0) buffer.update(0,oldValue+1) } /** * 分区间的聚合 * @param buffer1:分区一聚合的临时结果 * @param buffer2;分区二聚合的临时结果 */ override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { val p1=buffer1.getInt(0) val p2=buffer2.getInt(0) buffer1.update(0,p1+p2) } //该聚合函数最终输出的值 override def evaluate(buffer: Row): Any = { buffer.get(0) }}
调用:
object SparkSqlTest { def main(args: Array[String]): Unit = { //屏蔽多余的日志 Logger.getLogger("org.apache.hadoop").setLevel(Level.WARN) Logger.getLogger("org.apache.spark").setLevel(Level.WARN) Logger.getLogger("org.project-spark").setLevel(Level.WARN) //构建编程入口 val conf: SparkConf = new SparkConf() conf.setAppName("SparkSqlTest") .setMaster("local[2]") .set("spark.serializer","org.apache.spark.serializer.KryoSerializer") .registerKryoClasses(Array(classOf[Student])) val spark: SparkSession = SparkSession.builder().config(conf) .getOrCreate() //创建sqlcontext对象 val sqlContext: SQLContext = spark.sqlContext //注册UDAF sqlContext.udf.register("myCount",new MyCountUDAF()) val stuList = List( new Student("委xx", 18), new Student("吴xx", 18), new Student("戚xx", 18), new Student("王xx", 19), new Student("薛xx", 19) ) import spark.implicits._ val stuDS: Dataset[Student] = sqlContext.createDataset(stuList) stuDS.createTempView("student") val sql= """ |select myCount(1) counts |from student |group by age |order by counts """.stripMargin spark.sql(sql).show() }}case class Student(name:String,age:Int)