You find Python easier than SQL? User-Defined Functions in PySpark might be what you’re looking for
Data scientists aren’t necessarily the best SQL users. Maybe you’re proficient in Python, but you don’t know how to translate that knowledge into SQL. That shouldn’t stop you from leveraging everything Spark and PySpark have to offer.
With User-Defined Functions (UDFs), you can write functions in Python and use them when writing Spark SQL queries. Today I’ll show you how to declare and register 5 Python functions and use them to clean and reformat the well-known Titanic dataset. You’ll also learn how to filter out records after using UDFs towards the end of the article.
Don’t feel like reading? Watch my video instead:
Dataset Used and Spark Session Initialization
To keep things extra simple, we’ll use the Titanic dataset. Download it from this URL and store it somewhere you’ll remember:
The dataset packs much more features than, let’s say, the Iris dataset. These will come in handy later. But first, let’s see how to load it with Spark.
If you’re working with PySpark in a notebook environment, always use this code snippet for better output formatting. Otherwise, the dataframes are likely to overflow if there are too many columns to see on the screen at once:
from IPython.core.display import HTML
display(HTML("<style>pre { white-space: pre !important; }</style>"))
When that’s out of the way, we can initialize a new Spark session:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("spark-sql").getOrCreate()
To read a CSV file, simply specify the path to the csv()
function of the read
module. The inferSchema
and header
parameters are mandatory whenever reading CSV files. Without them, Spark will cast every data type to string and treat the header row as actual data:
titanic = spark.read.csv(
path="file://<dataset-path>",
inferSchema=True,
header=True
)
titanic.show(10)
And with that out of the way, let’s declare our first User-Defined function.
How to Declare User-Defined Functions (UDFs) in Spark and Python
You can declare a User-Defined Function just like any other Python function. The trick comes later when you register a Python function with Spark, but more on that in a bit.
Our first function will extract the title from the passenger name. Titles for all names start after the first comma and include one word after it. It’s easy to extract it by chaining two split()
functions:
def extract_title(name):
return name.split(', ')[-1].split()[0]
extract_title("Braund, Mr. Owen Harris")
>>> 'Mr.'
The title alone isn’t a machine-learning-friendly attribute. We want a new, binary attribute with a value of 1 if person’s title is common - Miss, Mr, Mrs, and Master - and 0 otherwise:
def is_common_title(title):
return 1 if title in ["Miss.", "Mr.", "Mrs.", "Master."] else 0
is_common_title("Dr.")
>>> 0
The Sex
column is another attribute that needs some transformation. The remap_gender()
function returns 1 if Sex
is “male”, and 0 otherwise:
def remap_gender(gender):
return 1 if gender == "male" else 0
remap_gender("female")
>>> 0
The Titanic dataset is full of missing values, and attribute Cabin
is no exception. The thing is, not all passengers had a dedicated cabin, or their cabin number wasn’t known. We’ll declare a function that returns 1 if passenger had a cabin and 0 otherwise:
def has_cabin(cabin):
return 1 if cabin else 0
has_cabin("")
>>> 0
Finally, we’ll write a function for imputing passenger age. Unlike the previous functions, this one accepts two parameters:
age
- Incoming value from the dataset, could be a number or Null.value
- Value we’ll replace theage
with.
The logic is simple - If age
is missing, return value
, otherwise return age
:
def replace_missing_age(age, value):
return age if age else value
replace_missing_age(15, -1)
>>> 15
That’s all from the function declaration end, and now it’s time to use them in Spark. To do so, you’ll first have to register them through the spark.udf.register()
function. It accepts two parameters:
name
- A string, function name you’ll use in SQL queries.f
- A Python function that contains the programming logic.
It’s a common practice to give identical values to both parameters, just to avoid confusion later:
spark.udf.register("extract_title", extract_title)
spark.udf.register("is_common_title", is_common_title)
spark.udf.register("remap_gender", remap_gender)
spark.udf.register("has_cabin", has_cabin)
spark.udf.register("replace_missing_age", replace_missing_age)
And that’s it - you’re now ready to use these Python User-Declared Functions in Spark SQL!
How to Use User-Defined Functions in Spark SQL
If you’ve followed my Introduction to SQL in Spark article, you know that you’ll first have to create a temporary view on a DataFrame. That’s done by calling createOrReplaceTempView()
function and passing in a view name (arbitrary one):
titanic.createOrReplaceTempView("titanic")
Once created, pass any SQL statement in a call to spark.sql()
. The example below uses all five of our User-Defined Functions:
spark.sql("""
SELECT
Survived,
Pclass,
extract_title(Name) AS Title,
is_common_title(extract_title(Name)) AS IsCommonTitle,
remap_gender(Sex) as IsMale,
replace_missing_age(Age, -1) as Age,
SibSp,
Parch,
Fare,
has_cabin(Cabin) as HasCabin,
Embarked
FROM titanic
""").show(10)
As you can see, our functions work like a charm! I’ve set the replace_missing_gender()
to replace any missing value with -1, but you’re free to change that. You can also chain functions together, as shown with is_common_title(extract_title(Name))
call.
But you know what the problem is? Behind the scenes, Spark gives arbitrary names to columns that are result of UDFs, even though we’ve explicitly specified an alias. It’s an issue because you can’t use the WHERE
keyword to filter out records based on a condition.
What you can do instead is to wrap the entire query inside another one (subquery) and then do any SQL magic you want:
spark.sql("""
SELECT * FROM (
SELECT
Survived,
Pclass,
extract_title(Name) AS Title,
is_common_title(extract_title(Name)) AS IsCommonTitle,
remap_gender(Sex) as IsMale,
replace_missing_age(Age, -1) as Age,
SibSp,
Parch,
Fare,
has_cabin(Cabin) as HasCabin,
Embarked
FROM titanic
) WHERE IsCommonTitle = 0
""").show(10)
Now you can see that only records with uncommon passenger title show. It’s mostly older males, which is expected.
Summary of User-Defined Functions in PySpark
Today you’ve learned how to work with User-Defined Functions (UDF) in Python and Spark. This is a huge milestone if you’re using Python daily and aren’t the biggest fan of SQL. Sure, these are not a replacement for an adequate SQL knowledge, but no one can stop you from using them.
Some things are easier to do in Python, and UDFs make sure you don’t waste time figuring our SQL commands when time is of the essence. Just remember the steps needed to use UDFs:
- Declare a Python function
- Register Python function with Spark
- Use the function in Spark SQL statements
It’s that easy! Up next, you’ll learn how to apply a machine learning algorithm to this dataset, so stay tuned.
Recommended reads
- 5 Best Books to Learn Data Science Prerequisites (Math, Stats, and Programming)
- Top 5 Books to Learn Data Science in 2022
- 7 Ways to Print a List in Python
Stay connected
- Hire me as a technical writer
- Subscribe on YouTube
- Connect on LinkedIn