Apache Spark: Partition Tuning

(, en)

Partitioning is a tricky business in Spark. If you do it right, Spark might even reach optimal performance, whereas if you do it wrong, Out-Of-Memory (OOM) errors and slow-performing jobs will haunt you.

There are three key principles you have to keep in mind when working with partitions in spark:

  1. Smart partitioning with .coalesce or .repartition helps with transformations that do not require shuffling. These type of transformations are also knowns as narrow transformations.
  2. As soon as you do a wide transformation (all transformations that involve shuffling), .coalesce or .repartition don’t help. So no need to repartition before a join. Here spark.sql.shuffle.partitions is a key setting.
  3. If you write data with .partitionBy, your data gets sliced in addition to your (already) existing spark partition. If you have 20 Spark partitions and do a .partitionBy on a different column with 30 distinct values, you end up with 20 * 30 files on disk.

To clarify point 2. and 3.

spark.table(...)
  .repartition(F.col("country"))
  .write
  .partitionBy("country")
  .format("parquet")
  .mode("overwrite")
  .saveAsTable(...)

» the number of files is determined by the “country” col and some size limit somewhere in the configs.

spark.table(...)
  .write
  .partitionBy("country")
  .format("parquet")
  .mode("overwrite")
  .saveAsTable(...)

» number of files is determined by the “country” col, current Spark partitions and some size limit somewhere in the configs. When shuffling data for joins or aggregations, spark uses spark.sql.shuffle.partitions (default 200). This means you can easily create a lot of files when you don’t pay attention.

Insight

Before you start fiddling with Spark configuration parameters, make sure you know how your data is structured and where the bottleneck lies. You can do this by inspecting your job with the Spark UI or Spark history server. Look for data skews and long-running tasks.

A look at the size (bytes) and rows per partition might help, too.

I use this function for Spark V1 tables:

def get_table_size(spark, fqtn):
    spark.sql(f"ANALYZE TABLE {fqtn} COMPUTE STATISTICS NOSCAN")
    stats = (
        spark.sql(f"DESCRIBE TABLE EXTENDED {fqtn}")
        .where(F.col("col_name") == "Statistics")
        .select("data_type")
        .head()
    )
    if not stats:
        return None
    m = re.search(r"(\d+)\s+bytes", stats[0])
    if not m:
        return None
    return int(m.group(1))

For V2 tables, you have to check the tools of the storage framework you use.

Can I figure out how many (shuffle) partitions I need before I run a (wide) transformation?

UI-wise, Spark is not the best example. Unfortunately, you need to keep poking around in the Spark UI to find the right numbers.

See also