diff --git a/README.md b/README.md index 867e8e52f71d731037907b2cca5c2d56019f06c6..483f0c609b1d5f246245312059f93ad4d507fc77 100644 --- a/README.md +++ b/README.md @@ -11,5 +11,9 @@ python setup.py develop ``` This will link the files into your environment instead of copying them. If you are on windows you can use [SourceTree](https://www.sourcetreeapp.com/) which is a nice GUI for git. +# qutil.plotting +`cycle_plots` helps you cycling through many plots with the arrow keys (there are probably much better functions for this out there) +`plot_2d_dataframe` helps you plot 2d data frames with numeric indices + ## qutil.matlab In this module there are functions that are helpful for reading `.mat` files, especially those created with special measure. If you simply want to open a random `.mat` file you can use `hdf5storage.loadmat`. diff --git a/qutil/plotting.py b/qutil/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..2b495184bb22a3799ebf5d0dcf7ebd520aed4d1e --- /dev/null +++ b/qutil/plotting.py @@ -0,0 +1,174 @@ +"""This module contains some useful plotting functions""" +from weakref import WeakValueDictionary +from typing import Tuple +import warnings + +from matplotlib import pyplot as plt +import pandas as pd +import numpy as np + + +__all__ = ["plot_2d_dataframe"] + + +def _to_corner_coords(index: pd.Index) -> np.ndarray: + """Helper function to transform N coordinates pointing at centers of bins to N+1 coords pointing to the edges""" + coords = index.values + + delta = coords[-1] - coords[-2] + + return np.concatenate((coords, [coords[-1] + delta])) - delta / 2 + + +def _data_hash(*args: np.ndarray) -> int: + return hash(tuple(arg.tobytes() for arg in args)) + + +def plot_2d_dataframe(df: pd.DataFrame, + ax: plt.Axes = None, square=True, + column=None, index_is_y=True, + update_mode: str = 'auto') -> plt.Axes: + """Plot pandas data frames using pcolormesh. This function expects numeric labels and it can update an existing + plot. Have a look at seaborn.heatmap if you need something else. + + 'auto': 'rescale' if x-label, y-label, and title are the same else 'clear' + 'clear': Clear axis before drawing + 'overwrite': Just plot new data frame on top (no colorbar is drawn) + 'rescale': Recalculate and redraw the colorbar + + Details: + - plotted meshes are stored in ax.meshes + - The colorbar is stored in ax.custom_colorbar (DO NOT RELY ON THIS) + - If the plotted data is already present we just shift it to the top using set_zorder + - Uses _data_hash(x, y, c) to identify previously plotted data + + :param df: pandas dataframe to plot + :param ax: Axes object + :param square: + :param column: Select this column from the dataframe and unstack the index + :param index_is_y: If true the index are on the y-axis and the columns on the x-axis + :param update_mode: 'auto', 'overwrite' or 'rescale' + :return: + """ + if ax is None: + ax = plt.gca() + + if square: + ax.set_aspect("equal") + + if column is None and len(df.columns) == 1 and len(df.index.levshape) == 2: + column = df.columns[0] + + if column is not None: + title = column + series = df[column] + df = series.unstack() + + else: + title = None + + c = df.values + x_idx = df.columns + y_idx = df.index + + if not index_is_y: + c = np.transpose(c) + x_idx, y_idx = y_idx, x_idx + + x_label = x_idx.name + y_label = y_idx.name + + if update_mode == 'auto': + if (x_label, y_label, title) == (ax.get_xlabel(), ax.get_ylabel(), ax.get_title()): + update_mode = 'rescale' + else: + update_mode = 'clear' + if update_mode not in ('clear', 'rescale', 'overwrite'): + raise ValueError('%s is an invalid value for update_mode' % update_mode) + + if update_mode == 'clear': + if hasattr(ax, 'custom_colorbar'): + # clear colorbar axis + ax.custom_colorbar.ax.clear() + ax.clear() + ax.meshes = WeakValueDictionary() + + y = _to_corner_coords(y_idx) + x = _to_corner_coords(x_idx) + + if not hasattr(ax, 'meshes'): + ax.meshes = WeakValueDictionary() + + df_hash = _data_hash(x, y, c) + current_mesh = ax.meshes.get(df_hash, None) + + if current_mesh is None: + # data not yet drawn -> draw it + current_mesh = ax.pcolormesh(x, y, c) + ax.meshes[df_hash] = current_mesh + + # push to foreground + max_z = max(mesh.get_zorder() for mesh in ax.meshes.values()) if ax.meshes else 0 + current_mesh.set_zorder(max_z + 1) + + if update_mode != 'overwrite': + all_data = [mesh.get_array() + for mesh in ax.meshes.values()] + vmin = min(map(np.min, all_data)) + vmax = max(map(np.max, all_data)) + + if not hasattr(ax, 'custom_colorbar'): + ax.custom_colorbar = plt.colorbar(ax=ax, mappable=current_mesh) + + for mesh in ax.meshes.values(): + mesh.set_clim(vmin, vmax) + + ax.custom_colorbar.set_clim(vmin, vmax) + else: + # TODO: fix + warnings.warn("for update_mode='overwrite' the colorbar code is stupid") + + ax.set(ylabel=y_label, xlabel=x_label, title=title) + + return ax + + +def cycle_plots(plot_callback, *args, + fig: plt.Figure = None, ax: plt.Axes = None, **kwargs) -> Tuple[plt.Figure, plt.Axes]: + """Call plot_callback(fig, ax, curr_pos, *args, **kwargs) on each left/right arrow key press. + Initially curr_pos = 0. The right arrow increases and the left arrow decreases the current position. + There is no limit so you need to do the wraparound yourself if needed: + + >>> plot_data = [(x1, y1), (x2, y2), ...] + >>> def example_plot_callback(fig, ax, pos): + >>> idx = pos % len(plot_data) + >>> ax.plot(*plot_data[idx]) + """ + def key_event(e): + if e.key == "right": + key_event.curr_pos += 1 + elif e.key == "left": + key_event.curr_pos -= 1 + else: + return + plot_callback(fig, ax, key_event.curr_pos, *args, **kwargs) + plt.draw_all() + + key_event.curr_pos = 0 + + if fig is None: + if ax is None: + fig = plt.figure() + else: + fig = ax.get_figure() + if ax is None: + ax = fig.add_subplot(111) + + assert ax in fig.axes, "axes not in figure" + + fig.canvas.mpl_connect('key_press_event', key_event) + + plot_callback(fig, ax, key_event.curr_pos, *args, **kwargs) + plt.draw_all() + + return fig, ax