Source code for tdg.plot.ScatterMatrix

#!/usr/bin/env python3

import numpy as np
import torch
from itertools import product

import matplotlib.pyplot as plt

[docs]class ScatterMatrix: r''' Different observables are correlated as a function of Markov Chain time, because they are measured on the same state. We can visualize the correlations between many different observables in a grid, each panel a two-dimensional projection of the many-dimensional space of observable values. Parameters ---------- fields: int Number of rows and columns. labels: iterable of strings of length :code:`fields` Names for the different axes that will correspond to the plotted fields. wspace, hspace: float [inches] White space between panels. kwargs: Forwarded to ``matplotlib.pyplot.subplots`` .. note:: If you prefer more whitespace, consider a :class:`ScatterTriangle` over a :class:`~.ScatterMatrix`. It has the same interface. ''' def __init__(self, fields=2, labels=None, wspace=0.05, hspace=0.05, **kwargs ): width_ratios = np.ones(fields) height_ratios = np.ones(fields) self.fig, self.grid = plt.subplots( fields, fields, gridspec_kw={ 'width_ratios': width_ratios, 'height_ratios': height_ratios, 'wspace': wspace, 'hspace': hspace }, **kwargs ) grid = self.grid for (i, j) in product(range(fields), range(fields)): # Share axes if i == 0 and j == 0: grid[i,j].sharex(grid[1,0]) elif i == j: grid[i,j].sharex(grid[0,j]) elif i == 0: grid[i,j].sharey(grid[0,1]) else: grid[i,j].sharex(grid[0,j]) grid[i,j].sharey(grid[i,0]) # Only allow ticks on the left, bottom frames. if j != 0: [label.set_visible(False) for label in grid[i,j].get_yticklabels()] if i != fields-1: [label.set_visible(False) for label in grid[i,j].get_xticklabels()] # Labels if labels is not None and len(labels) == fields: if j == 0: grid[i,j].set_ylabel(labels[i]) if i == fields-1: grid[i,j].set_xlabel(labels[j])
[docs] def plot(self, data, label=None, density=True, scatter_alpha=0.1, histogram_alpha=0.5, bins=31, color=None, **kwargs): r''' Parameters ---------- data: iterable of length fields density: bool Should the histograms be normalized? scatter_alpha: float Transparency of plotted points. histogram_alpha: float Transparency of the histograms. bins: int Number of bins in each histogram. color: Forwarded `matplotlib color <https://matplotlib.org/stable/tutorials/colors/colors.html>`_. kwargs: Currently ignored. ''' d = tuple(d.clone().detach().cpu().numpy() if isinstance(d, torch.Tensor) else d for d in data) for ((i, y), (j, x)) in product(enumerate(d), enumerate(d)): if i != j: self.grid[i,j].scatter(x,y, color=color, alpha=scatter_alpha, edgecolors='none', ) else: self.grid[i,j].hist( x, label=label, orientation='vertical', bins=bins, density=density, color=color, alpha=histogram_alpha, )
class ScatterTriangle: r''' If you prefer more whitespace, consider a ScatterTriangle over a ScatterMatrix. Parameters ---------- fields: int Number of rows and columns. wspace, hspace: float [inches] White space between panels. kwargs: Forwarded to ``matplotlib.pyplot.subplots`` ''' def __init__(self, fields=2, labels=None, wspace=0.05, hspace=0.05, **kwargs ): width_ratios = np.ones(fields) height_ratios = np.ones(fields) self.fig, self.grid = plt.subplots( fields, fields, gridspec_kw={ 'width_ratios': width_ratios, 'height_ratios': height_ratios, 'wspace': wspace, 'hspace': hspace }, **kwargs ) grid = self.grid for (i, j) in product(range(fields), range(fields)): if j > i: grid[i,j].axis('off') continue # Share axes if i == 0 and j == 0: grid[i,j].sharex(grid[1,0]) elif i == j: grid[i,j].sharey(grid[0,j]) elif i == 0: grid[i,j].sharey(grid[0,1]) else: grid[i,j].sharex(grid[0,j]) grid[i,j].sharey(grid[i,0]) # Only allow ticks on the left, bottom frames. if j != 0: [label.set_visible(False) for label in grid[i,j].get_yticklabels()] if i < fields-1 or j == fields-1: [label.set_visible(False) for label in grid[i,j].get_xticklabels()] grid[0,0].set_yticks([]) # Labels if labels is not None and len(labels) == fields: if j == 0 and i > 0: grid[i,j].set_ylabel(labels[i]) if i == fields-1 and j < i: grid[i,j].set_xlabel(labels[j]) def plot(self, data, label=None, density=True, scatter_alpha=0.1, histogram_alpha=0.5, bins=31, color=None, **kwargs): r''' Parameters ---------- data: iterable of length fields density: bool Should the histograms be normalized? scatter_alpha: float Transparency of plotted points. histogram_alpha: float Transparency of the histograms. bins: int Number of bins in each histogram. color: Forwarded `matplotlib color <https://matplotlib.org/stable/tutorials/colors/colors.html>`_. kwargs: Currently ignored. ''' d = tuple(d.clone().detach().cpu().numpy() if isinstance(d, torch.Tensor) else d for d in data) for ((i, y), (j, x)) in product(enumerate(d), enumerate(d)): if j > i: continue if i != j: self.grid[i,j].scatter(x,y, color=color, alpha=scatter_alpha, edgecolors='none', ) else: self.grid[i,j].hist( x, label=label, orientation=('vertical' if i ==0 else 'horizontal'), bins=bins, density=density, color=color, alpha=histogram_alpha, )