You've successfully subscribed to Better Data Science
Great! Next, complete checkout for full access to Better Data Science
Welcome back! You've successfully signed in
Success! Your account is fully activated, you now have access to all content.

Apache Spark for Data Science - User-Defined Functions (UDF) Explained

Apache Spark for Data Science - User-Defined Functions (UDF) Explained

You find Python easier than SQL? User-Defined Functions in PySpark might be what you're looking for

Data scientists aren't necessarily the best SQL users. Maybe you're proficient in Python, but you don't know how to translate that knowledge into SQL. That shouldn't stop you from leveraging everything Spark and PySpark have to offer.

With User-Defined Functions (UDFs), you can write functions in Python and use them when writing Spark SQL queries. Today I'll show you how to declare and register 5 Python functions and use them to clean and reformat the well-known Titanic dataset. You'll also learn how to filter out records after using UDFs towards the end of the article.

Don't feel like reading? Watch my video instead:


Dataset Used and Spark Session Initialization

To keep things extra simple, we'll use the Titanic dataset. Download it from this URL and store it somewhere you'll remember:

Image 1 - Titanic dataset (image by author)
Image 1 - Titanic dataset (image by author)

The dataset packs much more features than, let's say, the Iris dataset. These will come in handy later. But first, let's see how to load it with Spark.

If you're working with PySpark in a notebook environment, always use this code snippet for better output formatting. Otherwise, the dataframes are likely to overflow if there are too many columns to see on the screen at once:

from IPython.core.display import HTML
display(HTML("<style>pre { white-space: pre !important; }</style>"))

When that's out of the way, we can initialize a new Spark session:

from pyspark.sql import SparkSession


spark = SparkSession.builder.appName("spark-sql").getOrCreate()

To read a CSV file, simply specify the path to the csv() function of the read module. The inferSchema and header parameters are mandatory whenever reading CSV files. Without them, Spark will cast every data type to string and treat the header row as actual data:

titanic = spark.read.csv(
    path="file://<dataset-path>", 
    inferSchema=True, 
    header=True
)
titanic.show(10)
Image 2 - First 10 rows of the Titanic dataset (image by author)
Image 2 - First 10 rows of the Titanic dataset (image by author)

And with that out of the way, let's declare our first User-Defined function.


How to Declare User-Defined Functions (UDFs) in Spark and Python

You can declare a User-Defined Function just like any other Python function. The trick comes later when you register a Python function with Spark, but more on that in a bit.

Our first function will extract the title from the passenger name. Titles for all names start after the first comma and include one word after it. It's easy to extract it by chaining two split() functions:

def extract_title(name):
    return name.split(', ')[-1].split()[0]

extract_title("Braund, Mr. Owen Harris")

>>> 'Mr.'

The title alone isn't a machine-learning-friendly attribute. We want a new, binary attribute with a value of 1 if person's title is common - Miss, Mr, Mrs, and Master - and 0 otherwise:

def is_common_title(title):
    return 1 if title in ["Miss.", "Mr.", "Mrs.", "Master."] else 0

is_common_title("Dr.")

>>> 0

The Sex column is another attribute that needs some transformation. The remap_gender() function returns 1 if Sex is "male", and 0 otherwise:

def remap_gender(gender):
    return 1 if gender == "male" else 0

remap_gender("female")

>>> 0

The Titanic dataset is full of missing values, and attribute Cabin is no exception. The thing is, not all passengers had a dedicated cabin, or their cabin number wasn't known. We'll declare a function that returns 1 if passenger had a cabin and 0 otherwise:

def has_cabin(cabin):
    return 1 if cabin else 0

has_cabin("")

>>> 0

Finally, we'll write a function for imputing passenger age. Unlike the previous functions, this one accepts two parameters:

  • age - Incoming value from the dataset, could be a number or Null.
  • value - Value we'll replace the age with.

The logic is simple - If age is missing, return value, otherwise return age:

def replace_missing_age(age, value):
    return age if age else value

replace_missing_age(15, -1)

>>> 15

That's all from the function declaration end, and now it's time to use them in Spark. To do so, you'll first have to register them through the spark.udf.register() function. It accepts two parameters:

  • name - A string, function name you'll use in SQL queries.
  • f - A Python function that contains the programming logic.

It's a common practice to give identical values to both parameters, just to avoid confusion later:

spark.udf.register("extract_title", extract_title)
spark.udf.register("is_common_title", is_common_title)
spark.udf.register("remap_gender", remap_gender)
spark.udf.register("has_cabin", has_cabin)
spark.udf.register("replace_missing_age", replace_missing_age)

And that's it - you're now ready to use these Python User-Declared Functions in Spark SQL!


How to Use User-Defined Functions in Spark SQL

If you've followed my Introduction to SQL in Spark article, you know that you'll first have to create a temporary view on a DataFrame. That's done by calling createOrReplaceTempView() function and passing in a view name (arbitrary one):

titanic.createOrReplaceTempView("titanic")

Once created, pass any SQL statement in a call to spark.sql(). The example below uses all five of our User-Defined Functions:

spark.sql("""
    SELECT
        Survived,
        Pclass,
        extract_title(Name) AS Title,
        is_common_title(extract_title(Name)) AS IsCommonTitle,
        remap_gender(Sex) as IsMale,
        replace_missing_age(Age, -1) as Age,
        SibSp,
        Parch,
        Fare,
        has_cabin(Cabin) as HasCabin,
        Embarked
    FROM titanic
""").show(10)
Image 3 - Using User-Defined Functions in PySpark (image by author)
Image 3 - Using User-Defined Functions in PySpark (image by author)

As you can see, our functions work like a charm! I've set the replace_missing_gender() to replace any missing value with -1, but you're free to change that. You can also chain functions together, as shown with is_common_title(extract_title(Name)) call.

But you know what the problem is? Behind the scenes, Spark gives arbitrary names to columns that are result of UDFs, even though we've explicitly specified an alias. It's an issue because you can't use the WHERE keyword to filter out records based on a condition.

What you can do instead is to wrap the entire query inside another one (subquery) and then do any SQL magic you want:

spark.sql("""
    SELECT * FROM (
        SELECT
            Survived,
            Pclass,
            extract_title(Name) AS Title,
            is_common_title(extract_title(Name)) AS IsCommonTitle,
            remap_gender(Sex) as IsMale,
            replace_missing_age(Age, -1) as Age,
            SibSp,
            Parch,
            Fare,
            has_cabin(Cabin) as HasCabin,
            Embarked
        FROM titanic
    ) WHERE IsCommonTitle = 0
""").show(10)
Image 4 - User-Defined Functions with additional filters (image by author)
Image 4 - User-Defined Functions with additional filters (image by author)

Now you can see that only records with uncommon passenger title show. It's mostly older males, which is expected.


Summary of User-Defined Functions in PySpark

Today you've learned how to work with User-Defined Functions (UDF) in Python and Spark. This is a huge milestone if you're using Python daily and aren't the biggest fan of SQL. Sure, these are not a replacement for an adequate SQL knowledge, but no one can stop you from using them.

Some things are easier to do in Python, and UDFs make sure you don't waste time figuring our SQL commands when time is of the essence. Just remember the steps needed to use UDFs:

  1. Declare a Python function
  2. Register Python function with Spark
  3. Use the function in Spark SQL statements

It's that easy! Up next, you'll learn how to apply a machine learning algorithm to this dataset, so stay tuned.

Stay connected