Pivot Tables and Reshaping Data
In data analysis, one common task is reshaping and transforming data to gain insights. Pandas, a powerful Python library, provides versatile tools for reshaping data, including the use of pivot tables. In this post, we will explore the concepts of pivot tables and reshaping data using the Seaborn's built-in Tips dataset (opens in a new tab) as examples.
pivot_table()
Function
Pivot tables are used to summarize and aggregate data in a tabular format. They allow you to group data by one or more columns and apply aggregation functions to other columns. Pivot tables are particularly useful for exploring relationships between variables.
The pivot_table()
function is used to create a pivot table from a DataFrame. It allows you to summarize and aggregate data based on one or more columns as indices and columns. Here's the syntax and some of the most common parameters:
pivot_table(values=None, index=None, columns=None, aggfunc='mean', fill_value=None, margins=False, margins_name='All')
values
: The column(s) whose values will be aggregated.index
: The column(s) to be used as the index of the resulting pivot table. These are the rows of the pivot table. Read more about axis parameter herecolumns
: The column(s) to be used as the columns of the resulting pivot table.aggfunc
: The aggregation function(s) to apply to thevalues
. Common options include 'mean', 'sum', 'count', 'min', 'max', etc. You can also provide custom aggregation functions.fill_value
: The value to replace missing (NaN) values with.margins
: IfTrue
, adds totals/sums for each row/column (margins). Default isFalse
.margins_name
: The name to use for the margin label(s).
Loading and Exploring the Tips Dataset
Let's start by loading the tips
dataset and examining its structure.
import pandas as pd
import seaborn as sns
# Load the 'tips' dataset from Seaborn
tips_df = sns.load_dataset('tips')
# Display first few rows of the dataset
print(tips_df.head())
Output:
total_bill | tip | sex | smoker | day | time | size | |
---|---|---|---|---|---|---|---|
0 | 16.99 | 1.01 | Female | No | Sun | Dinner | 2 |
1 | 10.34 | 1.66 | Male | No | Sun | Dinner | 3 |
2 | 21.01 | 3.5 | Male | No | Sun | Dinner | 3 |
3 | 23.68 | 3.31 | Male | No | Sun | Dinner | 2 |
4 | 24.59 | 3.61 | Female | No | Sun | Dinner | 4 |
Pivot Tables with Aggregation
We'll begin by creating a simple pivot table to calculate the average total bill amount for each day and meal time combination.
# Create a pivot table for average total bill by day and sex
pivot_table = tips_df.pivot_table(
values='total_bill',
index='day',
columns='sex',
aggfunc='mean')
# Display the pivot table
print(pivot_table.round(2)) # float upto 2 decimals
sex Male Female
day
Thur 18.71 16.72
Fri 19.86 14.15
Sat 20.80 19.68
Sun 21.89 19.87
Multi-level Pivot Tables
Pivot tables can have multiple levels of indices and columns. Let's create a pivot table with both 'day' and 'sex' as indices, and calculate the sum of total bills for each combination.
# multi-level pivot table for sum and mean of total bill by day and sex
multi_level_pivot = tips_df.pivot_table(
values='total_bill',
index=['day', 'sex'],
aggfunc=['sum', 'mean'])
# Display the multi-level pivot table
print(multi_level_pivot.round(2))
sum mean
total_bill total_bill
day sex
Thur Male 561.44 18.71
Female 534.89 16.72
Fri Male 198.57 19.86
Female 127.31 14.15
Sat Male 1227.35 20.80
Female 551.05 19.68
Sun Male 1269.46 21.89
Female 357.70 19.87
Reshaping with melt()
Function
The melt()
function is used to reshape wide format data into long format by "melting" columns into rows. It's particularly useful when you have multiple columns representing different variables and you want to convert them into a single column while retaining their corresponding values.
Example:
Let's use the melt()
function to convert the total_bill
and tip
columns into a single column named 'bill_type', and the corresponding values into a new 'amount' column.
# Using the melt() function to reshape data
melted_data = pd.melt(
tips_df,
id_vars=['day', 'time'],
value_vars=['total_bill', 'tip'],
var_name='bill_type',
value_name='amount')
# Display the reshaped data
print(melted_data.head())
day time bill_type amount
0 Sun Dinner total_bill 16.99
1 Sun Dinner total_bill 10.34
2 Sun Dinner total_bill 21.01
3 Sun Dinner total_bill 23.68
4 Sun Dinner total_bill 24.59
In this example, the id_vars
parameter specifies the columns that will remain as identifiers (not melted), while value_vars
lists the columns to be melted. The result is a long format DataFrame where the 'total_bill' and 'tip' columns have been melted into a single 'amount' column, and the 'bill_type' column indicates the original column name.
Reshaping with stack()
Function:
The stack()
function is used to stack the specified level(s) of columns vertically, essentially converting wide format data to a multi-level index.
Example:
We'll use the pivot_table()
function to create a wide format pivot table and then apply the stack()
function to convert it to long format.
# Create a pivot table for average total bill by day and time
pivot_table = tips_df.pivot_table(
values='total_bill',
index='day',
columns='time',
aggfunc='mean')
# Using the stack() function to reshape data
stacked_data = pivot_table.stack()
# Display the stacked data
print(stacked_data)
day time
Thur Lunch 17.664754
Dinner 18.780000
Fri Lunch 12.845714
Dinner 19.663333
Sat Dinner 20.441379
Sun Dinner 21.410000
dtype: float64
In this example, the pivot_table
is first created, and then the stack()
function is applied to it. The result is a Series with a multi-level index, where 'day' and 'time' are the indices. The 'total_bill' values are now stacked vertically, converting the data into long format.
Both melt()
and stack()
functions are valuable tools for transforming data between wide and long formats, depending on the analysis requirements.