Mocking nested calls of a Spark instance (or dataframe)

(, en)

Python gives you nice tools to test your code. Most prominent are pytest and unittest.mock. There you can find all info you need to write python tests and mocks. Still, sometimes you might wonder: how should I test this Spark function? Here you might find an answer!

First get acquainted with unittest.mock:

PySpark’s SparkSession and DataFrame have many functions that return self, meaning that you can chain invocations. To emulate this behaviour, you can create a mock and let it return the SparkSession/DataFrame instead of a new mock (which is the default behaviour of unittest.mock).

Let’s assume you have a function that takes a spark session as input and loads a CSV file:

def load_csv(spark, file):
    return ("csv")
        .option("sep", ";")
        .option("inferSchema", "true")
        .option("header", "true")
        # ... probably more transformations and filters

Your goal is to test if the load_csv() function calls spark.load(). I admit, that this example is quite artifical, but bear with me.

Side note

The need to create mocks like this is often caused by functions that mix data loading/ingestion with transformation/filtering (aka domain logic). In load_csv() it is indicated by »... probably more transformations and filters«.

I try to prevent this situation as much as possible, but when working with existing codebases, you might not be able to do so. Now you are in a chicken and egg situation: for refactoring I need tests vs to write a good test I need to refactor. This approach provides a shortcut to build a test as basis for refactoring later.

The mock spark session looks like this:

import pytest
from unittest.mock import Mock

def spark_mock():
    spark_mock = Mock()
    type(spark_mock).write = spark_mock
    type(spark_mock).read = spark_mock
    spark_mock.table.return_value = spark_mock
    spark_mock.format.return_value = spark_mock
    spark_mock.option.return_value = spark_mock
    spark_mock.mode.return_value = spark_mock = None
    return spark_mock

Now I can feed the mock to the load_csv() function:

import pytest
import pyspark.sql.functions as F
from py3_spielwiese.spark import load_csv
from pyspark.sql import Row
from pyspark.sql.types import StringType, StructField, StructType, IntegerType, DoubleType

def sample_df(spark):
    data = [("Walter", 32, "Germany", 10000.0)]

    # for simple use cases you can also omit the schema
    # return spark.createDataFrame(data, ["name", "age", "country", "salary"])

    schema = StructType(
            StructField("name", StringType(), False),
            StructField("age", IntegerType(), False),
            StructField("country", StringType(), False),
            StructField("salary", DoubleType(), False),
    return spark.createDataFrame(data, schema)

def test_mock(spark_mock, sample_df):
    # return a sample df when spark.load is called
    spark_mock.load.return_value = sample_df

    # here we supply the mock to load_csv
    df = load_csv(spark_mock, "resources/people.csv")

    assert spark_mock.format.called
    assert not
    spark_mock.option.assert_any_call("sep", ";")

    # here we check if load was indeed called with the supplied file

    row = df.withColumn("salary", 1.5 * F.col("salary")).head()
    assert row == Row(name="Walter", age=32, country="Germany", salary=15000.0)

Why is this handier than the default behaviour of the Mock() implementation? The assertions are closer to what you expect from a spark object. Otherwise you have to know the exact order in which the nested functions are called, which makes the test (even) more fragile. Fragile, because the order of option() invocations can easily change and break the test.

Additionally, you can set spark_mock.load.return_value directly, and return your own sample data frame!

The spark instance I’ve set up like this:

from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf

def spark_context(request):
    conf = SparkConf().setMaster("local[2]").setAppName("pytest-pyspark-local-testing")
    sc = SparkContext(conf=conf)
    yield sc

def spark(request):
    spark = SparkSession.getActiveSession()
    if spark is not None:

    conf = SparkConf().setMaster("local[2]").setAppName("pytest-pyspark2.+-local-testing")
    conf.set("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.3.4")
    spark = SparkSession.builder.config(conf=conf).getOrCreate()
    yield spark

The spark mock and spark context/session fixtures reside in my file.