≡ Menu

Plot 3D Functions With Matplotlib and NumPy

Math is beautiful, and the software we have at our fingertips represents an invaluable way to explore that beauty. Never before have we been able to visualize and experiment with mathematical objects in such an accessible way.

In this short tutorial I would like to break down the steps required to plot a function of two variables using Python.

Along the way, we’ll learn about a few NumPy procedures that’ll have you manipulating matrices like a wizard. I’ve broken the process down into three conceptual steps that’ll also help refresh the underlying math.

Finally, I put everything together into a function you can use out of the box, so feel free to skip to that if it’s all you need.

Define the Function Domain

A mathematical function is a map that takes elements from one set and associates them with one element of another set. This first set of inputs is called the domain of the function. In our case, the domain will consist of tuples of real numbers.

Although there are infinitely many real numbers inside any interval, we obviously can’t store an infinite set for our domain. For our plot to look nice, it’s sufficient to sample enough points within our domain so the end product will look smooth and not unnaturally jagged.

With this in mind, we can define our domain and store it in a set of arrays in three steps.

  1. Decide on the boundaries for each of the two variables in our domain:
x_interval = (-2, 2)
y_interval = (-2, 2)

2. Sample points within each of these intervals:

x_points = np.linspace(x_interval[0], x_interval[1], 100)
y_points = np.linspace(y_interval[0], y_interval[1], 100)

3. Take the Cartesian product of these two sampled sets to produce two arrays that (when stacked) form a set of ordered pairs we can compute a function on:

X, Y = np.meshgrid(x_points, y_points)

The next step is to associate an output value with every point in our input domain. For the purpose, we define our math function as a Python function of two scalar inputs:

def func3d(x, y):
    return -np.sin(10 * (x**2 + y**2)) / 10

Then we produce a vectorized version of the function that can be called on vectors or matrices of inputs:

func3d_vectorized = np.vectorize(func3d)

Plot the Function

From this point, things proceed in nearly the same way as they would in making a 2D plot with Matplotlib. Only a few argument and method names need to change in order to produce beautiful 3D visualizations.

  1. Set up a plotting figure and axes with projection='3d':
plt.figure(figsize=(20, 10))
ax = plt.axes(projection=’3d’)

2. Select a plotting method of the axes object and call it on our function data:

ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
                cmap=’terrain’, edgecolor=None)

3. Set other attributes of the plot, such as the title and axis labels:

ax.set(xlabel=”x”, ylabel=”y”, zlabel=”f(x, y)”, 
       title=”Cool Function”)

To make the process more reproducible, I’ve packaged all these steps together into a Python function for producing quick surface plots.

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d

def plot_surface(domain, fn, grid_samples=100, title=None, **plot_kwargs):
    x = np.linspace(domain[0][0], domain[0][1], grid_samples)
    y = np.linspace(domain[1][0], domain[1][1], grid_samples)
    X, Y = np.meshgrid(x, y)
    fn_vectorized = np.vectorize(fn)
    Z = fn_vectorized(X, Y)
    fig = plt.figure(figsize=(20,10))
    ax = plt.axes(projection="3d")
    ax.plot_surface(X, Y, Z, **plot_kwargs)
    ax.set(xlabel="x", ylabel="y", zlabel="f(x, y)", title=title)

    return fig, ax

# now let's try it out!
def func(x, y):
    return -np.sin(10 * (x**2 + y**2)) / 10
domain = [(-0.5, 0.5), (-0.5, 0.5)] 
fig, ax = plot_surface(domain, func, rstride=1, cstride=1, cmap='terrain', edgecolor=None)
{ 0 comments… add one }