欢迎光临
Spark重温笔记(五):SparkSQL进阶操作——迭代计算,开窗函数,结合多种数据源,UDF自定义函数
   

Spark重温笔记(五):SparkSQL进阶操作——迭代计算,开窗函数,结合多种数据源,UDF自定义函数

Spark学习笔记

前言:今天是温习 Spark 的第 5 天啦!主要梳理了 SparkSQL 的进阶操作,包括spark结合hive做离线数仓,以及结合mysql,dataframe,以及最为核心的迭代计算逻辑-udf函数等,以及演示了几个企业级操作案例,希望对大家有帮助!

Tips:"分享是快乐的源泉💧,在我的博客里,不仅有知识的海洋🌊,还有满满的正能量加持💪,快来和我一起分享这份快乐吧😊!

喜欢我的博客的话,记得点个红心❤️和小关小注哦!您的支持是我创作的动力!"


文章目录

  • Spark学习笔记
    • 三、SparkSQL进阶操作
      • 1. spark 与 hive 结合
        • (1) 启动hive的metastore服务
        • (2) 测试SparkSQL结合hive是否成功
        • 2. spark开窗函数
        • 3. spark 读写 mysql
          • (1) read from mysql
          • (2) write to mysql
          • 4. dataframe数据源的UDF
            • (1) pandas_toDF
            • (2) UDF基础入门
            • (3) 类型不一返回null值
            • (4) 封装common函数为udf函数
            • (5) 列表输出类型
            • (6) 混合类型返回
            • (7) 混合函数设定
            • 5. Series 数据源的UDF
              • (1)Series 函数,用pandas_udf自定义
              • (2) series与迭代器结合
              • (3) Tuple是为了更好解包
              • (4) pandas_udf 聚合借助 applyInPandas
              • (5) 某一列加数问题
              • (6) 某两列相加问题
              • (7) series函数返回单值类型

                (本节的所有数据集放在我的资源下载区哦,感兴趣的小伙伴可以自行下载:最全面的SparkSQL系列案例数据集)

                三、SparkSQL进阶操作

                1. spark 与 hive 结合

                (1) 启动hive的metastore服务
                nohup /export/server/hive/bin/hive --service metastore 2>&1 >> /var/log.log &
                

                查看进程后,可以发现有一个RunJar进程,使用结束时可以使用kill -9 进行关闭

                (2) 测试SparkSQL结合hive是否成功
                • 1-Spark-sql方式测试
                  cd /export/server/spark
                  bin/spark-sql --master local[2] --executor-memory 512m --total-executor-cores 1
                  或:
                  bin/spark-sql --master spark://node1.itcast.cn:7077 --executor-memory 512m --total-executor-cores 1
                  

                  成功进入页面后,直接使用sql语句进行操作,exit()可以退出

                  • 2- Spark-Shell方式启动
                    cd /export/server/spark
                    bin/spark-shell --master local[3]
                    spark.sql("show databases").show()
                    

                    成功进入界面后,发现是scale版本的界面,:q可以退出界面

                    • 3-pyspark 方式启动
                      cd /export/server/spark
                      bin/pyspark --master local[3]
                      spark.sql("show databases").show()
                      

                      成功进入界面后,发现是python版本界面,exit()可以退出

                      • 4-beeline方式启动
                        sbin/start-thriftserver.sh                      =====================  无需启动hive
                        /export/server/spark/bin/beeline
                        Beeline version 1.2.1.spark2 by Apache Hive
                        beeline> !connect jdbc:hive2://node1.itcast.cn:10000
                        Connecting to jdbc:hive2://node1.itcast.cn:10000
                        Enter username for jdbc:hive2://node1.itcast.cn:10000: root
                        Enter password for jdbc:hive2://node1.itcast.cn:10000: ****
                        
                        • 5-pycharm 方式代码连接
                          • 建表语句+加载数据
                          • 创建连接方式,记得9083是hive-site.xml中
                            testHiveSQL.py
                            # -*- coding: utf-8 -*-
                            # Program function:测试和Hive整合
                            '''
                            * 1-准备环境SparkSession--enableHiveSupport
                            # 建议指定hive的metastore的地址(元数据的服务)
                            # 建议指定hive的warehouse的hdfs的地址(存放数据)
                            * 2-使用HiveSQL的加载数据文件方式:spark.sql("LOAD DATA LOCAL INPATH '/export/pyworkspace/pyspark_sparksql_chapter3/data/hive/student.csv' INTO TABLE person")
                            * 3-查询加载到数据表的数据:spark.sql("select * from table").show()
                            * 4-使用HiveSQL统计操作
                            * 5-结束
                            '''
                            from pyspark.sql import SparkSession
                            if __name__ == '__main__':
                                spark = SparkSession \
                                    .builder \
                                    .appName("testHive") \
                                    .master("local[*]") \
                                    .enableHiveSupport() \
                                    .config("spark.sql.warehouse.dir", "hdfs://node1:9820/user/hive/warehouse") \
                                    .config("hive.metastore.uris", "thrift://node1:9083") \
                                    .getOrCreate()
                                spark.sql("show databases;").show()
                                # 2-使用已经存在的数据库
                                spark.sql("use sparkhive2")
                                # 1-创建student的数据表
                                # spark.sql(
                                #     "create table if not exists person (id int, name string, age int) row format delimited fields terminated by ','")
                                # 2-执行load加载数据
                                # spark.sql(
                                #     "LOAD DATA LOCAL INPATH '/export/data/pyspark_workspace/PySpark-SparkSQL_3.1.2/data/sql/hive/student.csv' INTO TABLE person")
                                # 3-执行数据的查询,并且需要查看bin/Hive是否存在表
                                spark.sql("select * from person").show()
                                # +---+--------+---+
                                # | id|    name|age|
                                # +---+--------+---+
                                # |  1|zhangsan| 30|
                                # |  2|    lisi| 40|
                                # |  3|  wangwu| 50|
                                # +---+--------+---+
                                print("===========================================================")
                                import pyspark.sql.functions as fn
                                # 了解spark.read.table
                                spark.read \
                                    .table("person") \
                                    .groupBy("name") \
                                    .agg(fn.round(fn.avg("age"), 2).alias("avg_age")) \
                                    .show(10, truncate=False)
                                spark.stop()
                            

                            2. spark开窗函数

                            • 1-row_number():顺序不重复
                            • 2-rank():跳跃函数
                            • 3-dense_rank():顺序重复函数
                              sparkWindows.py
                              # -*- coding: utf-8 -*-
                              # Program function:实验SparkSQL的开窗函数
                              # 分类:聚合类开窗函数,排序类的开窗函数
                              from pyspark.sql import SparkSession
                              if __name__ == '__main__':
                                  # 创建上下文环境
                                  spark = SparkSession.builder \
                                      .appName('test') \
                                      .getOrCreate()
                                  sc = spark.sparkContext
                                  sc.setLogLevel("WARN")
                                  # 创建数据文件
                                  scoreDF = spark.sparkContext.parallelize([
                                      ("a1", 1, 80),
                                      ("a2", 1, 78),
                                      ("a3", 1, 95),
                                      ("a4", 2, 74),
                                      ("a5", 2, 92),
                                      ("a6", 3, 99),
                                      ("a7", 3, 99),
                                      ("a8", 3, 45),
                                      ("a9", 3, 55),
                                      ("a10", 3, 78),
                                      ("a11", 3, 100)]
                                  ).toDF(["name", "class", "score"])
                                  scoreDF.createOrReplaceTempView("scores")
                                  scoreDF.show()
                                  #+----+-----+-----+
                                  # |name|class|score|
                                  # +----+-----+-----+
                                  # |  a1|    1|   80|
                                  # |  a2|    1|   78|
                                  # |  a3|    1|   95|
                                  # |  a4|    2|   74|
                                  # |  a5|    2|   92|
                                  # |  a6|    3|   99|
                                  # |  a7|    3|   99|
                                  # |  a8|    3|   45|
                                  # |  a9|    3|   55|
                                  # | a10|    3|   78|
                                  # | a11|    3|  100|
                                  # +----+-----+-----+
                                  # 聚合类开窗函数
                                  spark.sql("select  count(name)  from scores").show()
                                  spark.sql("select name,class,score,count(name) over() name_count from scores").show()
                                  spark.sql("select name,class,score,count(name) over(partition by class) name_count from scores").show()
                                  # 排序类开窗函数
                                  # 顺序排序
                                  spark.sql("select name,class,score,row_number() over(partition by class order by score)  `rank` from scores").show()
                                  # +----+-----+-----+----+
                                  # |name|class|score|rank|
                                  # +----+-----+-----+----+
                                  # |  a2|    1|   78|   1|
                                  # |  a1|    1|   80|   2|
                                  # |  a3|    1|   95|   3|
                                  # |  a8|    3|   45|   1|
                                  # |  a9|    3|   55|   2|
                                  # | a10|    3|   78|   3|
                                  # |  a6|    3|   99|   4|
                                  # |  a7|    3|   99|   5|
                                  # | a11|    3|  100|   6|
                                  # |  a4|    2|   74|   1|
                                  # |  a5|    2|   92|   2|
                                  # +----+-----+-----+----+
                                  
                                  # 跳跃排序
                                  spark.sql("select name,class,score,rank() over(partition by class order by score) `rank` from scores").show()
                                  # +----+-----+-----+----+
                                  # |name|class|score|rank|
                                  # +----+-----+-----+----+
                                  # |  a2|    1|   78|   1|
                                  # |  a1|    1|   80|   2|
                                  # |  a3|    1|   95|   3|
                                  # |  a8|    3|   45|   1|
                                  # |  a9|    3|   55|   2|
                                  # | a10|    3|   78|   3|
                                  # |  a6|    3|   99|   4|
                                  # |  a7|    3|   99|   4|
                                  # | a11|    3|  100|   6|
                                  # |  a4|    2|   74|   1|
                                  # |  a5|    2|   92|   2|
                                  # +----+-----+-----+----+
                                  # dense_rank排序
                                  spark.sql("select name,class,score,dense_rank() over(partition by class order by score) `rank` from scores").show()
                                  # +----+-----+-----+----+
                                  # |name|class|score|rank|
                                  # +----+-----+-----+----+
                                  # |  a2|    1|   78|   1|
                                  # |  a1|    1|   80|   2|
                                  # |  a3|    1|   95|   3|
                                  # |  a8|    3|   45|   1|
                                  # |  a9|    3|   55|   2|
                                  # | a10|    3|   78|   3|
                                  # |  a6|    3|   99|   4|
                                  # |  a7|    3|   99|   4|
                                  # | a11|    3|  100|   5|
                                  # |  a4|    2|   74|   1|
                                  # |  a5|    2|   92|   2|
                                  # +----+-----+-----+----+
                                  # ntile
                                  spark.sql("select name,class,score,ntile(6) over(order by score) `rank` from scores").show()
                                  # +----+-----+-----+----+
                                  # |name|class|score|rank|
                                  # +----+-----+-----+----+
                                  # |  a8|    3|   45|   1|
                                  # |  a9|    3|   55|   1|
                                  # |  a4|    2|   74|   2|
                                  # |  a2|    1|   78|   2|
                                  # | a10|    3|   78|   3|
                                  # |  a1|    1|   80|   3|
                                  # |  a5|    2|   92|   4|
                                  # |  a3|    1|   95|   4|
                                  # |  a6|    3|   99|   5|
                                  # |  a7|    3|   99|   5|
                                  # | a11|    3|  100|   6|
                                  # +----+-----+-----+----+
                                  spark.stop()
                              

                              3. spark 读写 mysql

                              (1) read from mysql
                              • 1-spark.read.format:选择jdbc
                              • 2-四个option,1是mysql的url和链接地址,2是数据库和表,3是用户名,4是密码
                              • 3-load:加载数据
                                # -*- coding: utf-8 -*-
                                # Program function:从MySQL中读取数据
                                # -*- coding: utf-8 -*-
                                # Program function:
                                from pyspark.sql import functions as F
                                from collections import Counter
                                from pyspark.sql.types import *
                                from pyspark.sql import SparkSession
                                import os
                                os.environ['SPARK_HOME'] = '/export/server/spark'
                                PYSPARK_PYTHON = "/root/anaconda3/bin/python3"
                                # 当存在多个版本时,不指定很可能会导致出错
                                os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
                                os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON
                                if __name__ == '__main__':
                                    spark = SparkSession.builder \
                                        .appName('test') \
                                        .getOrCreate()
                                    sc = spark.sparkContext
                                    # 2.读取文件
                                    jdbcDF = spark.read \
                                        .format("jdbc") \
                                        .option("url", "jdbc:mysql://node1:3306/?serverTimezone=UTC&characterEncoding=utf8&useUnicode=true") \
                                        .option("dbtable", "bigdata.person") \
                                        .option("user", "root") \
                                        .option("password", "123456") \
                                        .load()
                                    jdbcDF.show()
                                    jdbcDF.printSchema()
                                
                                (2) write to mysql
                                • 1-coalesce:减少分区操作
                                • 2-write:写入操作
                                • 3-format:jdbc
                                • 4-mode:overwrite
                                • 5-mysql:五个option,一个save,option就是多了一个driver驱动
                                  # -*- coding: utf-8 -*-
                                  # Program function:将数据结果写入MySQL
                                  import os
                                  from pyspark.sql import SparkSession
                                  from pyspark.sql.types import *
                                  from pyspark.sql.types import Row
                                  os.environ['SPARK_HOME'] = '/export/server/spark'
                                  PYSPARK_PYTHON = "/root/anaconda3/bin/python3"
                                  # 当存在多个版本时,不指定很可能会导致出错
                                  os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
                                  os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON
                                  if __name__ == '__main__':
                                      spark = SparkSession.builder \
                                          .appName('test') \
                                          .getOrCreate()
                                      sc = spark.sparkContext
                                      # Load a text file and convert each line to a Row.
                                      # 读取一个文件转化每一行为Row对象
                                      lines = sc.textFile("file:///export/data/spark_practice/PySpark-SparkSQL_3.1.2/data/sql/people.txt")
                                      parts = lines.map(lambda l: l.split(","))
                                      people = parts.map(lambda p: Row(name=p[0], age=int(p[1])))
                                      # Infer the schema, and register the DataFrame as a table.
                                      # 推断Schema,并且将DataFrame注册为Table
                                      personDF = spark.createDataFrame(people)
                                      personDF \
                                          .coalesce(1) \
                                          .write \
                                          .format("jdbc") \
                                          .mode("overwrite") \
                                          .option("driver", "com.mysql.jdbc.Driver") \
                                          .option("url", "jdbc:mysql://node1:3306/?serverTimezone=UTC&characterEncoding=utf8&useUnicode=true") \
                                          .option("dbtable", "bigdata.person") \
                                          .option("user", "root") \
                                          .option("password", "123456") \
                                          .save()
                                      print("save To MySQL finished..............")
                                      spark.stop()
                                  

                                  4. dataframe数据源的UDF

                                  简介:

                                  • udf是用户自定义函数(User-Defined-Function):一进一出
                                  • udaf是用户自定义聚合函数(User-Defined Aggregation Function):多进一出,通常与groupBy联合使用
                                  • udtf是用户自定义表生成函数(User-Defined Table-Generating Functions):一进多出,常与flatmap联合使用
                                    (1) pandas_toDF
                                    • 1-Apache Arrow 是一种内存中的列式数据格式,用于 Spark 中以在 JVM 和 Python 进程之间有效地传输数据。
                                    • 2-在linux中执行安装:pip install pyspark[sql]
                                      # -*- coding: utf-8 -*-
                                      # Program function:从Pandas转化为DF
                                      from pyspark.sql import SparkSession
                                      import pandas as pd
                                      if __name__ == '__main__':
                                          spark = SparkSession.builder \
                                              .appName('test') \
                                              .master("local[*]") \
                                              .getOrCreate()
                                          spark.sparkContext.setLogLevel("WARN")
                                          # Apache Arrow 是一种内存中的列式数据格式,用于 Spark 中以在 JVM 和 Python 进程之间有效地传输数据。
                                          # 需要安装Apache Arrow
                                          # pip install pyspark[sql]  执行安装,在linux中执行pip install xxx
                                          # spark.conf.set("spark.sql.execution.arrow.enabled", "true")
                                          spark.conf.set("spark.sql.execution.arrow.enabled", "true")
                                          df_pd = pd.DataFrame(
                                              data={'integers': [1, 2, 3],
                                                    'floats': [-1.0, 0.6, 2.6],
                                                    'integer_arrays': [[1, 2], [3, 4.6], [5, 6, 8, 9]]}
                                          )
                                          df = spark.createDataFrame(df_pd)
                                          df.printSchema()
                                          # arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
                                          # root
                                          #  |-- integers: long (nullable = true)
                                          #  |-- floats: double (nullable = true)
                                          #  |-- integer_arrays: array (nullable = true)
                                          #  |    |-- element: double (containsNull = true)
                                          df.show()
                                          # +--------+------+--------------------+
                                          # |integers|floats|      integer_arrays|
                                          # +--------+------+--------------------+
                                          # |       1|  -1.0|          [1.0, 2.0]|
                                          # |       2|   0.6|          [3.0, 4.6]|
                                          # |       3|   2.6|[5.0, 6.0, 8.0, 9.0]|
                                          # +--------+------+--------------------+
                                      
                                      (2) UDF基础入门
                                      • 1-udf( 列式数据自定义函数,returnType = 类型)
                                      • 2-DSL: df.select(“integers”, udf_integger(“integers”)).show(),udf_integger已封装为函数
                                      • 3-SQL: spark.udf.register(“udf_integger”, udf_integger),注册临时函数
                                        # -*- coding: utf-8 -*-
                                        # Program function:从Pandas转化为DF
                                        import os
                                        from pyspark.sql import SparkSession
                                        import pandas as pd
                                        from pyspark.sql.functions import udf
                                        if __name__ == '__main__':
                                            spark = SparkSession.builder \
                                                .appName('test') \
                                                .master("local[*]") \
                                                .getOrCreate()
                                            spark.sparkContext.setLogLevel("WARN")
                                            # Apache Arrow 是一种内存中的列式数据格式,用于 Spark 中以在 JVM 和 Python 进程之间有效地传输数据。
                                            # 需要安装Apache Arrow
                                            # pip install pyspark[sql]  执行安装,在linux中执行pip install xxx
                                            # spark.conf.set("spark.sql.execution.arrow.enabled", "true")
                                            spark.conf.set("spark.sql.execution.arrow.enabled", "true")
                                            df_pd = pd.DataFrame(
                                                data={'integers': [1, 2, 3],
                                                      'floats': [-1.0, 0.6, 2.6],
                                                      'integer_arrays': [[1, 2], [3, 4.6], [5, 6, 8, 9]]}
                                            )
                                            df = spark.createDataFrame(df_pd)
                                            df.printSchema()
                                            df.show()
                                            # 需求:我们想将一个求平方和的python函数注册成一个Spark UDF函数
                                            def square(x):
                                                return x ** 2
                                            from pyspark.sql.functions import udf
                                            from pyspark.sql.types import IntegerType
                                            udf_integger = udf(square, returnType=IntegerType())
                                            # DSL
                                            df.select("integers", udf_integger("integers")).show()
                                            # SQL
                                            df.createOrReplaceTempView("table")
                                            # 如果使用SQL需要注册时临时函数
                                            spark.udf.register("udf_integger", udf_integger)
                                            spark.sql("select integers,udf_integger(integers) as integerX from table").show()
                                        
                                        (3) 类型不一返回null值
                                        • 现象:如果类型和udf的返回类型不一致的化,导致null的出现
                                          # -*- coding: utf-8 -*-
                                          # Program function:从Pandas转化为DF
                                          from pyspark.sql import SparkSession
                                          import pandas as pd
                                          if __name__ == '__main__':
                                              spark = SparkSession.builder \
                                                  .appName('test') \
                                                  .master("local[*]") \
                                                  .getOrCreate()
                                              spark.sparkContext.setLogLevel("WARN")
                                              # Apache Arrow 是一种内存中的列式数据格式,用于 Spark 中以在 JVM 和 Python 进程之间有效地传输数据。
                                              # 需要安装Apache Arrow
                                              # pip install pyspark[sql]  执行安装,在linux中执行pip install xxx
                                              # spark.conf.set("spark.sql.execution.arrow.enabled", "true")
                                              spark.conf.set("spark.sql.execution.arrow.enabled", "true")
                                              df_pd = pd.DataFrame(
                                                  data={'integers': [1, 2, 3],
                                                        'floats': [-1.0, 0.6, 2.6],
                                                        'integer_arrays': [[1, 2], [3, 4.6], [5, 6, 8, 9]]}
                                              )
                                              df = spark.createDataFrame(df_pd)
                                              df.printSchema()
                                              df.show()
                                              # 需求:我们想将一个求平方和的python函数注册成一个Spark UDF函数
                                              def square(x):
                                                  return x ** 2
                                              from pyspark.sql.functions import udf
                                              from pyspark.sql.types import IntegerType,FloatType
                                              # 现象:如果类型和udf的返回类型不一致的化,导致null的出现
                                              # 如何解决?基于列表方式解决
                                              udf_integger = udf(square, returnType=FloatType())
                                              # DSL
                                              df.select("integers", udf_integger("integers").alias("int"),udf_integger("floats").alias("flat")).show()
                                              
                                              # |integers| int|flat|
                                              # +--------+----+----+
                                              # |       1|null| 1.0|
                                              # |       2|null|0.36|
                                              # |       3|null|6.76|
                                              # +--------+----+----+
                                              
                                              # 下面演示转化为lambda表达式场景
                                              udf_integger_lambda = udf(lambda x:x**2, returnType=IntegerType())
                                              df.select("integers", udf_integger_lambda("integers").alias("int")).show()
                                              # +--------+---+
                                              # |integers|int|
                                              # +--------+---+
                                              # |       1|  1|
                                              # |       2|  4|
                                              # |       3|  9|
                                              # +--------+---+
                                          
                                          (4) 封装common函数为udf函数
                                          • 在函数头部加上:@udf(returnType=IntegerType()),可以替代:udf_int = udf(squareTest, returnType=IntegerType())
                                            # -*- coding: utf-8 -*-
                                            # Program function:从Pandas转化为DF
                                            import os
                                            from pyspark.sql import SparkSession
                                            import pandas as pd
                                            from pyspark.sql.functions import udf
                                            if __name__ == '__main__':
                                                spark = SparkSession.builder \
                                                    .appName('test') \
                                                    .master("local[*]") \
                                                    .getOrCreate()
                                                spark.sparkContext.setLogLevel("WARN")
                                                # Apache Arrow 是一种内存中的列式数据格式,用于 Spark 中以在 JVM 和 Python 进程之间有效地传输数据。
                                                # 需要安装Apache Arrow
                                                # pip install pyspark[sql]  执行安装,在linux中执行pip install xxx
                                                # spark.conf.set("spark.sql.execution.arrow.enabled", "true")
                                                spark.conf.set("spark.sql.execution.arrow.enabled", "true")
                                                df_pd = pd.DataFrame(
                                                    data={'integers': [1, 2, 3],
                                                          'floats': [-1.0, 0.6, 2.6],
                                                          'integer_arrays': [[1, 2], [3, 4.6], [5, 6, 8, 9]]}
                                                )
                                                df = spark.createDataFrame(df_pd)
                                                df.printSchema()
                                                df.show()
                                                from pyspark.sql.types import IntegerType, FloatType
                                                # 需求:我们想将一个求平方和的python函数注册成一个Spark UDF函数
                                                @udf(returnType=IntegerType())
                                                def square(x):
                                                    return x ** 2
                                                df.select("integers", square("integers").alias("int")).show()
                                                # 需求:我们想将一个求平方和的python函数注册成一个Spark UDF函数
                                                def squareTest(x):
                                                    return x ** 2
                                                udf_int = udf(squareTest, returnType=IntegerType())
                                                df.select("integers", udf_int("integers").alias("int")).show()
                                            
                                            (5) 列表输出类型
                                            • ArrayType:列表类型解决返回混合类型
                                              _05_baseUDFList.py
                                              # -*- coding: utf-8 -*-
                                              # Program function:从Pandas转化为DF
                                              import os
                                              from pyspark.sql import SparkSession
                                              import pandas as pd
                                              from pyspark.sql.functions import udf
                                              if __name__ == '__main__':
                                                  spark = SparkSession.builder \
                                                      .appName('test') \
                                                      .master("local[*]") \
                                                      .getOrCreate()
                                                  spark.sparkContext.setLogLevel("WARN")
                                                  # Apache Arrow 是一种内存中的列式数据格式,用于 Spark 中以在 JVM 和 Python 进程之间有效地传输数据。
                                                  # 需要安装Apache Arrow
                                                  # pip install pyspark[sql]  执行安装,在linux中执行pip install xxx
                                                  # spark.conf.set("spark.sql.execution.arrow.enabled", "true")
                                                  spark.conf.set("spark.sql.execution.arrow.enabled", "true")
                                                  df_pd = pd.DataFrame(
                                                      data={'integers': [1, 2, 3],
                                                            'floats': [-1.0, 0.6, 2.6],
                                                            'integer_arrays': [[1, 2], [3, 4.6], [5, 6, 8, 9]]}
                                                  )
                                                  df = spark.createDataFrame(df_pd)
                                                  df.printSchema()
                                                  df.show()
                                                  # +--------+------+--------------------+
                                                  # |integers|floats|      integer_arrays|
                                                  # +--------+------+--------------------+
                                                  # |       1|  -1.0|          [1.0, 2.0]|
                                                  # |       2|   0.6|          [3.0, 4.6]|
                                                  # |       3|   2.6|[5.0, 6.0, 8.0, 9.0]|
                                                  # +--------+------+--------------------+
                                                  # 需求:我们想将一个求平方和的python函数注册成一个Spark UDF函数
                                                  def square(x):
                                                      return [(val)**2 for val in x]
                                                  from pyspark.sql.functions import udf
                                                  from pyspark.sql.types import FloatType,ArrayType,IntegerType
                                                  # 现象:如果类型和udf的返回类型不一致的化,导致null的出现
                                                  # 如何解决?基于列表方式解决
                                                  udf_list = udf(square, returnType=ArrayType(FloatType()))
                                                  # DSL
                                                  df.select("integer_arrays", udf_list("integer_arrays").alias("int_float_list")).show(truncate=False)
                                                  # +--------------------+------------------------+
                                                  # |integer_arrays      |int_float_list          |
                                                  # +--------------------+------------------------+
                                                  # |[1.0, 2.0]          |[1.0, 4.0]              |
                                                  # |[3.0, 4.6]          |[9.0, 21.16]            |
                                                  # |[5.0, 6.0, 8.0, 9.0]|[25.0, 36.0, 64.0, 81.0]|
                                                  # +--------------------+------------------------+
                                              
                                              (6) 混合类型返回
                                              • StructType:结构类型,克服混合类型返回
                                                # -*- coding: utf-8 -*-
                                                # Program function:从Pandas转化为DF
                                                import os
                                                from pyspark.sql import SparkSession
                                                import pandas as pd
                                                from pyspark.sql.functions import udf
                                                if __name__ == '__main__':
                                                    spark = SparkSession.builder \
                                                        .appName('test') \
                                                        .master("local[*]") \
                                                        .getOrCreate()
                                                    spark.sparkContext.setLogLevel("WARN")
                                                    # Apache Arrow 是一种内存中的列式数据格式,用于 Spark 中以在 JVM 和 Python 进程之间有效地传输数据。
                                                    # 需要安装Apache Arrow
                                                    # pip install pyspark[sql]  执行安装,在linux中执行pip install xxx
                                                    # spark.conf.set("spark.sql.execution.arrow.enabled", "true")
                                                    spark.conf.set("spark.sql.execution.arrow.enabled", "true")
                                                    df_pd = pd.DataFrame(
                                                        data={'integers': [1, 2, 3],
                                                              'floats': [-1.0, 0.6, 2.6],
                                                              'integer_arrays': [[1, 2], [3, 4.6], [5, 6, 8, 9]]}
                                                    )
                                                    df = spark.createDataFrame(df_pd)
                                                    df.printSchema()
                                                    df.show()
                                                    # 下面展示udf的装饰器的使用方法
                                                    from pyspark.sql.types import StructType, StructField, StringType, IntegerType
                                                    struct_type = StructType([StructField("number", IntegerType(), True), StructField("letters", StringType(), False)])
                                                    import string
                                                    @udf(returnType=struct_type)
                                                    def convert_ascii(number):
                                                        return [number, string.ascii_letters[number]]
                                                    # 使用udf函数
                                                    df.select("integers", convert_ascii("integers").alias("num_letters")).show()
                                                
                                                (7) 混合函数设定
                                                • 多个函数,每个函数设置一个udf,默认returnType是字符串类型
                                                  # -*- coding: utf-8 -*-
                                                  # Program function:
                                                  import os
                                                  from pyspark.sql import SparkSession
                                                  from pyspark.sql.functions import udf
                                                  from pyspark.sql.types import *
                                                  if __name__ == '__main__':
                                                      spark = SparkSession.builder \
                                                          .appName('test') \
                                                          .master("local[*]") \
                                                          .getOrCreate()
                                                      spark.sparkContext.setLogLevel("WARN")
                                                      # 需求:模拟数据集实现用户名字长度,用户名字转为大写,以及age年龄字段增加1岁
                                                      # 三个自定义函数实现三个功能
                                                      def length(x):
                                                          if x is not None:
                                                              return len(x)
                                                      udf_len = udf(length, returnType=IntegerType())
                                                      def changeBigLetter(letter):
                                                          if letter is not  None:
                                                              return letter.upper()
                                                      udf_change = udf(changeBigLetter,returnType=StringType())
                                                      def addAgg(age):
                                                          if age is not None:
                                                              return age+1
                                                      udf_addAge = udf(addAgg, returnType=IntegerType())
                                                      df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
                                                      df.select(udf_len("name").alias("username_length"),udf_change("name").alias("letter"),udf_addAge("age").alias("age+1")).show()
                                                  

                                                  5. Series 数据源的UDF

                                                  (1)Series 函数,用pandas_udf自定义
                                                  • 1-定义series函数
                                                  • 2-pandas_udf自定义函数和类型,或者@pandas_udf
                                                  • 3-将series数据源转化为dataframe格式
                                                    # -*- coding: utf-8 -*-
                                                    # Program function:
                                                    # -*- coding: utf-8 -*-
                                                    # Program function:
                                                    import os
                                                    import pandas as pd
                                                    from pyspark.sql import SparkSession
                                                    from pyspark.sql.functions import col, pandas_udf
                                                    from pyspark.sql.types import LongType
                                                    # Import data types
                                                    os.environ['SPARK_HOME'] = '/export/server/spark'
                                                    PYSPARK_PYTHON = "/root/anaconda3/bin/python3"
                                                    # 当存在多个版本时,不指定很可能会导致出错
                                                    os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
                                                    os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON
                                                    if __name__ == '__main__':
                                                        spark = SparkSession.builder \
                                                            .appName('test') \
                                                            .getOrCreate()
                                                        sc = spark.sparkContext
                                                        def multiply_mul(x: pd.Series, y: pd.Series) -> pd.Series:
                                                            return x * y
                                                        series1 = pd.Series([1, 2, 3])
                                                        print("普通的集合的基本series相乘:")
                                                        # print(multiply_mul(series1, series1))
                                                        # 提出问题:如果使用上面的方式仅仅可以处理单机版本的数据,对于分布式数据无法使用上述函数
                                                        # 如何解决,这时候通过pandas_udf,将pandas的series或dataframe和Spark并行计算结合
                                                        udf_pandas = pandas_udf(multiply_mul, returnType=LongType())
                                                        # 需要创建spark的分布式集合
                                                        df = spark.createDataFrame(pd.DataFrame(series1, columns=["x"]))
                                                        df.show()
                                                        df.printSchema()
                                                        import pyspark.sql.functions as F
                                                        # 通过选择已经存在的x列,使用pandas_udf完成两个数的乘积
                                                        df.select("x", udf_pandas(F.col("x"), F.col("x"))).show()
                                                        #+---+------------------+
                                                        # |  x|multiply_mul(x, x)|
                                                        # +---+------------------+
                                                        # |  1|                 1|
                                                        # |  2|                 4|
                                                        # |  3|                 9|
                                                        # +---+------------------+
                                                    
                                                    (2) series与迭代器结合
                                                    • 1-Iterator:迭代器需要series数据类型,即使数据源是dataframe
                                                    • 2-pandas_udf:定义返回类型,或者@pandas_udf
                                                    • 3-迭代器返回:yield
                                                      # -*- coding: utf-8 -*-
                                                      # Program function:
                                                      # -*- coding: utf-8 -*-
                                                      # Program function:
                                                      import os
                                                      import pandas as pd
                                                      from pyspark.sql import SparkSession
                                                      from pyspark.sql.functions import col, pandas_udf
                                                      # Import data types
                                                      os.environ['SPARK_HOME'] = '/export/server/spark'
                                                      PYSPARK_PYTHON = "/root/anaconda3/bin/python"
                                                      # 当存在多个版本时,不指定很可能会导致出错
                                                      os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
                                                      os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON
                                                      if __name__ == '__main__':
                                                          spark = SparkSession.builder \
                                                              .appName('test') \
                                                              .getOrCreate()
                                                          sc = spark.sparkContext
                                                          # 需要创建spark的分布式集合,由pandas的集合创建spark的集合,使用arrow数据传输框架
                                                          pdf = pd.DataFrame([1, 2, 3], columns=["x"])
                                                          df = spark.createDataFrame(pdf)
                                                          # 首先通过multiply_add函数实现迭代器中的元素与1相加,series是每行的数据类型
                                                          from typing import Iterator
                                                          from pyspark.sql.types import IntegerType
                                                          def multiply_add(iterator:Iterator[pd.Series])->Iterator[pd.Series]:
                                                              for x in iterator:
                                                                  yield x+1
                                                          udf_pandas = pandas_udf(multiply_add, returnType=IntegerType())
                                                          df.select(udf_pandas("x")).show()
                                                          # +---+------------------+
                                                          # |  x|multiply_mul(x, x)|
                                                          # +---+------------------+
                                                          # |  1|                 1|
                                                          # |  2|                 4|
                                                          # |  3|                 9|
                                                          # +---+------------------+
                                                      
                                                      (3) Tuple是为了更好解包
                                                      # -*- coding: utf-8 -*-
                                                      # Program function:
                                                      import os
                                                      import pandas as pd
                                                      from pyspark.sql import SparkSession
                                                      from pyspark.sql.functions import col, pandas_udf
                                                      from pyspark.sql.types import LongType
                                                      # Import data types
                                                      os.environ['SPARK_HOME'] = '/export/server/spark'
                                                      PYSPARK_PYTHON = "/root/anaconda3/bin/python"
                                                      # 当存在多个版本时,不指定很可能会导致出错
                                                      os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
                                                      os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON
                                                      if __name__ == '__main__':
                                                          spark = SparkSession.builder \
                                                              .appName('test') \
                                                              .getOrCreate()
                                                          sc = spark.sparkContext
                                                          # 需要创建spark的分布式集合,由pandas的集合创建spark的集合,使用arrow数据传输框架
                                                          pdf = pd.DataFrame([1, 2, 3], columns=["x"])
                                                          df = spark.createDataFrame(pdf)
                                                          # 首先通过multiply_add函数实现迭代器中的元素与1相加
                                                          from typing import Iterator,Tuple
                                                          from pyspark.sql.types import IntegerType
                                                          @pandas_udf(returnType=IntegerType())
                                                          def multiply_add(iterator:Iterator[Tuple[pd.Series]])->Iterator[pd.Series]:
                                                              for x,y in iterator:
                                                                  yield x*y
                                                          df.select(multiply_add("x","x")).show()
                                                          # +------------------+
                                                          # |multiply_add(x, x)|
                                                          # +------------------+
                                                          # |                 1|
                                                          # |                 4|
                                                          # |                 9|
                                                          # +------------------+
                                                      
                                                      (4) pandas_udf 聚合借助 applyInPandas
                                                      • 1-agg:简单应用函数,但是不能修改列属性
                                                      • 2-applyInPandas:更高性能,可以修改列属性
                                                        # -*- coding: utf-8 -*-
                                                        # Program function:
                                                        # -*- coding: utf-8 -*-
                                                        # Program function:
                                                        import os
                                                        import pandas as pd
                                                        from pyspark.sql import SparkSession
                                                        from pyspark.sql.functions import col, pandas_udf
                                                        from pyspark.sql.types import LongType
                                                        # Import data types
                                                        os.environ['SPARK_HOME'] = '/export/server/spark'
                                                        PYSPARK_PYTHON = "/root/anaconda3/bin/python"
                                                        # 当存在多个版本时,不指定很可能会导致出错
                                                        os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
                                                        os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON
                                                        if __name__ == '__main__':
                                                            spark = SparkSession.builder \
                                                                .appName('test') \
                                                                .getOrCreate()
                                                            sc = spark.sparkContext
                                                            df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
                                                            # Declare the function and create the UDF
                                                            @pandas_udf("double")
                                                            def mean_udf(v: pd.Series) -> float:
                                                                return v.mean()
                                                            # 计算平均值
                                                            df.select(mean_udf("v")).show()
                                                            #+-----------+
                                                            # |mean_udf(v)|
                                                            # +-----------+
                                                            # |        4.2|
                                                            # +-----------+
                                                            # 根据id分组后求解平均值
                                                            df.groupby("id").agg(mean_udf("v")).show()
                                                            # +---+-----------+
                                                            # | id|mean_udf(v)|
                                                            # +---+-----------+
                                                            # |  1|        1.5|
                                                            # |  2|        6.0|
                                                            # +---+-----------+
                                                            # 应用
                                                            def subtract_mean(pdf):
                                                                # pdf is a pandas.DataFrame
                                                                v = pdf.v
                                                                return pdf.assign(v=v - v.mean())
                                                            # assign 方法被用来创建一个新的 DataFrame,
                                                            # schema="id long, v double" 是返回值类型,如下id | v 是返回值的名称
                                                            df.groupby("id").applyInPandas(subtract_mean, schema="id long, v double").show()
                                                            # +---+----+
                                                            # | id | v |
                                                            # +---+----+
                                                            # | 1 | -0.5 |
                                                            # | 1 | 0.5 |
                                                            # | 2 | -3.0 |
                                                            # | 2 | -1.0 |
                                                            # | 2 | 4.0 |
                                                            # +---+----+
                                                        
                                                        (5) 某一列加数问题
                                                        • 1-普通series加数
                                                        • 2-迭代器yield数
                                                          # -*- coding: utf-8 -*-
                                                          # Program function:通过NBA的数据集完成几个需求
                                                          from pyspark.sql import SparkSession
                                                          import os
                                                          # Import data types
                                                          os.environ['SPARK_HOME'] = '/export/server/spark'
                                                          PYSPARK_PYTHON = "/root/anaconda3/bin/python3"
                                                          # 当存在多个版本时,不指定很可能会导致出错
                                                          os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
                                                          os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON
                                                          if __name__ == '__main__':
                                                              # 1-准备环境
                                                              spark = SparkSession.builder.appName("bucket").master("local[*]").getOrCreate()
                                                              sc = spark.sparkContext
                                                              sc.setLogLevel("WARN")
                                                              # 2-读取数据
                                                              bucketData = spark.read \
                                                                  .format("csv") \
                                                                  .option("header", True) \
                                                                  .option("sep", ",") \
                                                                  .option("inferSchema", True) \
                                                                  .load("file:///export/data/spark_practice/PySpark-SparkSQL_3.1.2/data/bucket/bucket.csv")
                                                              bucketData.show()
                                                              # +---+----+----+------+----+------+----------+---------+----+----+----+
                                                              # |_c0|对手|胜负|主客场|命中|投篮数|投篮命中率|3分命中率|篮板|助攻|得分|
                                                              # +---+----+----+------+----+------+----------+---------+----+----+----+
                                                              # |  0|勇士|  胜|    客|  10|    23|     0.435|    0.444|   6|  11|  27|
                                                              # |  1|国王|  胜|    客|   8|    21|     0.381|    0.286|   3|   9|  28|
                                                              # |  2|小牛|  胜|    主|  10|    19|     0.526|    0.462|   3|   7|  29|
                                                              # |  3|火箭|  负|    客|   8|    19|     0.526|    0.462|   7|   9|  20|
                                                              # |  4|快船|  胜|    主|   8|    21|     0.526|    0.462|   7|   9|  28|
                                                              # |  5|热火|  负|    客|   8|    19|     0.435|    0.444|   6|  11|  18|
                                                              # |  6|骑士|  负|    客|   8|    21|     0.435|    0.444|   6|  11|  28|
                                                              # |  7|灰熊|  负|    主|  10|    20|     0.435|    0.444|   6|  11|  27|
                                                              # |  8|活塞|  胜|    主|   8|    19|     0.526|    0.462|   7|   9|  16|
                                                              # |  9|76人|  胜|    主|  10|    21|     0.526|    0.462|   7|   9|  28|
                                                              # +---+----+----+------+----+------+----------+---------+----+----+----+
                                                              bucketData.printSchema()
                                                              # root
                                                              # | -- _c0: integer(nullable=true)
                                                              # | -- 对手: string(nullable=true)
                                                              # | -- 胜负: string(nullable=true)
                                                              # | -- 主客场: string(nullable=true)
                                                              # | -- 命中: integer(nullable=true)
                                                              # | -- 投篮数: integer(nullable=true)
                                                              # | -- 投篮命中率: double(nullable=true)
                                                              # | -- 3
                                                              # 分命中率: double(nullable=true)
                                                              # | -- 篮板: integer(nullable=true)
                                                              # | -- 助攻: integer(nullable=true)
                                                              # | -- 得分: integer(nullable=true)
                                                              print("助攻这一列需要加10,如何实现?思考使用哪种?")
                                                              import pandas as pd
                                                              from pyspark.sql.functions import pandas_udf
                                                              from pyspark.sql.types import IntegerType
                                                              def total_shoot(x: pd.Series) -> pd.Series:
                                                                  return x + 10
                                                              udf_add = pandas_udf(total_shoot, returnType=IntegerType())
                                                              bucketData.select("助攻", udf_add("助攻")).show()
                                                              # +----+-----------------+
                                                              # |助攻|total_shoot(助攻)|
                                                              # +----+-----------------+
                                                              # |  11|               21|
                                                              # |   9|               19|
                                                              # |   7|               17|
                                                              # |   9|               19|
                                                              # |   9|               19|
                                                              # |  11|               21|
                                                              # |  11|               21|
                                                              # |  11|               21|
                                                              # |   9|               19|
                                                              # |   9|               19|
                                                              # +----+-----------------+
                                                              from typing import Iterator
                                                              def total_shoot_iter(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
                                                                  for x in iterator:
                                                                      yield x + 10
                                                              udf_add_iter = pandas_udf(total_shoot_iter, returnType=IntegerType())
                                                              bucketData.select("助攻", udf_add_iter("助攻")).show()
                                                              # +----+----------------------+
                                                              # |助攻|total_shoot_iter(助攻)|
                                                              # +----+----------------------+
                                                              # |  11|                    21|
                                                              # |   9|                    19|
                                                              # |   7|                    17|
                                                              # |   9|                    19|
                                                              # |   9|                    19|
                                                              # |  11|                    21|
                                                              # |  11|                    21|
                                                              # |  11|                    21|
                                                              # |   9|                    19|
                                                              # |   9|                    19|
                                                              # +----+----------------------+
                                                          
                                                          (6) 某两列相加问题
                                                          • 1-普通series相加

                                                          • 2-iterate迭代器相加

                                                            # -*- coding: utf-8 -*-
                                                            # Program function:通过NBA的数据集完成几个需求
                                                            from pyspark.sql import SparkSession
                                                            import os
                                                            # Import data types
                                                            os.environ['SPARK_HOME'] = '/export/server/spark'
                                                            PYSPARK_PYTHON = "/root/anaconda3/bin/python"
                                                            # 当存在多个版本时,不指定很可能会导致出错
                                                            os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
                                                            os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON
                                                            if __name__ == '__main__':
                                                                # 1-准备环境
                                                                spark = SparkSession.builder.appName("bucket").master("local[*]").getOrCreate()
                                                                sc = spark.sparkContext
                                                                sc.setLogLevel("WARN")
                                                                # 2-读取数据
                                                                bucketData = spark.read \
                                                                    .format("csv") \
                                                                    .option("header", True) \
                                                                    .option("sep", ",") \
                                                                    .option("inferSchema", True) \
                                                                    .load("file:///export/data/spark_practice/PySpark-SparkSQL_3.1.2/data/bucket/bucket.csv")
                                                                bucketData.show()
                                                                # +---+----+----+------+----+------+----------+---------+----+----+----+
                                                                # |_c0|对手|胜负|主客场|命中|投篮数|投篮命中率|3分命中率|篮板|助攻|得分|
                                                                # +---+----+----+------+----+------+----------+---------+----+----+----+
                                                                # |  0|勇士|  胜|    客|  10|    23|     0.435|    0.444|   6|  11|  27|
                                                                # |  1|国王|  胜|    客|   8|    21|     0.381|    0.286|   3|   9|  28|
                                                                # |  2|小牛|  胜|    主|  10|    19|     0.526|    0.462|   3|   7|  29|
                                                                # |  3|火箭|  负|    客|   8|    19|     0.526|    0.462|   7|   9|  20|
                                                                # |  4|快船|  胜|    主|   8|    21|     0.526|    0.462|   7|   9|  28|
                                                                # |  5|热火|  负|    客|   8|    19|     0.435|    0.444|   6|  11|  18|
                                                                # |  6|骑士|  负|    客|   8|    21|     0.435|    0.444|   6|  11|  28|
                                                                # |  7|灰熊|  负|    主|  10|    20|     0.435|    0.444|   6|  11|  27|
                                                                # |  8|活塞|  胜|    主|   8|    19|     0.526|    0.462|   7|   9|  16|
                                                                # |  9|76人|  胜|    主|  10|    21|     0.526|    0.462|   7|   9|  28|
                                                                # +---+----+----+------+----+------+----------+---------+----+----+----+
                                                                bucketData.printSchema()
                                                                # root
                                                                # | -- _c0: integer(nullable=true)
                                                                # | -- 对手: string(nullable=true)
                                                                # | -- 胜负: string(nullable=true)
                                                                # | -- 主客场: string(nullable=true)
                                                                # | -- 命中: integer(nullable=true)
                                                                # | -- 投篮数: integer(nullable=true)
                                                                # | -- 投篮命中率: double(nullable=true)
                                                                # | -- 3
                                                                # 分命中率: double(nullable=true)
                                                                # | -- 篮板: integer(nullable=true)
                                                                # | -- 助攻: integer(nullable=true)
                                                                # | -- 得分: integer(nullable=true)
                                                                print("助攻这一列需要加10,如何实现?思考使用哪种?")
                                                                import pandas as pd
                                                                from pyspark.sql.functions import pandas_udf,col
                                                                from pyspark.sql.types import IntegerType
                                                                # 方法1  官网建议使用python的数据类型,比如int
                                                                # @pandas_udf(returnType=IntegerType())
                                                                # @pandas_udf(returnType="int")
                                                                @pandas_udf("int")
                                                                def total_shoot(x: pd.Series) -> pd.Series:
                                                                    return x + 10
                                                                bucketData.select("助攻", total_shoot("助攻")).show()
                                                                # 方法2
                                                                from typing import Iterator,Tuple
                                                                @pandas_udf("long")
                                                                def total_shoot_iter(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
                                                                    for x in iterator:
                                                                        yield x + 10
                                                                bucketData.select("助攻", total_shoot_iter("助攻")).show()
                                                                # TODO 1-Scalar UDF定义了一个转换,函数输入一个或多个pd.Series,输出一个pd.Series,函数的输出和输入有相同的长度
                                                                print("篮板+助攻的次数,思考使用哪种?")
                                                                # 方法1
                                                                @pandas_udf("long")
                                                                def total_bucket(x: pd.Series,y:pd.Series) -> pd.Series:
                                                                    return x + y
                                                                bucketData.select("助攻","篮板", total_bucket(col("助攻"),col("篮板"))).show()
                                                                # 方法2
                                                                @pandas_udf("long")
                                                                def multiply_two_cols(iterator: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]:
                                                                    for x,y in iterator:
                                                                        yield x+y
                                                                bucketData.select("助攻","篮板", multiply_two_cols(col("助攻"),col("篮板"))).show()
                                                                
                                                                # 篮板+助攻的次数,思考使用哪种?
                                                                # +----+----+------------------------+
                                                                # |助攻|篮板|total_bucket(助攻, 篮板)|
                                                                # +----+----+------------------------+
                                                                # |  11|   6|                      17|
                                                                # |   9|   3|                      12|
                                                                # |   7|   3|                      10|
                                                                # |   9|   7|                      16|
                                                                # |   9|   7|                      16|
                                                                # |  11|   6|                      17|
                                                                # |  11|   6|                      17|
                                                                # |  11|   6|                      17|
                                                                # |   9|   7|                      16|
                                                                # |   9|   7|                      16|
                                                                # +----+----+------------------------+
                                                                # 
                                                                # +----+----+-----------------------------+
                                                                # |助攻|篮板|multiply_two_cols(助攻, 篮板)|
                                                                # +----+----+-----------------------------+
                                                                # |  11|   6|                           17|
                                                                # |   9|   3|                           12|
                                                                # |   7|   3|                           10|
                                                                # |   9|   7|                           16|
                                                                # |   9|   7|                           16|
                                                                # |  11|   6|                           17|
                                                                # |  11|   6|                           17|
                                                                # |  11|   6|                           17|
                                                                # |   9|   7|                           16|
                                                                # |   9|   7|                           16|
                                                                # +----+----+-----------------------------+
                                                            
                                                            (7) series函数返回单值类型
                                                            • def count_num(v: pd.Series) -> float:计算过程返回值
                                                              # -*- coding: utf-8 -*-
                                                              # Program function:通过NBA的数据集完成几个需求
                                                              from pyspark.sql import SparkSession
                                                              import os
                                                              # Import data types
                                                              os.environ['SPARK_HOME'] = '/export/server/spark'
                                                              PYSPARK_PYTHON = "/root/anaconda3/bin/python"
                                                              # 当存在多个版本时,不指定很可能会导致出错
                                                              os.environ["PYSPARK_PYTHON"] = PYSPARK_PYTHON
                                                              os.environ["PYSPARK_DRIVER_PYTHON"] = PYSPARK_PYTHON
                                                              if __name__ == '__main__':
                                                                  # 1-准备环境
                                                                  spark = SparkSession.builder.appName("bucket").master("local[*]").getOrCreate()
                                                                  sc = spark.sparkContext
                                                                  sc.setLogLevel("WARN")
                                                                  # 2-读取数据
                                                                  bucketData = spark.read \
                                                                      .format("csv") \
                                                                      .option("header", True) \
                                                                      .option("sep", ",") \
                                                                      .option("inferSchema", True) \
                                                                      .load("file:///export/data/pyspark_workspace/PySpark-SparkSQL_3.1.2/data/bucket/bucket.csv")
                                                                  bucketData.show()
                                                                  # +---+----+----+------+----+------+----------+---------+----+----+----+
                                                                  # |_c0|对手|胜负|主客场|命中|投篮数|投篮命中率|3分命中率|篮板|助攻|得分|
                                                                  # +---+----+----+------+----+------+----------+---------+----+----+----+
                                                                  # |  0|勇士|  胜|    客|  10|    23|     0.435|    0.444|   6|  11|  27|
                                                                  # |  1|国王|  胜|    客|   8|    21|     0.381|    0.286|   3|   9|  28|
                                                                  # |  2|小牛|  胜|    主|  10|    19|     0.526|    0.462|   3|   7|  29|
                                                                  # |  3|火箭|  负|    客|   8|    19|     0.526|    0.462|   7|   9|  20|
                                                                  # |  4|快船|  胜|    主|   8|    21|     0.526|    0.462|   7|   9|  28|
                                                                  # |  5|热火|  负|    客|   8|    19|     0.435|    0.444|   6|  11|  18|
                                                                  # |  6|骑士|  负|    客|   8|    21|     0.435|    0.444|   6|  11|  28|
                                                                  # |  7|灰熊|  负|    主|  10|    20|     0.435|    0.444|   6|  11|  27|
                                                                  # |  8|活塞|  胜|    主|   8|    19|     0.526|    0.462|   7|   9|  16|
                                                                  # |  9|76人|  胜|    主|  10|    21|     0.526|    0.462|   7|   9|  28|
                                                                  # +---+----+----+------+----+------+----------+---------+----+----+----+
                                                                  bucketData.printSchema()
                                                                  # root
                                                                  # | -- _c0: integer(nullable=true)
                                                                  # | -- 对手: string(nullable=true)
                                                                  # | -- 胜负: string(nullable=true)
                                                                  # | -- 主客场: string(nullable=true)
                                                                  # | -- 命中: integer(nullable=true)
                                                                  # | -- 投篮数: integer(nullable=true)
                                                                  # | -- 投篮命中率: double(nullable=true)
                                                                  # | -- 3
                                                                  # 分命中率: double(nullable=true)
                                                                  # | -- 篮板: integer(nullable=true)
                                                                  # | -- 助攻: integer(nullable=true)
                                                                  # | -- 得分: integer(nullable=true)
                                                                  print("助攻这一列需要加10,如何实现?思考使用哪种?")
                                                                  import pandas as pd
                                                                  from pyspark.sql.functions import pandas_udf,col
                                                                  from pyspark.sql.types import IntegerType
                                                                  # 方法1  官网建议使用python的数据类型,比如int
                                                                  # @pandas_udf(returnType=IntegerType())
                                                                  # @pandas_udf(returnType="int")
                                                                  @pandas_udf("int")
                                                                  def total_shoot(x: pd.Series) -> pd.Series:
                                                                      return x + 10
                                                                  bucketData.select("助攻", total_shoot("助攻")).show()
                                                                  # TODO 1-Scalar UDF定义了一个转换,函数输入一个或多个pd.Series,输出一个pd.Series,函数的输出和输入有相同的长度
                                                                  print("篮板+助攻的次数,思考使用哪种?")
                                                                  # 方法1
                                                                  @pandas_udf("long")
                                                                  def total_bucket(x: pd.Series,y:pd.Series) -> pd.Series:
                                                                      return x + y
                                                                  bucketData.select("助攻","篮板", total_bucket(col("助攻"),col("篮板"))).show()
                                                                  # TODO 3-GROUPED_AGG定义了一个或多个pandas.Series -> 一个scalar,scalar的返回值类型(returnType)应该是原始数据类型
                                                                  print("统计胜 和 负的平均分")
                                                                  #UDAF
                                                                  @pandas_udf('int')
                                                                  def count_num(v: pd.Series) -> float:
                                                                      return v.mean()
                                                                  # 中文列名会报错,所以尝试英文列名
                                                                  bucketData.groupby("胜负").agg(count_num(bucketData['得分']).alias('avg_score')).show(2)
                                                                  #+----+---------+
                                                                  # |胜负|avg_score|
                                                                  # +----+---------+
                                                                  # |  负|       23|
                                                                  # |  胜|       26|
                                                                  # +----+---------+
                                                              
                                                               
打赏
版权声明:本文采用知识共享 署名4.0国际许可协议 [BY-NC-SA] 进行授权
文章名称:《Spark重温笔记(五):SparkSQL进阶操作——迭代计算,开窗函数,结合多种数据源,UDF自定义函数》
文章链接:https://goodmancom.com/wl/175931.html