Visualization with Matplotlib#

Let’s take an in-depth look at Python’s Matplotlib package for visualization. Matplotlib supports dozens of backends and output types, which means you can count on it to work regardless of which operating system you are using or which output format you wish. This cross-platform, everything-to-everyone approach has been one of the great strengths of Matplotlib.

✏️ The example is inspired by [Van17].

We will use some standard shorthands for Matplotlib imports:

import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

Simple Line Plots#

For all Matplotlib plots, we start by creating a figure and axes. In their simplest form, a figure and axes can be created as follows:

fig = plt.figure()
ax = plt.axes()
../../_images/matplotlib_basics_3_0.png

In Matplotlib, the figure (an instance of the class plt.Figure) can be thought of as a single container that contains all the objects representing axes, graphics, text, and labels. The axes (an instance of the class plt.Axes) is what we see above: a bounding box with ticks and labels, which will eventually contain the plot elements that make up our visualization.

Let’s start with a simple sinusoid:

fig = plt.figure()
ax = plt.axes()

x = np.linspace(0, 10, 1000)
ax.plot(x, np.sin(x));
../../_images/matplotlib_basics_5_0.png

Alternatively, we can use the pylab interface and let the figure and axes be created for us in the background:

plt.plot(x, np.sin(x))
plt.plot(x, np.cos(x));
../../_images/matplotlib_basics_7_0.png

Adjusting the Plot#

Line Colors and Styles#

The first adjustment you might wish to make to a plot is to control the line colors and styles:

plt.plot(x, np.sin(x - 0), color='blue')
plt.plot(x, np.sin(x - 1), color='g')
plt.plot(x, np.sin(x - 2), color='0.75') ;
../../_images/matplotlib_basics_9_0.png

Matplotlib will automatically cycle through a set of default colors for multiple lines if no color is specified.

Similarly, the line style can be adjusted using the line style keyword:

plt.plot(x, x + 4, linestyle='-')
plt.plot(x, x + 5, linestyle='--')
plt.plot(x, x + 6, linestyle='-.');
../../_images/matplotlib_basics_11_0.png

Axes Limits#

The most basic way to adjust axis limits is to use the plt.xlim() and plt.ylim() methods:

plt.plot(x, np.sin(x))

plt.xlim(-1, 11)
plt.ylim(-1.5, 1.5);
../../_images/matplotlib_basics_13_0.png

Labeling Plots#

Titles and axis labels are the most straightforward such labels—some methods can be used to set them quickly:

plt.plot(x, np.sin(x))
plt.title("A Sine Curve")
plt.xlabel("x")
plt.ylabel("sin(x)");
../../_images/matplotlib_basics_15_0.png

When multiple lines are being shown within a single axes, it can be helpful to create a plot legend that labels each line type:

plt.plot(x, np.sin(x), '-g', label='sin(x)')
plt.plot(x, np.cos(x), ':b', label='cos(x)')

plt.legend();
../../_images/matplotlib_basics_17_0.png

Simple Scatter Plots#

Another commonly used plot type is the simple scatter plot. Instead of points being joined by line segments, the points are represented individually with a dot, circle, or another shape.

rng = np.random.RandomState(0)
x = rng.randn(100)
y = rng.randn(100)
colors = rng.rand(100)
sizes = 1000 * rng.rand(100)

plt.scatter(x, y, c=colors, s=sizes, alpha=0.3, cmap='viridis');
../../_images/matplotlib_basics_19_0.png

Grids of Subplots#

Let’s create an entire grid of subplots and return them in a NumPy array:

fig, ax = plt.subplots(2, 3, sharex='col', sharey='row')

# axes are in a two-dimensional array, indexed by [row, col]
for i in range(2):
    for j in range(3):
        ax[i, j].text(0.5, 0.5, str((i, j)), fontsize=18, ha='center')
../../_images/matplotlib_basics_21_0.png

Note that by specifying sharex and sharey, we’ve automatically removed inner labels on the grid to make the plot cleaner.

Histograms, Binnings, and Density#

A simple histogram can be a significant first step in understanding a dataset.

Let’s generate some data:

x1 = np.random.normal(0, 0.8, 1000)
x2 = np.random.normal(-2, 1, 1000)
x3 = np.random.normal(3, 2, 1000)

And use hist() function and assign several options to tune both the calculation and the display:

kwargs = dict(histtype='stepfilled', alpha=0.3, bins=40)

plt.hist(x1, **kwargs)
plt.hist(x2, **kwargs)
plt.hist(x3, **kwargs);
../../_images/matplotlib_basics_25_0.png

Let’s create histograms in two dimensions by dividing points among two-dimensional bins:

mean = [0, 0]
cov = [[1, 1], [1, 2]]
x, y = np.random.multivariate_normal(mean, cov, 10000).T

and visualize data as 2D histogram:

plt.hist2d(x, y, bins=30, cmap='Blues')
cb = plt.colorbar()
cb.set_label('counts in bin');
../../_images/matplotlib_basics_29_0.png

Seaborn#

Matplotlib API is relatively low level and is not designed to be used with Pandas. When Matplotlib comes short, Seaborn is for the rescue. Seaborn provides an API on top of Matplotlib that offers sane choices for plot style and color defaults, defines simple high-level functions for common statistical plot types, and integrates with the functionality provided by Pandas DataFrames.

We will use some standard shorthands for seaborn imports:

import seaborn as sns

# let's surpass some warnings with seaborn (there were some recent changes in the API)
import warnings
warnings.filterwarnings("ignore")

Histograms, KDE, and densities#

Often in statistical data visualization, all you want is to plot histograms and joint distributions of variables.

Let’s generate some data:

data = np.random.multivariate_normal([0, 0], [[5, 2], [2, 2]], size=2000)
data = pd.DataFrame(data, columns=['x', 'y'])

