Integration with matplotlib#
The drawing capabilities of graph-tool
(see draw
module) can be integrated with matplotlib,
as we demonstrate in the following.
Note
Integration with matplotlib
works with every backend, but vector
drawing only works with a
cairo-based backend (e.g. cairo
or
GTK3Cairo
). The backend can be changed by calling
matplotlib.pyplot.switch_backend()
:
import matplotlib.pyplot as plt
plt.switch_backend("cairo")
When using a backend not based on cairo, rasterization will be used instead. In
this case, the resolution can be controlled via the dpi
parameter of
matplotlib.figure.Figure
.
Drawing with matplotlib is done by calling graph_draw()
and passing a container (e.g. matplotlib.axes.Axes
) as the mplfig
parameter. When this option is passed, the function will return a
GraphArtist()
object that has been added to the figure.
Warning
Axis autoscaling will not
work with GraphArtist()
, so the axis limits need to be
set explicitly with matplotlib.axes.Axes.set_xlim()
and
matplotlib.axes.Axes.set_ylim()
.
More conveniently, GraphArtist()
offers a
fit_view()
that does this automatically.
Tip
When calling graph_draw()
without integrating with
matplotlib, the node positions correspond to cairo coordinates, which have an
origin in the upper left corner, and with the y axis increasing from top to
bottom.
In order for the visualization to be the same when matplotlib is being used,
the y axis needs to be flipped by inverting the limits with
matplotlib.axes.Axes.set_ylim()
.
Alternatively, the option yflip
can be passed to
graph_tool.draw.GraphArtist.fit_view()
for this to be done
automatically.
The example below shows how to plot several graphs in different subplots of the same figure.
import graph_tool.all as gt
import matplotlib.pyplot as plt
plt.switch_backend("cairo") # to enable vector drawing
fig, ax = plt.subplots(2, 2, figsize=(12, 11.5))
g = gt.collection.data["polbooks"]
a = gt.graph_draw(g, g.vp.pos, vertex_size=1.5, mplfig=ax[0,0])
a.fit_view(yflip=True)
ax[0,0].set_xlabel("$x$ coordinate")
ax[0,0].set_ylabel("$y$ coordinate")
state = gt.minimize_nested_blockmodel_dl(g)
a = state.draw(mplfig=ax[0,1])[0]
a.fit_view(yflip=True)
ax[0,1].set_xlabel("$x$ coordinate")
ax[0,1].set_ylabel("$y$ coordinate")
g = gt.collection.data["lesmis"]
a = gt.graph_draw(g, g.vp.pos, vertex_size=1.5, mplfig=ax[1,0])
a.fit_view(yflip=True)
ax[1,0].set_xlabel("$x$ coordinate")
ax[1,0].set_ylabel("$y$ coordinate")
state = gt.minimize_nested_blockmodel_dl(g)
a = state.draw(mplfig=ax[1,1])[0]
a.fit_view(yflip=True)
ax[1,1].set_xlabel("$x$ coordinate")
ax[1,1].set_ylabel("$y$ coordinate")
plt.subplots_adjust(left=0.08, right=0.99, top=0.99, bottom=0.06)
fig.savefig("gt-mpl.svg")
Integration with basemap
#
As a slightly more elaborate example, below we show how we can draw the
European airline graph on a map using
mpl_toolkits.basemap
.
from itertools import chain
from mpl_toolkits.basemap import Basemap
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
g = gt.collection.ns["eu_airlines"]
pos = gt.group_vector_property([g.vp.nodeLong, g.vp.nodeLat])
m = Basemap(projection='ortho', resolution=None,
lat_0=g.vp.nodeLat.fa.mean(), lon_0=g.vp.nodeLong.fa.mean())
m.shadedrelief(scale=.2)
lats = m.drawparallels(np.linspace(-90, 90, 13))
lons = m.drawmeridians(np.linspace(-180, 180, 13))
lat_lines = chain(*(tup[1][0] for tup in lats.items()))
lon_lines = chain(*(tup[1][0] for tup in lons.items()))
all_lines = chain(lat_lines, lon_lines)
for line in all_lines:
line.set(linestyle='-', alpha=0.3, color='w')
a = gt.graph_draw(g, pos=pos.t(lambda x: m(*x)), # project positions
edge_color=(.1,.1,.1,.1), mplfig=ax)
a.fit_view()
a.set_zorder(10)
tight_layout()
fig.savefig("gt-map.svg")