Create Plots Using Pandas Crosstab() in Python

When I was working on a data analysis project, I needed to visualize the relationship between two categorical variables in my dataset. The challenge was finding an efficient way to both tabulate and visualize this relationship in one go. That’s when pandas’ crosstab() function came to my rescue.

In this article, I’ll show you how to create insightful plots using pandas crosstab() in Python. The crosstab() function is incredibly useful for analyzing relationships between categorical variables, and when combined with visualization, it becomes an efficient tool in your data analysis toolkit.

So let’s get in and see how we can use crosstab() to create various types of plots!

Pandas Crosstab()

Python Pandas crosstab() is essentially a function that computes a cross-tabulation of two or more factors. Think of it as creating a spreadsheet pivot table, but in Python.

The basic syntax looks like this:

pd.crosstab(index, columns, values=None, rownames=None, colnames=None, aggfunc=None, normalize=False)

Where:

  • index: The values to group by in the rows
  • columns: The values to group by in the columns
  • values: Optional values to aggregate
  • normalize: Option to normalize the results

Before we jump into creating plots, let’s make sure we have our environment set up properly.

Read Drop the Unnamed Column in Pandas DataFrame

Set Up Your Environment

First, let’s import the necessary libraries:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

Now, let’s create a sample dataset we’ll use throughout this tutorial:

# Creating a sample dataset of customer purchases
np.random.seed(42)
data = {
    'Age_Group': np.random.choice(['18-25', '26-35', '36-45', '46+'], size=1000),
    'Gender': np.random.choice(['Male', 'Female'], size=1000),
    'Product_Category': np.random.choice(['Electronics', 'Clothing', 'Home Goods', 'Books', 'Groceries'], size=1000),
    'Purchase_Amount': np.random.normal(50, 25, size=1000),
    'State': np.random.choice(['California', 'Texas', 'New York', 'Florida', 'Illinois'], size=1000)
}

df = pd.DataFrame(data)

This dataset represents customer purchases with demographics like age group, gender, product category, purchase amount, and state – pretty common data you’d encounter in real-world business analysis scenarios in the US.

Check out Drop Non-Numeric Columns From Pandas DataFrame

Create Plots Using Pandas Crosstab() in Python

Now, I will explain to you the methods to create plots using Pandas Crosstab() in Python.

Method 1: Create a Basic Heatmap with Crosstab

Let’s start with a simple but effective visualization: a heatmap showing the relationship between age groups and product categories.

# Create a crosstab between Age_Group and Product_Category
ct = pd.crosstab(df['Age_Group'], df['Product_Category'])

# Create a heatmap
plt.figure(figsize=(10, 6))
sns.heatmap(ct, annot=True, cmap='YlGnBu', fmt='d')
plt.title('Purchase Frequency by Age Group and Product Category')
plt.tight_layout()
plt.show()

You can see the output in the screenshot below.

pandas crosstab

This code creates a heatmap that shows the count of purchases for each combination of age group and product category. The annot=True parameter adds the count numbers to each cell, making it easier to interpret.

The resulting heatmap gives us a clear visual representation of which age groups are buying which product categories the most. For instance, you might notice that the ’26-35′ age group tends to purchase more electronics, while the ’46+’ group might lean toward home goods.

Read Remove All Non-Numeric Characters in Pandas

Method 2: Create Normalized Crosstab Plots

Sometimes, the raw counts don’t tell the whole story. We might want to see proportions instead. This is where the normalize parameter comes in handy.

# Create a normalized crosstab (row-wise)
ct_norm = pd.crosstab(df['Age_Group'], df['Product_Category'], normalize='index')