Rather than a histogram, we can get a smooth estimate of the distribution using a kernel density estimation, which Seaborn does with sns.kdeplot:

sns.kdeplot(data['x'], fill=True)
sns.kdeplot(data['y'], fill=True);
../../_images/matplotlib_basics_35_0.png

Histograms and KDE can be combined using distplot:

sns.histplot(data['x'], kde=True)
sns.histplot(data['y'], kde=True);
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
Cell In [19], line 1
----> 1 sns.histplot(data['x'], kde=True)
      2 sns.histplot(data['y'], kde=True)

File ~/miniconda3/envs/ds-academy-development/lib/python3.9/site-packages/seaborn/distributions.py:1418, in histplot(data, x, y, hue, weights, stat, bins, binwidth, binrange, discrete, cumulative, common_bins, common_norm, multiple, element, fill, shrink, kde, kde_kws, line_kws, thresh, pthresh, pmax, cbar, cbar_ax, cbar_kws, palette, hue_order, hue_norm, color, log_scale, legend, ax, **kwargs)
   1416     else:
   1417         method = ax.plot
-> 1418     color = _default_color(method, hue, color, kwargs)
   1420 if not p.has_xy_data:
   1421     return ax

File ~/miniconda3/envs/ds-academy-development/lib/python3.9/site-packages/seaborn/utils.py:139, in _default_color(method, hue, color, kws)
    134     scout.remove()
    136 elif method.__name__ == "bar":
    137 
    138     # bar() needs masked, not empty data, to generate a patch
--> 139     scout, = method([np.nan], [np.nan], **kws)
    140     color = to_rgb(scout.get_facecolor())
    141     scout.remove()

File ~/miniconda3/envs/ds-academy-development/lib/python3.9/site-packages/matplotlib/__init__.py:1423, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1420 @functools.wraps(func)
   1421 def inner(ax, *args, data=None, **kwargs):
   1422     if data is None:
-> 1423         return func(ax, *map(sanitize_sequence, args), **kwargs)
   1425     bound = new_sig.bind(ax, *args, **kwargs)
   1426     auto_label = (bound.arguments.get(label_namer)
   1427                   or bound.kwargs.get(label_namer))

File ~/miniconda3/envs/ds-academy-development/lib/python3.9/site-packages/matplotlib/axes/_axes.py:2373, in Axes.bar(self, x, height, width, bottom, align, **kwargs)
   2371 x0 = x
   2372 x = np.asarray(self.convert_xunits(x))
-> 2373 width = self._convert_dx(width, x0, x, self.convert_xunits)
   2374 if xerr is not None:
   2375     xerr = self._convert_dx(xerr, x0, x, self.convert_xunits)

File ~/miniconda3/envs/ds-academy-development/lib/python3.9/site-packages/matplotlib/axes/_axes.py:2182, in Axes._convert_dx(dx, x0, xconv, convert)
   2170 try:
   2171     # attempt to add the width to x0; this works for
   2172     # datetime+timedelta, for instance
   (...)
   2179     # removes the units from unit packages like `pint` that
   2180     # wrap numpy arrays.
   2181     try:
-> 2182         x0 = cbook._safe_first_finite(x0)
   2183     except (TypeError, IndexError, KeyError):
   2184         pass

File ~/miniconda3/envs/ds-academy-development/lib/python3.9/site-packages/matplotlib/cbook/__init__.py:1749, in _safe_first_finite(obj, skip_nonfinite)
   1746     raise RuntimeError("matplotlib does not "
   1747                        "support generators as input")
   1748 else:
-> 1749     return next(val for val in obj if safe_isfinite(val))

StopIteration: 
../../_images/matplotlib_basics_37_1.png

We can see the joint distribution and the marginal distributions together using:

sns.jointplot(x="x", y="y", data=data, kind='kde');
../../_images/matplotlib_basics_39_0.png

Pair Plot#

When you generalize joint plots to datasets of larger dimensions, you end up with pair plots. This is very useful for exploring correlations between multidimensional data when you’d like to plot all pairs of values against each other.

iris = sns.load_dataset("iris")
sns.pairplot(iris, hue='species');
../../_images/matplotlib_basics_41_0.png

Let’s revisit some other know visualizations which were shown the last time.

For example, let’s check species distribution in the dataset:

sns.countplot(x="species", data=iris);
../../_images/matplotlib_basics_43_0.png

Let’s measure correlations:

sns.heatmap(iris.corr(method="pearson"), annot=True);
../../_images/matplotlib_basics_45_0.png

Let’s compare the measurement distributions of the classes with box plots:

sns.boxplot(data=iris, y="sepal_length", x="species");
../../_images/matplotlib_basics_47_0.png

We can also use violin plots. Violin plots contain the same information as box plots, but also scale the box according to the density of the data:

sns.violinplot(data=iris, y="sepal_length", x="species")
<AxesSubplot: xlabel='species', ylabel='sepal_length'>
../../_images/matplotlib_basics_49_1.png

Faceted histograms#

Sometimes the best way to view data is via histograms of subsets. Seaborn’s FacetGrid makes this extremely simple.

tips = sns.load_dataset('tips')
tips['tip_pct'] = 100 * tips['tip'] / tips['total_bill']
grid = sns.FacetGrid(tips, row="sex", col="time", margin_titles=True)
grid.map(plt.hist, "tip_pct", bins=np.linspace(0, 40, 15));
../../_images/matplotlib_basics_51_0.png

Exercises#

Write a Python program to draw a line with suitable label in the x-axis, y-axis and a title#

# TODO: your answer here

Resources#

Van17

Jacob T. Vanderplas. Python Data Science Handbook: Essential Tools for working with data. O'Reilly, 2017.