Testing in Python
Python gives you nice tools to test your code. Most prominent are pytest and unittest.mock. There you can find all info you need to write python tests and mocks. Still, sometimes you might wonder: how should I test this function or framework? Here you might find an answer!
Mocking nested calls of a Spark instance
First get acquainted with
DataFrame have many functions that return
that you can chain invocations. To emulate this behaviour, you can create a mock and let
it return the SparkSession/DataFrame instead of a new mock (which is the default
Let’s assume you have a function that takes a spark session as input and loads a CSV file:
def load_csv(spark, file): return ( spark.read.format("csv") .option("sep", ";") .option("inferSchema", "true") .option("header", "true") .load(file) # ... probably more transformations and filters )
Your goal is to test if the
load_csv() function calls
spark.load(). I admit, that
this example is quite artifical, but bear with me.
The need to create mocks like this is often caused by functions that mix data loading/ingestion with transformation/filtering (aka domain logic). In
load_csv()it is indicated by »
... probably more transformations and filters«.
I try to prevent this situation as much as possible, but when working with existing codebases, you might not be able to do so. Now you are in a chicken and egg situation: for refactoring I need tests vs to write a good test I need to refactor. This approach provides a shortcut to build a test as basis for refactoring later.
The mock spark session looks like this:
import pytest from unittest.mock import Mock @pytest.fixture def spark_mock(): spark_mock = Mock() type(spark_mock).write = spark_mock type(spark_mock).read = spark_mock spark_mock.table.return_value = spark_mock spark_mock.format.return_value = spark_mock spark_mock.option.return_value = spark_mock spark_mock.mode.return_value = spark_mock spark_mock.save.return_value = None return spark_mock
Now I can feed the mock to the
import pytest import pyspark.sql.functions as F from py3_spielwiese.spark import load_csv from pyspark.sql import Row from pyspark.sql.types import StringType, StructField, StructType, IntegerType, DoubleType @pytest.fixture def sample_df(spark): data = [("Walter", 32, "Germany", 10000.0)] # for simple use cases you can also omit the schema # return spark.createDataFrame(data, ["name", "age", "country", "salary"]) schema = StructType( [ StructField("name", StringType(), False), StructField("age", IntegerType(), False), StructField("country", StringType(), False), StructField("salary", DoubleType(), False), ] ) return spark.createDataFrame(data, schema) def test_mock(spark_mock, sample_df): # return a sample df when spark.load is called spark_mock.load.return_value = sample_df # here we supply the mock to load_csv df = load_csv(spark_mock, "resources/people.csv") assert spark_mock.format.called assert not spark_mock.save.called spark_mock.option.assert_any_call("sep", ";") # here we check if load was indeed called with the supplied file spark_mock.load.assert_called_with("resources/people.csv") row = df.withColumn("salary", 1.5 * F.col("salary")).head() assert row == Row(name="Walter", age=32, country="Germany", salary=15000.0)
Why is this handier than the default behaviour of the
Mock() implementation? The
assertions are closer to what you expect from a spark object. Otherwise you have to know
the exact order in which the nested functions are called, which makes the test (even)
more fragile. Fragile, because the order of
option() invocations can easily change and
break the test.
Additionally, you can set
spark_mock.load.return_value directly, and return your own
sample data frame!
The spark instance I’ve set up like this:
from pyspark.sql import SparkSession from pyspark import SparkContext, SparkConf @pytest.fixture(scope="session") def spark_context(request): conf = SparkConf().setMaster("local").setAppName("pytest-pyspark-local-testing") sc = SparkContext(conf=conf) yield sc sc.stop() @pytest.fixture(scope="session") def spark(request): spark_conf = SparkConf().setMaster("local").setAppName("pytest-pyspark2.+-local-testing") spark = SparkSession.builder.config(conf=spark_conf).getOrCreate() yield spark spark.stop()
The spark mock and spark context/session fixtures reside in my