Spark High-Level API¶
Overview of Spark SQL¶
Creating DataFrames - RDD - createDataFrame() - Text file - read.text() - JSON file - read.json() - read,json(RDD) - Parquet file - read.parquet() - Table in a relational database - Temporary table in Spark
DataFrame to RDD - rdd()
Schemas - Inferring schemas - Why it is not optimal practice - Specifying schemas - Using StructType and StructField - Using DDL string (schema = “author STRING, title STRING, pages INT”) - Metadata - printSchema() - columns() - dtypes() - Actions - show() - Transforms - select() and alias() - drop() - filter() / where() - distinct() - dropDuplicates() - sample - sampleBy() - limit() - orderBy() / sort() - groupBy()
Operations that return an RDD - rdd.map() - rdd.flatMap()
pyspark.sql.functions module - String functions - Math functions - Statistics functions - Date functions - Hashing functions - Algorithms (sounded, levenstein) - Windowing functions
User defined functions - udf() - pandas_udf()
Multiple DataFrames - join(other, on, how) - union(), unionAll() - intersect() - subtract()
Persistence - cache() - persist(: - unpersist() - cacheTable() - clearCache() - repartition() - coalesce()
Output - write.csv() - write.parquet() - write.json()
Spark SQL - df.createOrReplaceTempView - sql() - table()
[1]:
from pyspark.sql import SparkSession
[2]:
import pyspark.sql.functions as F
import pyspark.sql.types as T
[3]:
spark = (
SparkSession.builder
.master("local")
.appName("BIOS-823")
.config("spark.executor.cores", 4)
.getOrCreate()
)
[4]:
spark.version
[4]:
'3.0.1'
[5]:
spark.conf.get('spark.executor.cores')
[5]:
'4'
Create a Spark DataFrame¶
[6]:
df = spark.range(3)
[7]:
df.show(3)
+---+
| id|
+---+
| 0|
| 1|
| 2|
+---+
[8]:
%%file data/test.csv
number,letter
0,a
1,c
2,b
3,a
4,b
5,c
6,a
7,a
8,a
9,b
10,b
11,c
12,c
13,b
14,b
Overwriting data/test.csv
Implicit schema¶
[9]:
df = (
spark.read.
format('csv').
option('header', 'true').
option('inferSchema', 'true').
load('csv/test.csv')
)
[10]:
df.show(3)
+------+------+
|number|letter|
+------+------+
| 0| a|
| 1| c|
| 2| b|
+------+------+
only showing top 3 rows
[11]:
df.printSchema()
root
|-- number: integer (nullable = true)
|-- letter: string (nullable = true)
Explicit schema¶
For production use, you should provide an explicit schema to reduce risk of error.
[12]:
schema = T.StructType([
T.StructField("number", T.DoubleType()),
T.StructField("letter", T.StringType()),
])
[13]:
df = (
spark.read.
format('csv').
option('header', 'true').
schema(schema).
load('csv/test.csv')
)
[14]:
df.show(3)
+------+------+
|number|letter|
+------+------+
| 0.0| a|
| 1.0| c|
| 2.0| b|
+------+------+
only showing top 3 rows
[15]:
df.printSchema()
root
|-- number: double (nullable = true)
|-- letter: string (nullable = true)
Alternative way to specify schema¶
You can use SQL DDL syntax to specify a schema as well.
[16]:
schema = 'number DOUBLE, letter STRING'
[17]:
df_altschema = (
spark.read.
format('csv').
option('header', 'true').
schema(schema=schema).
load('csv/test.csv')
)
[18]:
df_altschema.take(3)
[18]:
[Row(number=0.0, letter='a'),
Row(number=1.0, letter='c'),
Row(number=2.0, letter='b')]
[19]:
df_altschema.printSchema()
root
|-- number: double (nullable = true)
|-- letter: string (nullable = true)
Data manipulation¶
[21]:
df.select('number').show(3)
+------+
|number|
+------+
| 0.0|
| 1.0|
| 2.0|
+------+
only showing top 3 rows
[22]:
from pyspark.sql.functions import col, expr
[23]:
df.select(col('number').alias('index')).show(3)
+-----+
|index|
+-----+
| 0.0|
| 1.0|
| 2.0|
+-----+
only showing top 3 rows
[24]:
df.select(expr('number as x')).show(3)
+---+
| x|
+---+
|0.0|
|1.0|
|2.0|
+---+
only showing top 3 rows
[25]:
df.withColumnRenamed('number', 'x').show(3)
+---+------+
| x|letter|
+---+------+
|0.0| a|
|1.0| c|
|2.0| b|
+---+------+
only showing top 3 rows
Filter¶
[26]:
df.filter('number % 2 == 0').show(3)
+------+------+
|number|letter|
+------+------+
| 0.0| a|
| 2.0| b|
| 4.0| b|
+------+------+
only showing top 3 rows
[27]:
df.filter("number % 2 == 0 AND letter == 'a'").show(3)
+------+------+
|number|letter|
+------+------+
| 0.0| a|
| 6.0| a|
| 8.0| a|
+------+------+
Sort¶
[28]:
df.sort(df.number.desc()).show(3)
+------+------+
|number|letter|
+------+------+
| 14.0| b|
| 13.0| b|
| 12.0| c|
+------+------+
only showing top 3 rows
[29]:
df.sort('number', ascending=False).show(3)
+------+------+
|number|letter|
+------+------+
| 14.0| b|
| 13.0| b|
| 12.0| c|
+------+------+
only showing top 3 rows
[30]:
df.orderBy(df.letter.desc()).show(3)
+------+------+
|number|letter|
+------+------+
| 1.0| c|
| 5.0| c|
| 11.0| c|
+------+------+
only showing top 3 rows
Transform¶
[31]:
df.selectExpr('number*2 as x').show(3)
+---+
| x|
+---+
|0.0|
|2.0|
|4.0|
+---+
only showing top 3 rows
[32]:
df.selectExpr('number', 'letter', 'number*2 as x').show(3)
+------+------+---+
|number|letter| x|
+------+------+---+
| 0.0| a|0.0|
| 1.0| c|2.0|
| 2.0| b|4.0|
+------+------+---+
only showing top 3 rows
[33]:
df.withColumn('x', expr('number*2')).show(3)
+------+------+---+
|number|letter| x|
+------+------+---+
| 0.0| a|0.0|
| 1.0| c|2.0|
| 2.0| b|4.0|
+------+------+---+
only showing top 3 rows
Sumarize¶
[34]:
import pyspark.sql.functions as F
[35]:
df.agg(F.min('number'),
F.max('number'),
F.min('letter'),
F.max('letter')).show()
+-----------+-----------+-----------+-----------+
|min(number)|max(number)|min(letter)|max(letter)|
+-----------+-----------+-----------+-----------+
| 0.0| 14.0| a| c|
+-----------+-----------+-----------+-----------+
Group by¶
[36]:
(
df.groupby('letter').
agg(F.mean('number'), F.stddev_samp('number')).show()
)
+------+-----------------+-------------------+
|letter| avg(number)|stddev_samp(number)|
+------+-----------------+-------------------+
| c| 7.25| 5.188127472091127|
| b|8.666666666666666| 4.802776974487434|
| a| 4.8| 3.271085446759225|
+------+-----------------+-------------------+
Window functions¶
[37]:
from pyspark.sql.window import Window
[38]:
ws = (
Window.partitionBy('letter').
orderBy(F.desc('number')).
rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
[39]:
df.groupby('letter').agg(F.sum('number')).show()
+------+-----------+
|letter|sum(number)|
+------+-----------+
| c| 29.0|
| b| 52.0|
| a| 24.0|
+------+-----------+
[40]:
df.show()
+------+------+
|number|letter|
+------+------+
| 0.0| a|
| 1.0| c|
| 2.0| b|
| 3.0| a|
| 4.0| b|
| 5.0| c|
| 6.0| a|
| 7.0| a|
| 8.0| a|
| 9.0| b|
| 10.0| b|
| 11.0| c|
| 12.0| c|
| 13.0| b|
| 14.0| b|
+------+------+
[41]:
(
df.select('letter', F.sum('number').
over(ws).
alias('rank')).show()
)
+------+----+
|letter|rank|
+------+----+
| c|12.0|
| c|23.0|
| c|28.0|
| c|29.0|
| b|14.0|
| b|27.0|
| b|37.0|
| b|46.0|
| b|50.0|
| b|52.0|
| a| 8.0|
| a|15.0|
| a|21.0|
| a|24.0|
| a|24.0|
+------+----+
SQL¶
[42]:
df.createOrReplaceTempView('df_table')
[43]:
spark.sql('''SELECT * FROM df_table''').show(3)
+------+------+
|number|letter|
+------+------+
| 0.0| a|
| 1.0| c|
| 2.0| b|
+------+------+
only showing top 3 rows
[44]:
spark.sql('''
SELECT letter, mean(number) AS mean,
stddev_samp(number) AS sd from df_table
WHERE number % 2 = 0
GROUP BY letter
ORDER BY letter DESC
''').show()
+------+-----------------+-----------------+
|letter| mean| sd|
+------+-----------------+-----------------+
| c| 12.0| NaN|
| b| 7.5|5.507570547286102|
| a|4.666666666666667|4.163331998932265|
+------+-----------------+-----------------+
String operatons¶
[45]:
from pyspark.sql.functions import split, lower, explode
[46]:
import pandas as pd
[47]:
s = spark.createDataFrame(
pd.DataFrame(
dict(sents=('Thing 1 and Thing 2',
'The Quick Brown Fox'))))
[48]:
s.show()
+-------------------+
| sents|
+-------------------+
|Thing 1 and Thing 2|
|The Quick Brown Fox|
+-------------------+
[49]:
from pyspark.sql.functions import regexp_replace
[50]:
s1 = (
s.select(explode(split(lower(expr('sents')), ' '))).
sort('col')
)
[51]:
s1.show()
+-----+
| col|
+-----+
| 1|
| 2|
| and|
|brown|
| fox|
|quick|
| the|
|thing|
|thing|
+-----+
[52]:
s1.groupBy('col').count().show()
+-----+-----+
| col|count|
+-----+-----+
|thing| 2|
| fox| 1|
| the| 1|
| and| 1|
| 1| 1|
|brown| 1|
|quick| 1|
| 2| 1|
+-----+-----+
[53]:
s.createOrReplaceTempView('s_table')
[54]:
spark.sql('''
SELECT regexp_replace(sents, 'T.*?g', 'FOO')
FROM s_table
''').show()
+---------------------------------+
|regexp_replace(sents, T.*?g, FOO)|
+---------------------------------+
| FOO 1 and FOO 2|
| The Quick Brown Fox|
+---------------------------------+
[55]:
from pyspark.sql.functions import log1p, randn
[56]:
df.selectExpr('number', 'log1p(number)', 'letter').show(3)
+------+------------------+------+
|number| LOG1P(number)|letter|
+------+------------------+------+
| 0.0| 0.0| a|
| 1.0|0.6931471805599453| c|
| 2.0|1.0986122886681096| b|
+------+------------------+------+
only showing top 3 rows
[57]:
(
df.selectExpr('number', 'randn() as random').
stat.corr('number', 'random')
)
[57]:
0.09472194952788052
[58]:
dt = (
spark.range(3).
withColumn('today', F.current_date()).
withColumn('tomorrow', F.date_add('today', 1)).
withColumn('time', F.current_timestamp())
)
[59]:
dt.show()
+---+----------+----------+--------------------+
| id| today| tomorrow| time|
+---+----------+----------+--------------------+
| 0|2020-10-28|2020-10-29|2020-10-28 12:44:...|
| 1|2020-10-28|2020-10-29|2020-10-28 12:44:...|
| 2|2020-10-28|2020-10-29|2020-10-28 12:44:...|
+---+----------+----------+--------------------+
[60]:
dt.printSchema()
root
|-- id: long (nullable = false)
|-- today: date (nullable = false)
|-- tomorrow: date (nullable = false)
|-- time: timestamp (nullable = false)
[61]:
%%file data/test_null.csv
number,letter
0,a
1,
2,b
3,a
4,b
5,
6,a
7,a
8,
9,b
10,b
11,c
12,
13,b
14,b
Overwriting data/test_null.csv
[62]:
dn = (
spark.read.
format('csv').
option('header', 'true').
option('inferSchema', 'true').
load('csv/test_null.csv')
)
[63]:
dn.printSchema()
root
|-- number: integer (nullable = true)
|-- letter: string (nullable = true)
[64]:
dn.show()
+------+------+
|number|letter|
+------+------+
| 0| a|
| 1| null|
| 2| b|
| 3| a|
| 4| b|
| 5| null|
| 6| a|
| 7| a|
| 8| null|
| 9| b|
| 10| b|
| 11| c|
| 12| null|
| 13| b|
| 14| b|
+------+------+
[65]:
dn.na.drop().show()
+------+------+
|number|letter|
+------+------+
| 0| a|
| 2| b|
| 3| a|
| 4| b|
| 6| a|
| 7| a|
| 9| b|
| 10| b|
| 11| c|
| 13| b|
| 14| b|
+------+------+
[66]:
dn.na.fill('Missing').show()
+------+-------+
|number| letter|
+------+-------+
| 0| a|
| 1|Missing|
| 2| b|
| 3| a|
| 4| b|
| 5|Missing|
| 6| a|
| 7| a|
| 8|Missing|
| 9| b|
| 10| b|
| 11| c|
| 12|Missing|
| 13| b|
| 14| b|
+------+-------+
UDF¶
To avoid degrading performance, avoid using UDF if you can use the functions in pyspark.sql.functions
. If you must use UDFs, prefer pandas_udf
to udf
where possible.
[67]:
from pyspark.sql.functions import udf, pandas_udf
[68]:
@udf('double')
def square(x):
return x**2
[69]:
df.select('number', square('number')).show(3)
+------+--------------+
|number|square(number)|
+------+--------------+
| 0.0| 0.0|
| 1.0| 1.0|
| 2.0| 4.0|
+------+--------------+
only showing top 3 rows
This can be tricky to set up. I use Oracle Java SDK v11 and set the following environment variables.
export JAVA_HOME=$(/usr/libexec/java_home -v 11)
export JAVA_TOOL_OPTIONS="-Dio.netty.tryReflectionSetAccessible=true"
[70]:
@pandas_udf('double')
def scale(x):
return (x - x.mean())/x.std()
[71]:
df.select(scale('number')).show(3)
+-------------------+
| scale(number)|
+-------------------+
|-1.5652475842498528|
|-1.3416407864998738|
| -1.118033988749895|
+-------------------+
only showing top 3 rows
Grouped agg¶
[72]:
import warnings
warnings.simplefilter('ignore', UserWarning)
[73]:
@pandas_udf('double', F.PandasUDFType.GROUPED_AGG)
def gmean(x):
return x.mean()
[74]:
df.groupby('letter').agg(gmean('number')).show()
+------+-----------------+
|letter| gmean(number)|
+------+-----------------+
| c| 7.25|
| b|8.666666666666666|
| a| 4.8|
+------+-----------------+
Spark 3¶
Use type hints rather than specify pandas UDF type
See blog
[75]:
@pandas_udf('double')
def gmean1(x: pd.Series) -> float:
return x.mean()
[76]:
df.groupby('letter').agg(gmean1('number')).show()
+------+-----------------+
|letter| gmean1(number)|
+------+-----------------+
| c| 7.25|
| b|8.666666666666666|
| a| 4.8|
+------+-----------------+
Grouped map¶
[77]:
@pandas_udf(df.schema, F.PandasUDFType.GROUPED_MAP)
def gscale(pdf):
return pdf.assign(
number = (pdf.number - pdf.number.mean()) /
pdf.number.std())
[78]:
df.groupby('letter').apply(gscale).show()
+--------------------+------+
| number|letter|
+--------------------+------+
| -1.2046735616310666| c|
|-0.43368248218718397| c|
| 0.72280413697864| c|
| 0.9155519068396106| c|
| -1.3880858307767148| b|
| -0.9716600815437003| b|
| 0.06940429153883587| b|
| 0.2776171661553431| b|
| 0.9022557900048648| b|
| 1.1104686646213722| b|
| -1.467402817237783| a|
| -0.5502760564641687| a|
| 0.36685070430944583| a|
| 0.6725596245673173| a|
| 0.9782685448251889| a|
+--------------------+------+
applyinPandas
¶
This implements the split-apply-combine
pattern. Method of grouped DataFrame.
Variant 1: Function takes a single DataFrame input
Variant 2: Function takes a tuple of keys, and a DataFrame input
[79]:
def gscale1(pdf: pd.DataFrame) -> pd.DataFrame:
num = pdf.number
return pdf.assign(
number = (num - num.mean()) / num.std())
[80]:
df.groupby('letter').applyInPandas(gscale1, schema=df.schema).show()
+--------------------+------+
| number|letter|
+--------------------+------+
| -1.2046735616310666| c|
|-0.43368248218718397| c|
| 0.72280413697864| c|
| 0.9155519068396106| c|
| -1.3880858307767148| b|
| -0.9716600815437003| b|
| 0.06940429153883587| b|
| 0.2776171661553431| b|
| 0.9022557900048648| b|
| 1.1104686646213722| b|
| -1.467402817237783| a|
| -0.5502760564641687| a|
| 0.36685070430944583| a|
| 0.6725596245673173| a|
| 0.9782685448251889| a|
+--------------------+------+
[81]:
def gsum(key, pdf):
return pd.DataFrame([key + (pdf.number.sum(),)])
[82]:
df.groupby('letter').applyInPandas(gsum, 'letter string, number long').show()
+------+------+
|letter|number|
+------+------+
| c| 29|
| b| 52|
| a| 24|
+------+------+
Of course, you do not need a UDF in this example!
[83]:
df.groupBy('letter').sum().show()
+------+-----------+
|letter|sum(number)|
+------+-----------+
| c| 29.0|
| b| 52.0|
| a| 24.0|
+------+-----------+
So shoudl only be used for truly custom functions.
[84]:
def func(pdf: pd.DataFrame) -> int:
return (pdf.number.astype('str').str.len()).sum()
def gcustom(key, pdf):
return pd.DataFrame([key + (func(pdf),)])
[85]:
df.groupby('letter').applyInPandas(gcustom, 'letter string, number long').show()
+------+------+
|letter|number|
+------+------+
| c| 14|
| b| 21|
| a| 15|
+------+------+
mapinPandas
¶
This works on iterators. Method of DataFrame. Can be used to implement a filter.
[86]:
def even(it):
for pdf in it:
yield pdf[pdf.number % 2 == 0]
[87]:
df.mapInPandas(even, 'letter string, number long').show()
+------+------+
|letter|number|
+------+------+
| a| 0|
| b| 2|
| b| 4|
| a| 6|
| a| 8|
| b| 10|
| c| 12|
| b| 14|
+------+------+
Joins¶
[88]:
names = 'ann ann bob bob chcuk'.split()
courses = '821 823 821 824 823'.split()
pdf1 = pd.DataFrame(dict(name=names, course=courses))
[89]:
pdf1
[89]:
name | course | |
---|---|---|
0 | ann | 821 |
1 | ann | 823 |
2 | bob | 821 |
3 | bob | 824 |
4 | chcuk | 823 |
[90]:
course_id = '821 822 823 824 825'.split()
course_names = 'Unix Python R Spark GLM'.split()
pdf2 = pd.DataFrame(dict(course_id=course_id, name=course_names))
[91]:
pdf2
[91]:
course_id | name | |
---|---|---|
0 | 821 | Unix |
1 | 822 | Python |
2 | 823 | R |
3 | 824 | Spark |
4 | 825 | GLM |
[92]:
df1 = spark.createDataFrame(pdf1)
df2 = spark.createDataFrame(pdf2)
[93]:
df1.join(df2, df1.course == df2.course_id, how='inner').show()
+-----+------+---------+-----+
| name|course|course_id| name|
+-----+------+---------+-----+
| ann| 823| 823| R|
|chcuk| 823| 823| R|
| bob| 824| 824|Spark|
| ann| 821| 821| Unix|
| bob| 821| 821| Unix|
+-----+------+---------+-----+
[94]:
df1.join(df2, df1.course == df2.course_id, how='right').show()
+-----+------+---------+------+
| name|course|course_id| name|
+-----+------+---------+------+
| ann| 823| 823| R|
|chcuk| 823| 823| R|
| bob| 824| 824| Spark|
| null| null| 825| GLM|
| null| null| 822|Python|
| ann| 821| 821| Unix|
| bob| 821| 821| Unix|
+-----+------+---------+------+
DataFrame conversions¶
[95]:
sc = spark.sparkContext
[96]:
rdd = sc.parallelize([('ann', 23), ('bob', 34)])
[97]:
df = spark.createDataFrame(rdd, schema='name STRING, age INT')
[98]:
df.show()
+----+---+
|name|age|
+----+---+
| ann| 23|
| bob| 34|
+----+---+
[99]:
df.rdd.map(lambda x: (x[0], x[1]**2)).collect()
[99]:
[('ann', 529), ('bob', 1156)]
[100]:
df.rdd.mapValues(lambda x: x**2).collect()
[100]:
[('ann', 529), ('bob', 1156)]
[101]:
df.toPandas()
[101]:
name | age | |
---|---|---|
0 | ann | 23 |
1 | bob | 34 |