Source code for scqubits.utils.plot_utils

# plot_utils.py
#
# This file is part of scqubits: a Python package for superconducting qubits,
# Quantum 5, 583 (2021). https://quantum-journal.org/papers/q-2021-11-17-583/
#
#    Copyright (c) 2019 and later, Jens Koch and Peter Groszkowski
#    All rights reserved.
#
#    This source code is licensed under the BSD-style license found in the
#    LICENSE file in the root directory of this source tree.
############################################################################
import functools
import operator
import os

from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import matplotlib as mpl
import numpy as np

from matplotlib.axes import Axes
from matplotlib.figure import Figure
from numpy import ndarray

from scqubits import settings as settings
from scqubits.settings import matplotlib_settings
from scqubits.utils import plot_defaults as defaults

if TYPE_CHECKING:
    from scqubits.core.storage import WaveFunction


# A dictionary of plotting options that are directly passed to specific matplotlib's
# plot commands.
_direct_plot_options = {
    "plot": (
        "alpha",
        "color",
        "linestyle",
        "linewidth",
        "marker",
        "markersize",
        "label",
    ),
    "imshow": ("interpolation",),
    "contourf": tuple(),  # empty for now
}


@mpl.rc_context(matplotlib_settings)
def _extract_kwargs_options(
    kwargs: Dict[str, Any],
    plot_type: str,
    direct_plot_options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
    """
    Select options from kwargs for a given plot_type and return them in a dictionary.

    Parameters
    ----------
    kwargs:
        dictionary with options that can be passed to different plotting commands
    plot_type:
        a type of plot for which the options should be selected
    direct_plot_options:
        a lookup dictionary with supported options for a given plot_type

    Returns
    ----------
        dictionary with key/value pairs corresponding to selected options from kwargs

    """
    direct_plot_options = direct_plot_options or _direct_plot_options
    if plot_type not in direct_plot_options:
        return {}

    selected_options = {}

    for key in kwargs:
        if key in direct_plot_options[plot_type]:
            selected_options[key] = kwargs[key]
    return selected_options


@mpl.rc_context(matplotlib_settings)
def _process_options(
    figure: Figure, axes: Axes, opts: Optional[Dict[str, Any]] = None, **kwargs
) -> None:
    """
    Processes plotting options.

    Parameters
    ----------
    figure:
    axes:
    opts:
        keyword dictionary with custom options
    **kwargs:
        standard plotting option (see separate documentation)
    """
    opts = opts or {}

    # Only process items in kwargs that would not have been
    # processed through _extract_kwargs_options()
    filtered_kwargs = {
        key: value
        for key, value in kwargs.items()
        if key
        not in functools.reduce(
            operator.concat, _direct_plot_options.values()  # type:ignore
        )
    }

    option_dict = {**opts, **filtered_kwargs}

    for key, value in option_dict.items():
        if key in defaults.SPECIAL_PLOT_OPTIONS:
            _process_special_option(figure, axes, key, value)
        else:
            set_method = getattr(axes, f"set_{key}")
            set_method(value)

    filename = kwargs.get("filename")
    if filename:
        figure.savefig(os.path.splitext(filename)[0] + ".pdf")

    if settings.DESPINE and not axes.name == "3d":
        despine_axes(axes)


@mpl.rc_context(matplotlib_settings)
def _process_special_option(figure: Figure, axes: Axes, key: str, value: Any) -> None:
    """Processes a single 'special' option, i.e., one internal to scqubits and not to be
    handed further down to matplotlib.
    """
    if key == "ymax":
        ymax = value
        ymin, _ = axes.get_ylim()
        ymin = ymin - (ymax - ymin) * 0.05
        axes.set_ylim(ymin, ymax)
    elif key == "figsize":
        figure.set_size_inches(value)
    elif key == "grid":
        if isinstance(value, dict):
            axes.grid(**value)
        else:
            axes.grid(value)


[docs]@mpl.rc_context(matplotlib_settings) def despine_axes(axes: Axes) -> None: # Hide the right and top spines axes.spines["right"].set_visible(False) axes.spines["top"].set_visible(False) # Only show ticks on the left and bottom spines axes.yaxis.set_ticks_position("left") axes.xaxis.set_ticks_position("bottom")
[docs]@mpl.rc_context(matplotlib_settings) def scale_wavefunctions( wavefunc_list: List["WaveFunction"], potential_vals: np.ndarray, scaling: Optional[float], ) -> List["WaveFunction"]: scale_factors = np.array( [wavefunc.amplitude_scale_factor(potential_vals) for wavefunc in wavefunc_list] ) for wavefunc in wavefunc_list: wavefunc.rescale(np.max(scale_factors)) adaptive_scalefactor = scaling or defaults.set_wavefunction_scaling( wavefunc_list, potential_vals ) for wavefunc in wavefunc_list: wavefunc.rescale(adaptive_scalefactor) return wavefunc_list
[docs]@mpl.rc_context(matplotlib_settings) def plot_wavefunction_to_axes( axes: Axes, wavefunction: "WaveFunction", energy_offset: float, **kwargs ) -> None: x_vals = wavefunction.basis_labels y_vals = energy_offset + wavefunction.amplitudes offset_vals = [energy_offset] * len(x_vals) axes.plot(x_vals, y_vals, **_extract_kwargs_options(kwargs, "plot")) axes.fill_between( x_vals, y_vals, offset_vals, where=(y_vals != offset_vals), interpolate=True )
[docs]@mpl.rc_context(matplotlib_settings) def plot_potential_to_axes( axes: Axes, x_vals: ndarray, potential_vals: Union[ndarray, List[float]], offset_list: Union[ndarray, List[float]], **kwargs, ) -> None: y_min = np.min(potential_vals) y_max = np.max(offset_list) y_range = y_max - y_min y_max += 0.3 * y_range y_min = np.min(potential_vals) - 0.1 * y_range axes.set_ylim([y_min, y_max]) axes.plot( x_vals, potential_vals, color="gray", **_extract_kwargs_options(kwargs, "plot") )
[docs]@mpl.rc_context(matplotlib_settings) def add_numbers_to_axes( axes: Axes, matrix: ndarray, modefunc: Callable, fontsize: int = 8 ) -> None: for y_index in range(matrix.shape[0]): for x_index in range(matrix.shape[1]): axes.text( x_index, y_index, "{:.03f}".format(modefunc(matrix[y_index, x_index])), va="center", ha="center", fontsize=fontsize, rotation=45, color="white", )
[docs]@mpl.rc_context(matplotlib_settings) def color_normalize(vals, mode: str) -> Tuple[float, float, mpl.colors.Normalize]: minval = min(vals) maxval = max(vals) if mode in ["abs", "abs_sqr"]: minval = 0 nrm = mpl.colors.Normalize(minval, maxval) return minval, maxval, nrm