Making Plots With matplotlib/seaborn

Overview

Teaching: min
Exercises: min
Questions
  • How can I visualize data in Python?

Objectives
  • Create default seaborn plots

  • Set universal plot settings.

  • Modify existing plots.

  • Change the aesthetics of a plot such as color.

  • Edit the axis labels.

  • Build complex plots using a step-by-step approach.

  • Create scatter plots, box plots, and time series plots.

  • Create customized plot styles to meet their needs.

Python has powerful built-in plotting capabilities with the matplotlib library. For this episode, we will be using the seaborn package, which is built on top of Matplotlib and interacts well with Pandas.

There are other plotting packages but these are the two most popular. Another popular one is plotnine, which is inspired by the ggplot2 plotting library in R. For some material on this library, you can have a look at the source of this material that uses plotnine here.

Going back to seaborn, just as with the other packages, seaborn needs to be imported. Here we will import both matplotlib and seaborn. It is good practice to not just load an entire package such as from seaborn import *, but to use an abbreviation as we used pd for Pandas:

%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns

From now on, the functions of seaborn are available using sns. and matplotlib as plt.. For the exercise, we will use the surveys.csv data set, with the NA values removed

import pandas as pd

surveys_complete = pd.read_csv('data/surveys.csv')
surveys_complete = surveys_complete.dropna()

Plotting with matplotlib/seaborn

The seaborn package supports the creation of complex plots from data in a dataframe. It uses default settings, which help creating publication quality plots with a minimal amount of settings and tweaking.

Before we use the seaborn package, let’s just do the simplest plot we can do in Python using the built-in matplotlib. Most plots in matplotlib are based on explicitly passing the arrays for the x-values and y-values that will make up the points or lines of the plots. For example, to plot the weight versus the hindfoot length from our surveys data we will need to extract them from the dataframe and pass them to the plotting function:

plt.scatter(x=surveys_complete["weight"], y=surveys_complete["hindfoot_length"])

test

Because seaborn integrates well with pandas we can define our data object (the dataframe) and then assign column names to x and y values of the plot. So the same thing with seaborn is:

sns.scatterplot(data=surveys_complete, x="weight", y="hindfoot_length")

Building your plots iteratively

Building plots and interactive visualisation is typically an iterative process. So we would plot something initially but then depending on the results modify various parameters to, for example, see more clearly what we want.

We start by the same scatterplot we did before:

sns.scatterplot(data=surveys_complete, x="weight", y="hindfoot_length")

Then, we start modifying this plot to extract more information from it. For instance, we can add transparency (alpha) to avoid overplotting:

sns.scatterplot(data=surveys_complete, x="weight", y="hindfoot_length", alpha=0.5)

We can also add a different colour for all the points:

sns.scatterplot(data=surveys_complete, x="weight", y="hindfoot_length", alpha=0.25, color="red")

Or to color each species in the plot differently, we can map the species_id column to the color (hue):

sns.scatterplot(data=surveys_complete, x="weight", y="hindfoot_length", hue="species_id", alpha=0.25)

Apart from the adaptations of the arguments and settings of the seaborn plot, additional plot-level elements can be added and controlled as well. Here, it might be useful to have access to plot-level objects like the axes to set things like axes labels and so on. So we will create a figure and axes object with the built-in matplotlib so we have access to them and then pass them to seaborn for plotting. We will then use these objects to change the label of the x axis and the legend appearance:

fig = plt.figure()
ax = fig.add_subplot()
sns.scatterplot(data=surveys_complete, x="weight", y="hindfoot_length",
                alpha=0.5, hue="species_id", ax=ax)
ax.set_xlabel("Weight (g)")
plt.legend(ncol=3)

Challenge - Bar plots

Working on the surveys_complete data set, use the plot-id column to create a bar-plot that counts the number of records for each plot. (Check the documentation of the barplot to handle the counts)

(see in general API reference to find the appropriate plotting function).

Answers

sns.countplot(data=surveys_complete, x="plot_id", hue="sex")

Plotting distributions

Visualizing distributions is a common task during data exploration and analysis. To visualize the distribution of weight within each species_id group, a boxplot can be used:

sns.boxplot(data=surveys_complete, x="species_id", y="weight")

By plotting the points of the individual observations, we can have a better idea of the number of measurements and of their distribution:

sns.stripplot(data=surveys_complete, x="species_id", y="weight")

Sometimes boxplots can hide the ‘shape’ of the distribution. Violinplots can be used instead. See the full tutorial for plotting distribution data using seaborn including histograms and 2D distributions.

Plotting time series data

Let’s calculate number of counts per year for each species. To do that we need to group data first and count the species (species_id) within each group.

yearly_counts = surveys_complete.groupby(['year', 'species_id'])['species_id'].count()
yearly_counts

When checking the result of the previous calculation, we actually have both the year and the species_id as a row index. We can reset this index to use both as column variable:

yearly_counts = yearly_counts.reset_index(name='counts')
yearly_counts

Timelapse data can be visualised as a line plot with years on x axis and counts on the y axis.

sns.lineplot(data=yearly_counts, x='year', y='counts', hue='species_id')
plt.legend(ncol=3)

Faceting

As with other plotting libraries, seaborn also has a special technique called faceting that allows to split one plot into multiple plots based on a factor variable included in the dataset.

Consider our scatter plot of the weight versus the hindfoot_length from the previous sections:

sns.scatterplot(data=surveys_complete, x="weight", y="hindfoot_length")

We can now keep the same code and at the facet_wrap on a chosen variable to split out the graph and make a separate graph for each of the groups in that variable. As an example, use sex:

sns.relplot(data=surveys_complete, x="weight", y="hindfoot_length",
            col="sex", hue="species_id")

Note that we switched to the relplot figure-level function to do the facetting instead of using axis-level functions like scatterplot and lineplot.

We can apply the same concept on any of the available categorical variables:

sns.relplot(data=surveys_complete, x="weight", y="hindfoot_length",
            col="plot_id", hue="species_id", col_wrap=5)

Note that we use the col_wrap parameter to wrap the plots to multiple rows.

Challenge - facetting

Create a separate plot for each of the species that depicts how the average weight of the species changes through the years.

Answers

yearly_weight = surveys_complete.groupby(['year', 'species_id'])['weight'].mean().reset_index()

sns.relplot(data=yearly_weight, x="year", y="weight",
            col="species_id", col_wrap=5, kind="line")

Challenge - facetting

Based on the previous exercise, visually compare how the weights of male and females has changed through time by creating a separate plot for each sex and an individual color assigned to each species_id.

Answers

yearly_weight = surveys_complete.groupby(['year', 'species_id', 'sex'])['weight'].mean().reset_index()


sns.relplot(data=yearly_weight, x="year", y="weight", col="species_id",
            col_wrap=5, kind="line", hue="sex")

Further customization

The plots of seaborn, matplotlib, and other packages in Python can be customised heavily. It takes one look at the API of your favourite plot function to see the number of parameters it accepts. Then matplotlib as well on top of that has its own functions to modify its behaviour. In the end, the final look of the plots is a matter of personal preference or it can be dictated by external factors like publication standards.

Take for example a barchart that plots the counts of surveys per year:

sns.countplot(data=surveys_complete, x="year")

Notice that we use the year here as a categorical variable. However, by doing so, we have the individual year labels overlapping with each other. Functions in the matplotlib axis class provide a way to rotate the text of the x-axis labels:

fig = plt.figure()
ax = fig.add_subplot()
sns.countplot(data=surveys_complete, x="year", ax=ax)
ax.tick_params(axis='x', rotation=90)

Challenge - customization

Please take another five minutes to either improve one of the plots generated in this exercise or create a beautiful graph of your own.

Here are some ideas:

After creating your plot, you can save it to a file in your favourite format. You can easily change the dimension (and its resolution) of your plot by adjusting the appropriate arguments when you create the figure object.

fig = plt.figure(figsize=(3, 2), dpi=300)
ax = fig.add_subplot()
sns.scatterplot(data=surveys_complete, x="weight", y="hindfoot_length", alpha=0.25, hue="species_id", ax=ax)
plt.savefig("scatterplot.png")

Key Points