A pivot table is a tool used for summarizing data, allowing you to group and aggregate information based on categorical columns.
In the context of PySpark, a pivot table transforms unique values from one column into multiple columns, aggregating values using functions like sum
, count
, average
, etc.
This is useful for reshaping data into a more digestible format, especially for reporting or analytics purposes.
The following PySpark code demonstrates how to create a pivot table with the provided sample data:
from pyspark.sql import SparkSession
from pyspark.sql.functions import sum
spark = SparkSession.builder \
.appName("Pivot Table Example") \
.getOrCreate()
# Sample data as a list of dictionaries
data = [
{"employee": "Alice", "region": "North", "sales": 100},
{"employee": "Bob", "region": "North", "sales": 200},
{"employee": "Alice", "region": "South", "sales": 300},
{"employee": "Bob", "region": "South", "sales": 400},
{"employee": "Alice", "region": "East", "sales": 150},
{"employee": "Bob", "region": "West", "sales": 250}
]
df = spark.createDataFrame(data)
df.show()
# Pivot the table to show sales by employee per region
pivot_df = df.groupBy("employee").pivot("region").agg(sum("sales"))
# Show the pivoted DataFrame
pivot_df.show()
spark.stop()
groupBy("employee")
groups the data by employee.pivot("region")
creates new columns for each unique value in the "region" column (e.g., "North", "South", "East", "West").agg(sum("sales"))
aggregates the sales data by summing the values for each combination of employee and region.The output would look something like this:
+--------+-------+-----+
|employee| region|sales|
+--------+-------+-----+
| Alice| North| 100|
| Bob| North| 200|
| Alice| South| 300|
| Bob| South| 400|
| Alice| East| 150|
| Bob| West| 250|
+--------+-------+-----+
+--------+-----+-----+------+-----+
|employee| East|North| South| West|
+--------+-----+-----+------+-----+
| Alice| 150| 100| 300| null|
| Bob| null| 200| 400| 250|
+--------+-----+-----+------+-----+