Source code for das.utils_plot

"""Plot utilities."""
import matplotlib.pyplot as plt
from matplotlib_scalebar.scalebar import ScaleBar
import matplotlib_scalebar
import matplotlib as mpl
import matplotlib.legend as mlegend
from matplotlib.patches import Rectangle
from matplotlib.backends.backend_pdf import PdfPages
import colorcet
import numpy as np
import itertools
from typing import Optional, List


[docs]def scalebar(length, dx=1, units='', label=None, axis=None, location='lower right', frameon=False, **kwargs): """Add scalebar to axis. Usage: plt.subplot(122) plt.plot([0,1,2,3], [1,4,2,5]) scalebar(0.5, 'femtoseconds', label='duration', location='lower right’) Args: length (float): Length of the scalebar in units of axis ticks - length of 1.0 corresponds to spacing between to major x-ticks dx (int, optional): Scale factor for length. E.g. if scale factor is 10, the scalebar of length 1.0 will span 10 ticks. Defaults to 1. units (str, optional): Unit label (e.g. 'milliseconds'). Defaults to ''. label (str, optional): Title for scale bar (e.g. 'Duration'). Defaults to None. axis (matplotlib.axes.Axes, optional): Axes to add scalebar to. Defaults to None (currently active axis - plt.gca()). location (str, optional): Where in the axes to put the scalebar (upper/lower/'', left/right/center). Defaults to 'lower right'. frameon (bool, optional): Add background (True) or not (False). Defaults to False. kwargs: location=None, pad=None, border_pad=None, sep=None, frameon=None, color=None, box_color=None, box_alpha=None, scale_loc=None, label_loc=None, font_properties=None, label_formatter=None, animated=False): Returns: Handle to scalebar object """ if axis is None: axis = plt.gca() if 'dimension' not in kwargs: kwargs['dimension'] = matplotlib_scalebar.dimension._Dimension(units) scalebar = ScaleBar(dx=dx, units=units, label=label, fixed_value=length, location=location, frameon=frameon, **kwargs) axis.add_artist(scalebar) return scalebar
[docs]def remove_axes(axis=None, all=False, which='tblr'): """Remove top & left border around plot or all axes & ticks. Args: axis (matplotlib.axes.Axes, optional): Axes to modify. Defaults to None (currently active axis - plt.gca()). all (bool, optional): Remove all axes & ticks (True) or top & left border only (False). Defaults to False. """ if axis is None: axis = plt.gca() # Hide the right and top spines axis.spines['right'].set_visible(False) axis.spines['top'].set_visible(False) # Only show ticks on the left and bottom spines axis.yaxis.set_ticks_position('left') axis.xaxis.set_ticks_position('bottom') if all: # Hide the left and bottom spines axis.spines['left'].set_visible(False) axis.spines['bottom'].set_visible(False) # Remove all tick labels axis.yaxis.set_ticks([]) axis.xaxis.set_ticks([])
def despine(which='tr', axis=None): sides = {'t': 'top', 'b': 'bottom', 'l': 'left', 'r': 'right'} if axis is None: axis = plt.gca() # Hide the spines for side in which: axis.spines[sides[side]].set_visible(False) # Hide the tick marks and labels if 'r' in which: axis.yaxis.set_ticks_position('left') if 't' in which: axis.xaxis.set_ticks_position('bottom') if 'l' in which: axis.yaxis.set_ticks([]) if 'b' in which: axis.xaxis.set_ticks([]) import string from itertools import cycle
[docs]def label_axes(fig=None, labels=None, loc=None, **kwargs): """ Walks through axes and labels each. kwargs are collected and passed to `annotate` Parameters ---------- fig : Figure Figure object to work on labels : iterable or None iterable of strings to use to label the axes. If None, lower case letters are used. loc : len=2 tuple of floats Where to put the label in axes-fraction units """ if fig is None: fig = plt.gcf() if labels is None: labels = string.ascii_uppercase if 'size' not in kwargs: kwargs['size'] = 13 # if 'weight' not in kwargs: # kwargs['weight'] = 'bold' # re-use labels rather than stop labeling labels = cycle(labels) if loc is None: loc = (-0.2, .9) for ax, lab in zip(fig.axes, labels): ax.annotate(lab, xy=loc, xycoords='axes fraction', **kwargs)
[docs]def downsample_plot(x,y, ds=20): """Reduces complexity of exported pdfs w/o impairing visual appearance. Modified from pyqtgraph - downsampleMethod=peak. Keeps peaks so the envelope of the waveform is preserved. """ n = len(x) // ds x1 = np.empty((n,2)) x1[:] = x[:n*ds:ds,np.newaxis] x0 = x1.reshape(n*2) y1 = np.empty((n,2)) y2 = y[:n*ds].reshape((n, ds)) y1[:,0] = y2.max(axis=1) y1[:,1] = y2.min(axis=1) y0 = y1.reshape(n*2) return x0, y0
[docs]class Pdf: """Thin wrapper around Autosaving variant""" def __init__(self, savename, autosave=True, style=None, **savefig_kws): """[summary] Args: savename ([type]): [description] autosave (bool, optional): [description]. Defaults to True. style ([type], optional): [description]. Defaults to None. """ self.pdf = PdfPages(savename) self.autosave = autosave self.style = style self.orig = None self.savefig_kws = savefig_kws def __enter__(self): if self.style is not None: self.orig = mpl.rcParams.copy() mpl.style.use(self.style) self.pdf.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): if self.autosave: self.pdf.savefig(**self.savefig_kws) if self.orig is not None: dict.update(mpl.rcParams, self.orig) self.pdf.__exit__(exc_type, exc_val, exc_tb)
# from https://stackoverflow.com/a/60345118/2301098
[docs]def tablelegend(ax, col_labels=None, row_labels=None, title_label="", *args, **kwargs): """ Place a table legend on the axes. Creates a legend where the labels are not directly placed with the artists, but are used as row and column headers, looking like this: """ # ``` # title_label | col_labels[1] | col_labels[2] | col_labels[3] # ------------------------------------------------------------- # row_labels[1] | # row_labels[2] | <artists go there> # row_labels[3] | # ``` # Parameters # ---------- # ax : `matplotlib.axes.Axes` # The artist that contains the legend table, i.e. current axes instant. # col_labels : list of str, optional # A list of labels to be used as column headers in the legend table. # `len(col_labels)` needs to match `ncol`. # row_labels : list of str, optional # A list of labels to be used as row headers in the legend table. # `len(row_labels)` needs to match `len(handles) // ncol`. # title_label : str, optional # Label for the top left corner in the legend table. # ncol : int # Number of columns. # Other Parameters # ---------------- # Refer to `matplotlib.legend.Legend` for other parameters. # """ #################### same as `matplotlib.axes.Axes.legend` ##################### handles, labels, extra_args, kwargs = mlegend._parse_legend_args([ax], *args, **kwargs) if len(extra_args): raise TypeError('legend only accepts two non-keyword arguments') if col_labels is None and row_labels is None: ax.legend_ = mlegend.Legend(ax, handles, labels, **kwargs) ax.legend_._remove_method = ax._remove_legend return ax.legend_ #################### modifications for table legend ############################ else: ncol = kwargs.pop('ncol') handletextpad = kwargs.pop('handletextpad', 0 if col_labels is None else -2) title_label = [title_label] # blank rectangle handle extra = [Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)] # empty label empty = [""] # number of rows infered from number of handles and desired number of columns nrow = len(handles) // ncol # organise the list of handles and labels for table construction if col_labels is None: assert nrow == len(row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (nrow, len(row_labels)) leg_handles = extra * nrow leg_labels = row_labels elif row_labels is None: assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (ncol, len(col_labels)) leg_handles = [] leg_labels = [] else: assert nrow == len(row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (nrow, len(row_labels)) assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (ncol, len(col_labels)) leg_handles = extra + extra * nrow leg_labels = title_label + row_labels for col in range(ncol): if col_labels is not None: leg_handles += extra leg_labels += [col_labels[col]] leg_handles += handles[col*nrow:(col+1)*nrow] leg_labels += empty * nrow # Create legend ax.legend_ = mlegend.Legend(ax, leg_handles, leg_labels, ncol=ncol+int(row_labels is not None), handletextpad=handletextpad, **kwargs) ax.legend_._remove_method = ax._remove_legend return ax.legend_
[docs]def imshow_text(data, labels=None, ax=None, color_high='w', color_low='k', color_threshold=50, skip_zeros=False): """Text labels for individual cells of an imshow plot Args: data ([type]): Color values labels ([type], optional): Text labels. Defaults to None. ax ([type], optional): axis. Defaults to plt.gca(). color_high (str, optional): [description]. Defaults to 'w'. color_low (str, optional): [description]. Defaults to 'k'. color_threshold (int, optional): [description]. Defaults to 50. skip_zeros (bool, optional): [description]. Defaults to False. """ if ax is None: ax = plt.gca() if labels is None: labels = data for x, y in itertools.product(range(data.shape[0]), range(data.shape[1])): if labels[y, x] == 0: continue ax.text(x, y, f'{labels[y, x]:1.0f}', ha='center', va='center', c=color_high if data[y, x]>color_threshold else color_low)
[docs]def bar_text(ax=None, spacing=-20, to_int=True): """Add labels to the end of each bar in a bar chart. Arguments: ax (matplotlib.axes.Axes): The matplotlib object containing the axes of the plot to annotate. Defaults to current axis. spacing (int): The distance between the labels and the bars. Defaults to -20 (inside bar) """ if ax is None: ax = plt.gca() # For each bar: Place a label for rect in ax.patches: # Get X and Y placement of label from rect. y_value = rect.get_height() x_value = rect.get_x() + rect.get_width() / 2 # Number of points between bar and label. Change to your liking. space = spacing # Vertical alignment for positive values va = 'bottom' # If value of bar is negative: Place label below bar if y_value < 0: # Invert space to place label below space *= -1 # Vertically align label at top va = 'top' # Use Y value as label and format number with one decimal place if to_int: label = int(y_value) else: label = float(y_value) # Create annotation ax.annotate( label, # Use `label` as label (x_value, y_value), # Place label at end of the bar xytext=(0, space), # Vertically shift label by `space` textcoords="offset points", # Interpret `xytext` as offset in points ha='center', # Horizontally center label va=va, color='w') # Vertically align label differently for
# positive and negative values.
[docs]def generate_colors(nb_colors: int = 1, start_color=None, start=0, step=1): """[summary] Args: nb_colors (int, optional): [description]. Defaults to 1. start_color ([type], optional): [description]. Defaults to None. start (int, optional): [description]. Defaults to 0. step (int, optional): [description]. Defaults to 1. Returns: [type]: [description] """ cmap = colorcet.palette['glasbey_light'][start::step] cmap = list(cmap)[:nb_colors] if start_color is not None: cmap.insert(0, start_color) return cmap
[docs]def annotate_events(event_seconds, event_names=None, tmin: float = 0, tmax: float = np.inf, color: Optional[List] = None): """Plot events as vertical lines to plot. Args: event_seconds (List[float]): List of event times in seconds event_names (List[str], optional): List of event names. Defaults to None. tmin (float, optional): Start of the range of events annotated in seconds. Defaults to 0. tmax (float, optional): End of the range of events annotated in seconds. Defaults to np.inf. color (optional): Either a single valid matplotlib color (in that case all segment types will have that color) or a list of matplotlib colors, one for each segment type. Defaults to None, in which case a list of distinct colors will be created automatically from colorcet's 'glasbey_light' palette. Raises: ValueError: If the number of colors does not match the number of unique event names. """ unique_names = list(set(event_names)) if color is None: # generate a color for each event name color = generate_colors(len(unique_names)) elif not isinstance(color, (list, tuple)): # use given color for each event name color = [color for _ in unique_names] event_seen = {name: False for name in unique_names} for seconds, name in zip(event_seconds, event_names): if seconds > tmin and seconds < tmax: if not event_seen[name]: label = name event_seen[name] = True else: label = None plt.axvline(seconds, c=color[unique_names.index(name)], alpha=0.5, label=label)
[docs]def annotate_segments(onset_seconds, offset_seconds, segment_names=None, tmin: float = 0, tmax: float = np.inf, color=None): """Plot segments as vertical lines to plot. Args: onset_seconds ([type]): [description] offset_seconds ([type]): [description] segment_names ([type], optional): [description]. Defaults to None. tmin (float, optional): [description]. Defaults to 0. tmax (float, optional): [description]. Defaults to np.inf. color (optional): Either a single valid matplotlib color (in that case all segment types will have that color) or a list of matplotlib colors, one for each segment type. Defaults to None, in which case a list of distinct colors will be created automatically from colorcet's 'glasbey_light' palette. Raises: ValueError: If the number of colors does not match the number of unique segment names. """ unique_names = list(set(segment_names)) if color is None: # generate a color for each event name color = generate_colors(len(unique_names)) elif not isinstance(color, (list, tuple)): # use given color for each event name color = [color for _ in unique_names] if len(color) != len(unique_names): raise ValueError(f"Number of colors ({len(color)} does not match number of segment types ({len(unique_names)}).") segment_seen = {name: False for name in unique_names} for on, off, name in zip(onset_seconds, offset_seconds, segment_names): if on > tmin and off < tmax: if not segment_seen[name]: label = name segment_seen[name] = True else: label = None plt.axvspan(xmin=on, xmax=off, facecolor=color[unique_names.index(name)], alpha=0.25, label=label)