Skip to content
Snippets Groups Projects
Select Git revision
  • master
  • v1.1.1
  • v1.1.0 protected
  • v1.0.0 protected
4 results

README.md

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    To learn more about this project, read the wiki.
    default_plots.py 21.66 KiB
    #!/usr/bin/env python3
    # pylint: disable=too-many-locals,too-many-arguments
    """
    This module contains a few functions, which can be used to
    generate plots quickly and with useful defaults"""
    from __future__ import annotations
    from pathlib import Path
    from functools import wraps
    from typing import TypeVar, List, Tuple, Union, Callable, Optional
    from warnings import warn, filterwarnings, catch_warnings
    from textwrap import dedent
    
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    from matplotlib.ticker import MaxNLocator
    import numpy as np
    from numpy import amin, amax
    
    from .plot_settings import apply_styles, rwth_cycle
    from .types_ import Vector, Matrix
    
    mpl.use("Agg")
    
    In = TypeVar("In", List[float], Tuple[float],
                 Vector)
    
    # pylint: disable=invalid-name
    In2D = TypeVar("In2D", List[List[float]], List[Vector], Tuple[Vector],
                   Matrix)
    
    
    def fix_inputs(input_1: In, input_2: In)\
            -> tuple[Vector, Vector]:
        """
        Remove nans and infinities from the input vectors.
    
        Parameters
        ---------
        input_1, input_2:
            X/Y-axis data of plot
    
        Returns
        ------
        New vectors x and y with nans removed.
        """
        if len(input_1) != len(input_2):
            raise ValueError(
                "The sizes of the input vectors are not the same.")
        nan_count = np.count_nonzero(np.isnan(input_2))
        inf_count = np.count_nonzero(np.isinf(input_2))
        if nan_count != 0 or inf_count != 0:
            new_input_1 = np.empty(len(input_1)-nan_count-inf_count)
            new_input_2 = np.empty(len(new_input_1))
            position = 0
            for x_input, y_input in zip(input_1, input_2):
                if not np.isnan(y_input) or np.isinf(y_input):
                    new_input_1[position] = x_input
                    new_input_2[position] = y_input
                    position += 1
            return new_input_1, new_input_2
    
        return np.array(input_1), np.array(input_2)
    
    
    def check_inputs(input_1: In, input_2: In, label_1: str, label_2: str)\
            -> bool:
        """
        Check the input vectors to see, if they are large enough.
    
        Parameters
        ---------
        input_1, input_2:
            X/Y-axis data of plot
    
        label_1, label_2:
            Labels of the X/Y axis
    
        Returns
        ------
        True, if the plot can be created.
        """
        if len(input_1) <= 1 or len(input_2) <= 1:
            warn(
                "There are not enough points in the following plots:"
                f"label1: {label_1} label2: {label_2}. It cannot be drawn.")
            return False
    
        if min(input_1) == max(input_1):
            warn(
                "The area of the x-axis is not large enough in the following plot:"
                f"label1: {label_1} label2: {label_2}. It cannot be drawn.")
            return False
    
        if min(input_2) == max(input_2):
            warn(
                "The area of the y-axis is not large enough in the following plot:"
                f"label1: {label_1} label2: {label_2}. It cannot be drawn.")
            return False
    
        infinity = np.isinf(input_1).any() or np.isinf(input_2).any()
        if infinity:
            warn(dedent(f"""There are infinities in the data of the following plot:
                 label1: {label_1}, label2: {label_2}. It cannot be drawn."""),
                 RuntimeWarning)
            return False
    
        nan = np.isnan(input_1).any() or np.isnan(input_2).any()
        if nan:
            warn(dedent(f"""There are nans in the data of the following plot:
                 label1: {label_1}, label2: {label_2}. It cannot be drawn."""),
                 RuntimeWarning)
            return False
    
        return True
    
    
    def get_ylims(
            X: In, Y: In,
            xlim: Optional[tuple[float, float]] = None,
            logscale: bool = False) -> tuple[float, float]:
        """
        Calculate the optimised limits in y-direction.
        """
        if xlim is not None:
            X_plot = np.array(X)
            Y_plot = np.array(Y)[
                np.logical_and(X_plot >= xlim[0], X_plot <= xlim[1])]
        else:
            Y_plot = Y
    
        if logscale:
            ylim = (
                min(Y_plot) * 0.97, max(Y_plot) * 1.02)
        else:
            ylim = (
                min(Y_plot) - (max(Y_plot) - min(Y_plot)) * 0.02,
                max(Y_plot) + (max(Y_plot) - min(Y_plot)) * 0.02)
        return ylim
    
    
    # pylint: disable=too-many-positional-arguments
    def set_lims(
            X: In, Y: In,
            logscale: bool = False,
            single_log: bool = False,
            single_log_y: bool = False,
            xlim: Optional[tuple[float, float]] = None,
            ylim: Optional[tuple[float, float]] = None) -> None:
        """
        Set the limits of the current figure to the preferred values, and also
        consider the given manual limits.
        """
        if xlim is None:
            xlim = (min(X), max(X))
    
        if ylim is None:
            ylim = get_ylims(
                X, Y, xlim=xlim, logscale=(logscale or single_log_y))
    
        if logscale:
            plt.xscale("log")
            plt.yscale("log")
        elif single_log:
            plt.xscale("log")
        elif single_log_y:
            plt.yscale("log")
    
        plt.xlim(*xlim)
        plt.ylim(*ylim)
    
    
    # pylint: disable=too-many-positional-arguments
    @apply_styles
    def plot_fit(X: In, Y: In,
                 fit_function: Callable[..., float],
                 xlabel: str, ylabel: str, filename: Union[str, Path], *,
                 args: Optional[Tuple[float]] = None,
                 logscale: bool = False,
                 single_log: bool = False,
                 single_log_y: bool = False,
                 xlim: Optional[Tuple[float, float]] = None,
                 ylim: Optional[Tuple[float, float]] = None) -> None:
        """Creates a plot of data and a fit and saves it to 'filename'."""
        X, Y = fix_inputs(X, Y)  # type: ignore
        if not check_inputs(
                X, Y, xlabel, ylabel):
            return
    
        n_fit = 1000
    
        @wraps(fit_function)
        def _fit_function(x: float) -> float:
            """This is the function, which has been fitted"""
            if args is not None:
                return fit_function(x, *args)
            return fit_function(x)
    
        plt.plot(X, Y, label="data")
        X_fit = [
            min(X) + (max(X) - min(X)) * i / (n_fit - 1) for i in range(n_fit)]
    
        Y_fit = [_fit_function(x) for x in X_fit]
        plt.plot(X_fit, Y_fit, label="fit")  # type: ignore
    
        set_lims(
            X, Y, logscale=logscale, single_log=single_log,
            single_log_y=single_log_y,
            xlim=xlim, ylim=ylim)
        if logscale:
            plt.xscale("log")
            plt.yscale("log")
        plt.xlim(min(X), max(X))
        if logscale:
            plt.ylim(min(Y) * 0.97, max(Y) * 1.02)
        else:
            plt.ylim(
                min(Y) - (max(Y) - min(Y)) * 0.02,
                max(Y) + (max(Y) - min(Y)) * 0.02
            )
        plt.ylabel(ylabel)
        plt.xlabel(xlabel)
        plt.legend()
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()
    
    
    # pylint: disable=too-many-positional-arguments
    @apply_styles(three_d=True)
    def plot_surface(X: In2D, Y: In2D, Z: In2D,
                     xlabel: str, ylabel: str, zlabel: str,
                     filename: Union[str, Path], *,
                     log_scale: bool = False,
                     set_z_lim: bool = True,
                     colorscheme: str = "rwth_gradient_simple",
                     figsize: Tuple[float, float] = (4.33, 3.5),
                     labelpad: Optional[float] = None,
                     nbins: Optional[int] = None,
                     xlim: Optional[Tuple[float, float]] = None,
                     ylim: Optional[Tuple[float, float]] = None,
                     zlim: Optional[Tuple[float, float]] = None,
                     linewidth: float = 0,
                     antialiased: bool = True,
                     rcount: int = 200, ccount: int = 200) -> None:
        """create a 2D surface plot of meshgrid-like valued Xs, Ys and Zs"""
        # pylint: disable=too-many-branches
        if not check_inputs(
                np.array(X).flatten(),
                np.array(Z).flatten(), xlabel, zlabel):
            return
        fig = plt.figure()
        ax = fig.add_subplot(projection="3d")
        fig.subplots_adjust(left=-0.02, right=0.75, bottom=0.15, top=0.98)
        ax.plot_surface(X, Y, Z, cmap=colorscheme,
                        rcount=rcount, ccount=ccount,
                        linewidth=linewidth, antialiased=antialiased)
        ax.set_box_aspect(aspect=None, zoom=.8)
    
        if labelpad is None:
            ax.set_xlabel(xlabel)
            ax.set_ylabel(ylabel)
            ax.set_zlabel(zlabel, rotation=90)
        else:
            ax.set_xlabel(xlabel, labelpad=labelpad)
            ax.set_ylabel(ylabel, labelpad=labelpad)
            ax.set_zlabel(zlabel, rotation=90, labelpad=labelpad)
    
        assert ax.zaxis is not None
    
        if xlim is None:
            ax.set_xlim(amin(X), amax(X))  # type: ignore
        else:
            ax.set_xlim(*xlim)
        if ylim is None:
            ax.set_ylim(amin(Y), amax(Y))  # type: ignore
        else:
            ax.set_ylim(*ylim)
    
        if set_z_lim:
            if not log_scale:
                ax.set_zlim(
                    amin(Z) - (amax(Z) - amin(Z)) * 0.02,  # type: ignore
                    amax(Z) + (amax(Z) - amin(Z)) * 0.02  # type: ignore
                )
            else:
                ax.set_zlim(
                    amin(Z) * 0.97, amax(Z) * 1.02)  # type: ignore
        if zlim is not None:
            ax.set_zlim(*zlim)
    
        if log_scale:
            ax.set_yscale("log")
            ax.set_xscale("log")
            ax.set_zscale("log")
    
        for spine in ax.spines.values():
            spine.set_visible(False)
    
        ax.xaxis.pane.set_alpha(0.3)
        ax.yaxis.pane.set_alpha(0.3)
        ax.zaxis.pane.set_alpha(0.3)
    
        if nbins is not None:
            ax.xaxis.set_major_locator(
                MaxNLocator(nbins)
            )
            ax.yaxis.set_major_locator(
                MaxNLocator(nbins)
            )
    
        fig.set_size_inches(*figsize)
    
        with catch_warnings():
            filterwarnings("ignore", message=".*Tight layout")
            plt.tight_layout()
            plt.savefig(filename)
    
        plt.close()
    
    
    @apply_styles
    def plot(X: In, Y: In, xlabel: str, ylabel: str,
             filename: Union[Path, str], *, logscale: bool = False,
             ylim: Optional[tuple[float, float]] = None,
             yticks: bool = True, cycler: int = 0,
             xlim: Optional[tuple[float, float]] = None,
             single_log: bool = False,
             single_log_y: bool = False) -> None:
        """Create a simple 1D plot"""
        X, Y = fix_inputs(X, Y)  # type: ignore
        if not check_inputs(
                X, Y, xlabel, ylabel):
            return
        if len(X) <= 1 or len(Y) <= 1:
            raise ValueError(
                f"The data for plot {filename} contains empty rows!")
    
        if cycler > 0:
            for _ in range(cycler):
                plt.plot([], [])
    
        set_lims(
            X, Y, logscale=logscale, single_log=single_log,
            single_log_y=single_log_y,
            xlim=xlim, ylim=ylim)
    
        plt.plot(X, Y, linestyle="-")
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
    
        if not yticks:
            plt.yticks([])
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()
    
    
    @apply_styles
    def two_plots(x1: In, y1: In, label1: str,
                  x2: In, y2: In, label2: str,
                  xlabel: str, ylabel: str,
                  filename: Union[Path, str], *,
                  logscale: bool = False, cycle: int = 0,
                  color: tuple[int, int] = (0, 1),
                  outer: bool = False,
                  single_log: bool = False,
                  single_log_y: bool = False,
                  xlim: Optional[Tuple[float, float]] = None,
                  ylim: Optional[Tuple[float, float]] = None,
                  inner_y: bool = False,
                  half_y: bool = False) -> None:
        """Create a simple 1D plot with two different graphs inside of a single
        plot and a single y-axis.
    
        Keyword arguments:
        cycle -- skip this many colours in the colour-wheel before plotting
        color -- use these indeces in the colour-wheel when creating a plot
        outer -- use the outer limits on the x-axis rather than the inner limit
        inner_y -- Use the tighter definition for the limits on the y-axis
        half_y -- Use the upper upper limit but the upper lower limit
        """
        x1, y1 = fix_inputs(x1, y1)  # type: ignore
        x2, y2 = fix_inputs(x2, y2)  # type: ignore
    
        if not (
                check_inputs(x1, y1, xlabel, label1)
                or check_inputs(x2, y2, xlabel, label2)):
            return
        if len(x1) <= 1 or len(y1) <= 1 or len(y2) <= 1 or len(x2) <= 1:
            raise ValueError(
                f"The data for plot {filename} contains empty rows!")
    
        if cycle > 0:
            color = (color[0] + cycle, color[1] + cycle)
    
        # access colour
        prop_cycle = plt.rcParams["axes.prop_cycle"]
        try:
            linestyle = prop_cycle.by_key()["linestyle"]
        except KeyError:
            linestyle = rwth_cycle.by_key()["linestyle"]
        colors = prop_cycle.by_key()["color"]
    
        if max(color) >= len(colors):
            colors += colors
            linestyle += linestyle
    
        plt.plot(x1, y1, label=label1,
                 color=colors[color[0]],
                 linestyle=linestyle[0])
    
        plt.plot(x2, y2, label=label2,
                 color=colors[color[1]],
                 linestyle=linestyle[1])
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
    
        if xlim is not None:
            pass
        elif outer:
            xlim = (min(*x1, *x2),
                    max(*x1, *x2))
        else:
            xlim = (max(min(x1), min(x2)),
                    min(max(x1), max(x2)))
    
        plt.xlim(*xlim)
    
        if ylim is None:
            ylim1 = get_ylims(
                x1, y1, xlim=xlim,
                logscale=(logscale or single_log_y))
    
            ylim2 = get_ylims(
                x2, y2, xlim=xlim,
                logscale=(logscale or single_log_y))
    
            if inner_y:
                plt.ylim(
                    max(ylim1[0], ylim2[0]),
                    min(ylim1[1], ylim2[1]))
            elif half_y:
                plt.ylim(
                    max(ylim1[0], ylim2[0]),
                    max(ylim1[1], ylim2[1]))
            else:
                plt.ylim(
                    min(ylim1[0], ylim2[0]),
                    max(ylim1[1], ylim2[1]))
        else:
            plt.ylim(*ylim)
    
        if logscale or single_log:
            plt.xscale("log")
    
        if logscale or single_log_y:
            plt.yscale("log")
    
        plt.legend()
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()
    
    
    @apply_styles
    def three_plots(x1: In, y1: In, label1: str,
                    x2: In, y2: In, label2: str,
                    x3: In, y3: In, label3: str,
                    xlabel: str, ylabel: str,
                    filename: Union[Path, str], *,
                    logscale: bool = False,
                    xmin: Optional[float] = None,
                    xmax: Optional[float] = None) -> None:
        """Create a simple 1D plot with three different graphs inside of a single
        plot and a single y-axis."""
        x1, y1 = fix_inputs(x1, y1)  # type: ignore
        x2, y2 = fix_inputs(x2, y2)  # type: ignore
        x3, y3 = fix_inputs(x3, y3)  # type: ignore
        if not (
                check_inputs(x1, y1, xlabel, label1)
                or check_inputs(x2, y3, xlabel, label1)
                or check_inputs(x3, y3, xlabel, label3)):
            return
    
        if any(len(x) <= 1 for x in (x1, x2, y1, y2, x3, y3)):
            raise ValueError(
                f"The data for plot {filename} contains empty rows!")
    
        plt.plot(x1, y1, label=label1)
        plt.plot(x2, y2, label=label2, linestyle="dashed")
        plt.plot(x3, y3, label=label3, linestyle="dotted")
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        min_ = min(*y1, *y2, *y3)
        max_ = max(*y1, *y2, *y3)
        if not logscale:
            plt.ylim(
                min_ - (max_ - min_) * 0.02,
                max_ + (max_ - min_) * 0.02
            )
        else:
            plt.xscale("log")
            plt.yscale("log")
            plt.ylim(
                min_ * 0.97, max_ * 1.02)
        if xmin is not None and xmax is not None:
            plt.xlim(xmin, xmax)
        else:
            plt.xlim(min(x1), max(x1))
        plt.legend()
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()
    
    
    @apply_styles
    def two_axis_plots(x1: In, y1: In, label1: str,
                       x2: In, y2: In, label2: str,
                       xlabel: str, ylabel: str,
                       ylabel2: str,
                       filename: Union[Path, str], *,
                       ticks: Optional[Tuple[List[float], List[str]]] = None,
                       xlim: Optional[Tuple[float, float]] = None,
                       color: Tuple[int, int] = (0, 1))\
            -> None:
        """Create a simple 1D plot with two different graphs inside of a single
        plot with two y-axis.
        The variable "ticks" sets costum y-ticks on the second y-axis. The first
        argument gives the position of the ticks and the second argument gives the
        values to be shown.
        Color selects the indeces of the chosen color-wheel, which should be taken
        for the different plots. The default is (1,2)."""
        x1, y1 = fix_inputs(x1, y1)  # type: ignore
        x2, y2 = fix_inputs(x2, y2)  # type: ignore
        if not check_inputs(
                y1, y2, label1, label2):
            return
        if len(x1) <= 1 or len(y1) <= 1 or len(y2) <= 1 or len(x2) <= 1:
            raise ValueError(
                f"The data for plot {filename} contains empty rows!")
    
        fig = plt.figure()
        ax1 = fig.add_subplot(111)
        # access colour
        prop_cycle = plt.rcParams["axes.prop_cycle"]
        try:
            linestyle = prop_cycle.by_key()["linestyle"]
        except KeyError:
            linestyle = rwth_cycle.by_key()["linestyle"]
        colors = prop_cycle.by_key()["color"]
    
        if max(color) >= len(colors):
            colors += colors
            linestyle += linestyle
    
        # first plot
        lines = ax1.plot(x1, y1, label=label1,
                         color=colors[color[0]],
                         linestyle=linestyle[0])
        ax1.set_xlabel(xlabel)
        ax1.set_ylabel(ylabel)
        ax1.set_ylim(
            min(y1) - (max(y1) - min(y1)) * 0.02,
            max(y1) + (max(y1) - min(y1)) * 0.02
        )
    
        # second plot
        ax2 = ax1.twinx()
        lines += ax2.plot(x2, y2, label=label2,
                          color=colors[color[1]],
                          linestyle=linestyle[1])
        ax2.set_ylabel(ylabel2)
        ax2.set_ylim(
            min(y2) - (max(y2) - min(y2)) * 0.02,
            max(y2) + (max(y2) - min(y2)) * 0.02
        )
    
        # general settings
        if xlim is None:
            plt.xlim(min(x1), max(x1))
        else:
            plt.xlim(*xlim)
        labels = [line.get_label() for line in lines]
        plt.legend(lines, labels)
        # ticks
        if ticks is not None:
            ax2.set_yticks(ticks[0])
            ax2.set_yticklabels(ticks[1])
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()
    
    
    def make_invisible(ax: plt.Axes) -> None:
        """Make all patch spines invisible."""
        ax.set_frame_on(True)
        ax.patch.set_visible(False)
        for spine in ax.spines.values():
            spine.set_visible(False)
    
    
    @apply_styles
    def three_axis_plots(x1: In, y1: In, label1: str,
                         x2: In, y2: In, label2: str,
                         x3: In, y3: In, label3: str,
                         xlabel: str, ylabel: str,
                         ylabel2: str, ylabel3: str,
                         filename: Union[Path, str], *,
                         ticks: Optional[Tuple[List[float], List[str]]] = None,
                         xlim: Optional[Tuple[float, float]] = None,
                         color: Tuple[int, int, int] = (0, 1, 2),
                         legend: bool = True)\
            -> None:
        """Create a simple 1D plot with two different graphs inside of a single
        plot with two y-axis.
        The variable "ticks" sets costum y-ticks on the second y-axis. The first
        argument gives the position of the ticks and the second argument gives the
        values to be shown.
        Color selects the indeces of the chosen color-wheel, which should be taken
        for the different plots. The default is (1,2)."""
        # pylint: disable=R0915
        x1, y1 = fix_inputs(x1, y1)  # type: ignore
        x2, y2 = fix_inputs(x2, y2)  # type: ignore
        x3, y3 = fix_inputs(x3, y3)  # type: ignore
        if not check_inputs(
                y1, y2, label1, label2):
            return
        if not check_inputs(
                x3, y3, xlabel, label3):
            return
    
        if len(x1) <= 1 or len(y1) <= 1 or len(y2) <= 1 or len(x2) <= 1:
            raise ValueError(
                f"The data for plot {filename} contains empty rows!")
        assert len(color) == 3
    
        fig, ax1 = plt.subplots()
        fig.subplots_adjust(right=0.75)
        # access colour
        prop_cycle = plt.rcParams["axes.prop_cycle"]
        try:
            linestyle = prop_cycle.by_key()["linestyle"]
        except KeyError:
            linestyle = rwth_cycle.by_key()["linestyle"]
        colors = prop_cycle.by_key()["color"]
    
        if max(color) >= len(colors):
            colors += colors
            linestyle += linestyle
    
        # first plot
        lines = ax1.plot(x1, y1, label=label1,
                         color=colors[color[0]],
                         linestyle=linestyle[0])
        ax1.set_xlabel(xlabel)
        ax1.set_ylabel(ylabel)
        ax1.set_ylim(
            min(y1) - (max(y1) - min(y1)) * 0.02,
            max(y1) + (max(y1) - min(y1)) * 0.02
        )
        ax1.yaxis.label.set_color(colors[color[0]])
        ax1.tick_params(axis="y", colors=colors[color[0]])
    
        # second plot
        ax2 = ax1.twinx()
        lines += ax2.plot(x2, y2, label=label2,
                          color=colors[color[1]],
                          linestyle=linestyle[1])
        ax2.set_ylabel(ylabel2)
        ax2.set_ylim(
            min(y2) - (max(y2) - min(y2)) * 0.02,
            max(y2) + (max(y2) - min(y2)) * 0.02
        )
        ax2.yaxis.label.set_color(colors[color[1]])
        ax2.tick_params(axis="y", colors=colors[color[1]])
    
        # third plot
        ax3 = ax1.twinx()
        make_invisible(ax3)
        ax3.spines["right"].set_position(("axes", 1.25))
        ax3.spines["right"].set_visible(True)
        lines += ax3.plot(x3, y3, label=label3,
                          color=colors[color[2]],
                          linestyle=linestyle[2])
        ax3.set_ylabel(ylabel3)
        ax3.set_ylim(
            min(y3) - (max(y3) - min(y3)) * 0.02,
            max(y3) + (max(y3) - min(y3)) * 0.02
        )
        ax3.yaxis.label.set_color(colors[color[2]])
        ax3.tick_params(axis="y", colors=colors[color[2]])
    
        # general settings
        if xlim is None:
            plt.xlim(min(x1), max(x1))
        else:
            plt.xlim(*xlim)
        labels = [line.get_label() for line in lines]
        if legend:
            plt.legend(lines, labels)
        # ticks
        if ticks is not None:
            ax2.set_yticks(ticks[0])
            ax2.set_yticklabels(ticks[1])
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()