Compare the rows of two Spark DataFrames

(, en)

PySpark is a well known tool/API in the data engineering field. A common task is to build tests using a small test dataset. That works fine until a test starts failing and you need to figure out where the error is. In these cases an assertion that “diffs” the rows can save you a lot of time:

def test_row_equality(sample_df):
    data = [
        Row(name="Walter", age=32, country="Germany", salary=10000.0),
        Row(name="John", age=39, country="England", salary=20000.0),
        Row(name="Alice", age=17, country="Wonderland", salary=120000.0),
    ]

    tks.assert_equal(sample_df.collect(), data)
$ pytest -v tests/test_spark_assert.py
...
    def assert_equal(actual, expected, sort=True):
>       assert rows_to_dicts(actual, sort=sort) == rows_to_dicts(expected, sort=sort)
E       AssertionError: assert [{'age': 17, ...ry': 20000.0}] == [{'age': 17, ...ry': 20000.0}]
E         At index 1 diff: {'name': 'Walter', 'age': 32, 'country': 'Germany', 'salary': 10000.0}
E             != {'name': 'Walter', 'age': 33, 'country': 'Germany', 'salary': 10000.0}
E         Full diff:
E           [
E            {'age': 17, 'country': 'Wonderland', 'name': 'Alice', 'salary': 120000.0},
E         -  {'age': 33, 'country': 'Germany', 'name': 'Walter', 'salary': 10000.0},
E         ?           ^
E         +  {'age': 32, 'country': 'Germany', 'name': 'Walter', 'salary': 10000.0},...
E
E         ...Full output truncated (4 lines hidden), use '-vv' to show
...
====================================== 1 failed in 5.16s ======================================

You can immediately see that Walter from Germany should be 33 and not 32 years old. Luckily we have this test, otherwise German bureaucracy would give Walter a big headache.

If you want build a (more or less) custom assert function, you need a couple of ingredients:

  1. Two sets of rows that can be brought down to a common denominator, in our case dicts.
  2. An ordering approach that keeps the row order stable and prevents flaky tests due to the non-deterministic behaviour of df.collect().
  3. A way to wire your assertion function to the pytest machinery.

Comparing Spark dataframe rows.

The easy part is to convert the rows to dicts and create an assertion function to test that.

def lowercase_dict_keys(d):
    return {k.lower(): v for k, v in d.items()}


def rows_to_dicts(rows, case_sensitive_colnames=False, sort=True):
    res = [row.asDict() for row in rows]
    if not case_sensitive_colnames:
        res = [lowercase_dict_keys(d) for d in res]
    if sort:
        return sorted(res, key=calc_row_hash)
    return res


def assert_equal(actual, expected, *, case_sensitive_colnames=False, sort=True):
    assert rows_to_dicts(
        actual, case_sensitive_colnames=case_sensitive_colnames, sort=sort
    ) == rows_to_dicts(expected, case_sensitive_colnames=case_sensitive_colnames, sort=sort)

(Often the column names are case-insensitive, so by default it should not matter.)

There are two issues that might cause trouble:

  1. the row order might change
  2. the column order might change

To make the comparison deterministic, we introduce a hashing function:

def calc_row_hash(d):
    return "||".join([str(i[1]) for i in sorted(d.items(), key=lambda x: x[0])])

This function sorts the keys of a dictionary by name and concatenates the values by ||.

Wiring our assertion to pytest

I usually have a separate testkit package (see (2) in code below) in the tests/ directory. This package contains all the helper functions I need to write tests effectively. However, you can also just create a simple helpers file (see (1) in the code below).

tests
├── __init__.py
├── conftest.py
├── helpers.py             # (1)
├── kit                    # (2)
│   ├── __init__.py
│   └── spark.py
└── test_spark_assert.py

I usually add

[tool.pytest.ini_options]
norecursedirs="tests/kit"

to my pyproject.toml to prevent accidental test discovery in my testkit module. One last thing to do: connect your code to pytest.

pytest rewrites assertions and it needs to be aware of your assertion function, so you need to register your custom package/module:

import pytest

pytest.register_assert_rewrite("tests.kit.spark")

Conclusion

With this setup you are good to go and dealing with Spark dataframe rows shouldn’t be a problem anymore.