Source code for pyrolite.util.plot.axes

"""
Functions for creating, ordering and modifying :class:`~matplolib.axes.Axes`.
"""
import warnings

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable

from ..log import Handle
from ..meta import subkwargs

logger = Handle(__name__)


[docs]def get_ordered_axes(fig): """ Get the axes from a figure, which may or may not have been modified by pyrolite functions. This ensures that ordering is preserved. """ if hasattr(fig, "orderedaxes"): # previously modified axes = fig.orderedaxes else: # unmodified axes axes = fig.axes return axes
[docs]def get_axes_index(ax): """ Get the three-digit integer index of a subplot in a regular grid. Parameters ----------- ax : :class:`matplotlib.axes.Axes` Axis to to get the gridspec index for. Returns ----------- :class:`tuple` Rows, columns and axis index for the gridspec. """ nrow, ncol = ax.get_gridspec()._nrows, ax.get_gridspec()._ncols index = get_ordered_axes(ax.figure).index(ax) triple = nrow, ncol, index + 1 return triple
[docs]def replace_with_ternary_axis(ax): """ Replace a specified axis with a ternary equivalent. Parameters ------------ ax : :class:`~matplotlib.axes.Axes` Returns ------------ tax : :class:`~mpltern.ternary.TernaryAxes` """ if ax.name != "ternary": if not check_default_axes(ax): if not check_empty(ax): warnings.warn( "Non-empty, non-default bivariate axes being replaced with ternary axes." ) else: logger.info( "Non-default bivraite axes being replaced with ternary axes." ) fig = ax.figure axes = get_ordered_axes(fig) idx = axes.index(ax) tax = fig.add_subplot(*get_axes_index(ax), projection="ternary") fig.add_axes(tax) # make sure the axis is added to fig.children fig.delaxes(ax) # remove the original axes # update figure ordered axes fig.orderedaxes = [a if ix != idx else tax for (ix, a) in enumerate(axes)] return tax
[docs]def label_axes(ax, labels=[], **kwargs): """ Convenience function for labelling rectilinear and ternary axes. Parameters ----------- ax : :class:`~matplotlib.axes.Axes` Axes to label. labels : :class:`list` List of labels: [x, y] | or [t, l, r] """ if (ax.name == "ternary") and (len(labels) == 3): tvar, lvar, rvar = labels ax.set_tlabel(tvar, **kwargs) ax.set_llabel(lvar, **kwargs) ax.set_rlabel(rvar, **kwargs) elif len(labels) == 2: xvar, yvar = labels ax.set_xlabel(xvar, **kwargs) ax.set_ylabel(yvar, **kwargs) else: raise NotImplementedError
[docs]def axes_to_ternary(ax): """ Set axes to ternary projection after axis creation. As currently implemented, note that this will replace and reorder axes as acecessed from the figure (the ternary axis will then be at the end), and as such this returns a list of axes in the correct order. Parameters ----------- ax : :class:`~matplotlib.axes.Axes` | :class:`list` (:class:`~matplotlib.axes.Axes`) Axis (or axes) to convert projection for. Returns --------- axes : :class:`list' (:class:`~matplotlib.axes.Axes`, class:`~mpltern.ternary.TernaryAxes`) """ if isinstance(ax, (list, np.ndarray, tuple)): # multiple Axes specified fig = ax[0].figure for a in ax: # axes to set to ternary replace_with_ternary_axis(a) else: # a single Axes is passed fig = ax.figure replace_with_ternary_axis(ax) return fig.orderedaxes
[docs]def check_default_axes(ax): """ Simple test to check whether an axis is empty of artists and hasn't been rescaled from the default extent. Parameters ----------- ax : :class:`matplotlib.axes.Axes` Axes to check for artists and scaling. Returns ------- :class:`bool` """ if np.allclose(ax.axis(), np.array([0, 1, 0, 1])): return check_empty(ax) else: return False
[docs]def check_empty(ax): """ Simple test to check whether an axis is empty of artists. Parameters ----------- ax : :class:`matplotlib.axes.Axes` Axes to check for artists. Returns ------- :class:`bool` """ if not (ax.lines + ax.collections + ax.patches + ax.artists + ax.texts + ax.images): return True else: return False
[docs]def init_axes(ax=None, projection=None, minsize=1.0, **kwargs): """ Get or create an Axes from an optionally-specified starting Axes. Parameters ----------- ax : :class:`~matplotlib.axes.Axes` Specified starting axes, optional. projection : :class:`str` Whether to create a projected (e.g. ternary) axes. minsize : :class:`float` Minimum figure dimension (inches). Returns -------- ax : :class:`~matplotlib.axes.Axes` """ if "figsize" in kwargs.keys(): fs = kwargs["figsize"] kwargs["figsize"] = ( max(fs[0], minsize), max(fs[1], minsize), ) # minimum figsize if projection is not None: # e.g. ternary if ax is None: fig, ax = plt.subplots( 1, subplot_kw=dict(projection=projection), **subkwargs(kwargs, plt.subplots, plt.figure) ) else: # axes passed if ax.name != "ternary": # if an axis is converted previously, but the original axes reference # is used again, we'll end up with an error current_axes = get_ordered_axes(ax.figure) try: ix = current_axes.index(ax) axes = axes_to_ternary(ax) # returns list of axes ax = axes[ix] except ValueError: # ax is not in list # ASSUMPTION due to mis-referencing: # take the first ternary one ax = [a for a in current_axes if a.name == "ternary"][0] else: pass else: if ax is None: fig, ax = plt.subplots(1, **subkwargs(kwargs, plt.subplots, plt.figure)) return ax
[docs]def share_axes(axes, which="xy"): """ Link the x, y or both axes across a group of :class:`~matplotlib.axes.Axes`. Parameters ----------- axes : :class:`list` List of axes to link. which : :class:`str` Which axes to link. If :code:`x`, link the x-axes; if :code:`y` link the y-axes, otherwise link both. """ if which == "both": which = "xy" if "x" in which: [a.sharex(axes[0]) for a in axes[1:]] if "y" in which: [a.sharey(axes[0]) for a in axes[1:]]
[docs]def get_twins(ax, which="y"): """ Get twin axes of a specified axis. Parameters ----------- ax : :class:`matplotlib.axes.Axes` Axes to get twins for. which : :class:`str` Which twins to get (shared :code:`'x'`, shared :code:`'y'` or the concatenatation of both, :code:`'xy'`). Returns -------- :class:`list` Notes ------ This function was designed to assist in avoiding creating a series of duplicate axes when replotting on an existing axis using a function which would typically create a twin axis. """ s = [] if "y" in which: s += ax.get_shared_y_axes().get_siblings(ax) if "x" in which: s += ax.get_shared_x_axes().get_siblings(ax) return list( set([a for a in s if (a is not ax) & (a.bbox.bounds == ax.bbox.bounds)]) )
[docs]def subaxes(ax, side="bottom", width=0.2, moveticks=True): """ Append a sub-axes to one side of an axes. Parameters ----------- ax : :class:`matplotlib.axes.Axes` Axes to append a sub-axes to. side : :class:`str` Which side to append the axes on. width : :class:`float` Fraction of width to give to the subaxes. moveticks : :class:`bool` Whether to move ticks to the outer axes. Returns ------- :class:`matplotlib.axes.Axes` Subaxes instance. """ div = make_axes_locatable(ax) ax.divider = div if side in ["bottom", "top"]: which = "x" subax = div.append_axes(side, width, pad=0, sharex=ax) div.subax = subax subax.yaxis.set_visible(False) subax.spines["left"].set_visible(False) subax.spines["right"].set_visible(False) else: which = "y" subax = div.append_axes(side, width, pad=0, sharex=ax) div.subax = subax subax.yaxis.set_visible(False) subax.spines["top"].set_visible(False) subax.spines["bottom"].set_visible(False) share_axes([ax, subax], which=which) if moveticks: ax.tick_params( axis=which, which="both", bottom=False, top=False, labelbottom=False ) return subax
[docs]def add_colorbar(mappable, **kwargs): """ Adds a colorbar to a given mappable object. Source: http://joseph-long.com/writing/colorbars/ Parameters ---------- mappable The Image, ContourSet, etc. to which the colorbar applies. Returns ------- :class:`matplotlib.colorbar.Colorbar` Todo ---- * Where no mappable specificed, get most recent axes, and check for collections etc """ ax = kwargs.get("ax", None) if hasattr(mappable, "axes"): ax = ax or mappable.axes elif hasattr(mappable, "ax"): ax = ax or mappable.ax position = kwargs.pop("position", "right") size = kwargs.pop("size", "5%") pad = kwargs.pop("pad", 0.05) fig = ax.figure if ax.name == "ternary": cax = ax.inset_axes([1.05, 0.1, 0.05, 0.9], transform=ax.transAxes) colorbar = fig.colorbar(mappable, cax=cax, **kwargs) else: divider = make_axes_locatable(ax) cax = divider.append_axes(position, size=size, pad=pad) colorbar = fig.colorbar(mappable, cax=cax, **kwargs) return colorbar