# Plot as a stacked bar chart
ct_norm.plot(kind='bar', stacked=True, figsize=(12, 6), colormap='viridis')
plt.title('Proportion of Product Categories Purchased by Age Group')
plt.xlabel('Age Group')
plt.ylabel('Proportion')
plt.legend(title='Product Category', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

You can see the output in the screenshot below.

python crosstab

The normalize='index' parameter normalizes the values row-wise, showing the proportion of each product category within each age group. This creates a stacked bar chart where each bar (representing an age group) adds up to 1.

This visualization is particularly useful when you want to compare the distribution of purchases across product categories within each age group, regardless of the total number of purchases.

Check out Pandas drop_duplicates() Function in Python

Method 3: Use Crosstab with Values and Aggregation

So far, we’ve only been counting occurrences. But what if we want to analyze the purchase amounts? The values and aggfunc parameters allow us to do this.

# Create a crosstab with purchase amount values and aggregation
ct_values = pd.crosstab(
    df['Gender'], 
    df['Product_Category'], 
    values=df['Purchase_Amount'], 
    aggfunc='mean'
)

# Create a grouped bar chart
ct_values.plot(kind='bar', figsize=(12, 6))
plt.title('Average Purchase Amount by Gender and Product Category')
plt.xlabel('Gender')
plt.ylabel('Average Purchase Amount ($)')
plt.legend(title='Product Category', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

You can see the output in the screenshot below.

crosstab python

This code creates a bar chart showing the average purchase amount for each combination of gender and product category. The values parameter specifies the column to aggregate, and aggfunc='mean' tells pandas to calculate the mean of those values.

This visualization helps us understand spending patterns across genders and product categories. For example, we might see that males spend more on electronics, while females spend more on clothing.

Read Pandas Unique Values in Column Without NaN in Python

Method 4: Multi-Level Crosstab Visualization

One of the useful features of Crosstab is the ability to work with multiple levels of categories. Let’s create a more complex visualization that incorporates three categorical variables:

# Create a multi-level crosstab
ct_multi = pd.crosstab([df['Gender'], df['Age_Group']], df['Product_Category'])

# Plot this as a heatmap
plt.figure(figsize=(14, 10))
sns.heatmap(ct_multi, annot=True, cmap='coolwarm', fmt='d')
plt.title('Purchase Frequency by Gender, Age Group, and Product Category')
plt.tight_layout()
plt.show()

This creates a hierarchical heatmap where we can see the purchase counts for each combination of gender, age group, and product category. This is especially useful for identifying complex patterns that might not be visible in simpler visualizations.

Method 5: Geographic Visualization with Crosstab

Since our dataset includes US states, let’s create a visualization that shows the distribution of product categories across states:

# Create a crosstab between State and Product_Category
ct_state = pd.crosstab(df['State'], df['Product_Category'])

# Create a stacked bar chart
ct_state.plot(kind='bar', stacked=True, figsize=(12, 6), colormap='tab10')
plt.title('Distribution of Product Categories Across States')
plt.xlabel('State')
plt.ylabel('Number of Purchases')
plt.legend(title='Product Category', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

This stacked bar chart shows how product categories are distributed across different states. It’s particularly useful for businesses looking to understand regional preferences and tailor their marketing strategies accordingly.

Check out Pandas Get Index of Row in Python

Method 6: Advanced Visualization with Seaborn and Crosstab

We can also use the output of crosstab with Seaborn’s more advanced plotting functions:

# Create a crosstab between Age_Group and Product_Category with purchase amount
ct_amount = pd.crosstab(
    df['Age_Group'], 
    df['Product_Category'], 
    values=df['Purchase_Amount'], 
    aggfunc='sum'
)

# Create a clustermap
plt.figure(figsize=(12, 10))
sns.clustermap(
    ct_amount, 
    cmap='YlGnBu', 
    standard_scale=1,  # Scale the data
    figsize=(12, 10)
)
plt.title('Clustered Heatmap of Total Purchase Amounts', pad=50)
plt.tight_layout()
plt.show()

This creates a clustered heatmap that not only shows the relationship between age groups and product categories but also clusters similar groups together. This can reveal patterns and similarities that might not be immediately obvious.

Read Add Column from Another Dataframe in Pandas

Tips for Effective Crosstab Visualizations

After working with crosstab visualizations for years, I’ve found a few tips that can help you create more effective plots:

  1. Choose the right normalization: Use normalize='index' to compare distributions within rows, normalize='columns' for columns, and normalize='all' to see the proportion of each cell relative to the entire table.
  2. Select appropriate colors: Use color scales that make sense for your data. For example, sequential color maps (like ‘YlGnBu’) work well for quantities, while diverging color maps (like ‘coolwarm’) are good for showing deviations from a central value.
  3. Add annotations: Adding numbers to your heatmaps with annot=True makes them much easier to interpret.
  4. Handle missing values: If your crosstab has missing values, you can fill them with zeros or another value using fill_value=0.
  5. Format your visualizations: Always add titles, labels, and legends to make your plots self-explanatory.

I hope you found this tutorial on creating plots using pandas crosstab() helpful. With these techniques, you can effectively visualize relationships between categorical variables in your datasets, giving you deeper insights into your data. I have also given tips for efficient crosstab visualization

Other Python articles you may also like:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.