Commit 985d587e authored by Simon Sebastian Humpohl's avatar Simon Sebastian Humpohl
Browse files

Merge branch 'simon' into 'master'

Add plotting util functions

See merge request !1
parents d299a286 dd639bb1
......@@ -11,5 +11,9 @@ python develop
This will link the files into your environment instead of copying them. If you are on windows you can use [SourceTree]( which is a nice GUI for git.
# qutil.plotting
`cycle_plots` helps you cycling through many plots with the arrow keys (there are probably much better functions for this out there)
`plot_2d_dataframe` helps you plot 2d data frames with numeric indices
## qutil.matlab
In this module there are functions that are helpful for reading `.mat` files, especially those created with special measure. If you simply want to open a random `.mat` file you can use `hdf5storage.loadmat`.
"""This module contains some useful plotting functions"""
from weakref import WeakValueDictionary
from typing import Tuple
import warnings
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
__all__ = ["plot_2d_dataframe"]
def _to_corner_coords(index: pd.Index) -> np.ndarray:
"""Helper function to transform N coordinates pointing at centers of bins to N+1 coords pointing to the edges"""
coords = index.values
delta = coords[-1] - coords[-2]
return np.concatenate((coords, [coords[-1] + delta])) - delta / 2
def _data_hash(*args: np.ndarray) -> int:
return hash(tuple(arg.tobytes() for arg in args))
def plot_2d_dataframe(df: pd.DataFrame,
ax: plt.Axes = None, square=True,
column=None, index_is_y=True,
update_mode: str = 'auto') -> plt.Axes:
"""Plot pandas data frames using pcolormesh. This function expects numeric labels and it can update an existing
plot. Have a look at seaborn.heatmap if you need something else.
'auto': 'rescale' if x-label, y-label, and title are the same else 'clear'
'clear': Clear axis before drawing
'overwrite': Just plot new data frame on top (no colorbar is drawn)
'rescale': Recalculate and redraw the colorbar
- plotted meshes are stored in ax.meshes
- The colorbar is stored in ax.custom_colorbar (DO NOT RELY ON THIS)
- If the plotted data is already present we just shift it to the top using set_zorder
- Uses _data_hash(x, y, c) to identify previously plotted data
:param df: pandas dataframe to plot
:param ax: Axes object
:param square:
:param column: Select this column from the dataframe and unstack the index
:param index_is_y: If true the index are on the y-axis and the columns on the x-axis
:param update_mode: 'auto', 'overwrite' or 'rescale'
if ax is None:
ax = plt.gca()
if square:
if column is None and len(df.columns) == 1 and len(df.index.levshape) == 2:
column = df.columns[0]
if column is not None:
title = column
series = df[column]
df = series.unstack()
title = None
c = df.values
x_idx = df.columns
y_idx = df.index
if not index_is_y:
c = np.transpose(c)
x_idx, y_idx = y_idx, x_idx
x_label =
y_label =
if update_mode == 'auto':
if (x_label, y_label, title) == (ax.get_xlabel(), ax.get_ylabel(), ax.get_title()):
update_mode = 'rescale'
update_mode = 'clear'
if update_mode not in ('clear', 'rescale', 'overwrite'):
raise ValueError('%s is an invalid value for update_mode' % update_mode)
if update_mode == 'clear':
if hasattr(ax, 'custom_colorbar'):
# clear colorbar axis
ax.meshes = WeakValueDictionary()
y = _to_corner_coords(y_idx)
x = _to_corner_coords(x_idx)
if not hasattr(ax, 'meshes'):
ax.meshes = WeakValueDictionary()
df_hash = _data_hash(x, y, c)
current_mesh = ax.meshes.get(df_hash, None)
if current_mesh is None:
# data not yet drawn -> draw it
current_mesh = ax.pcolormesh(x, y, c)
ax.meshes[df_hash] = current_mesh
# push to foreground
max_z = max(mesh.get_zorder() for mesh in ax.meshes.values()) if ax.meshes else 0
current_mesh.set_zorder(max_z + 1)
if update_mode != 'overwrite':
all_data = [mesh.get_array()
for mesh in ax.meshes.values()]
vmin = min(map(np.min, all_data))
vmax = max(map(np.max, all_data))
if not hasattr(ax, 'custom_colorbar'):
ax.custom_colorbar = plt.colorbar(ax=ax, mappable=current_mesh)
for mesh in ax.meshes.values():
mesh.set_clim(vmin, vmax)
ax.custom_colorbar.set_clim(vmin, vmax)
# TODO: fix
warnings.warn("for update_mode='overwrite' the colorbar code is stupid")
ax.set(ylabel=y_label, xlabel=x_label, title=title)
return ax
def cycle_plots(plot_callback, *args,
fig: plt.Figure = None, ax: plt.Axes = None, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
"""Call plot_callback(fig, ax, curr_pos, *args, **kwargs) on each left/right arrow key press.
Initially curr_pos = 0. The right arrow increases and the left arrow decreases the current position.
There is no limit so you need to do the wraparound yourself if needed:
>>> plot_data = [(x1, y1), (x2, y2), ...]
>>> def example_plot_callback(fig, ax, pos):
>>> idx = pos % len(plot_data)
>>> ax.plot(*plot_data[idx])
def key_event(e):
if e.key == "right":
key_event.curr_pos += 1
elif e.key == "left":
key_event.curr_pos -= 1
plot_callback(fig, ax, key_event.curr_pos, *args, **kwargs)
key_event.curr_pos = 0
if fig is None:
if ax is None:
fig = plt.figure()
fig = ax.get_figure()
if ax is None:
ax = fig.add_subplot(111)
assert ax in fig.axes, "axes not in figure"
fig.canvas.mpl_connect('key_press_event', key_event)
plot_callback(fig, ax, key_event.curr_pos, *args, **kwargs)
return fig, ax
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment